diff --git a/topi/python/topi/nn/upsampling.py b/topi/python/topi/nn/upsampling.py index e1234741e286c68bacc5f8d95099cbc087edece1..df77bbdb23c04ec3358c46713dfa311bd9cd5d3f 100644 --- a/topi/python/topi/nn/upsampling.py +++ b/topi/python/topi/nn/upsampling.py @@ -1,6 +1,7 @@ """TVM operator upsampling compute.""" from __future__ import absolute_import import tvm +from .. import util def upsampling(data, scale): @@ -21,8 +22,8 @@ def upsampling(data, scale): 4-D with shape [batch, channel, in_height*scale, in_width*scale] """ batch, channel, height, width = data.shape - out_height = height * scale - out_width = width * scale + out_height = util.simplify(height * scale) + out_width = util.simplify(width * scale) return tvm.compute((batch, channel, out_height, out_width), \ lambda n, c, h, w: data[n, c, h/scale, w/scale]) diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 246d283cd2ed236f0275bdceb2772b2b9f566216..00c8b9d42e2ad69b174c48e4b1b7b0456632fa5d 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -59,9 +59,8 @@ def get_const_tuple(in_tuple): """ out_tuple = () for elem in in_tuple: - if not isinstance(elem, (tvm.expr.IntImm, tvm.expr.UIntImm)): - raise ValueError("Element of input tuple should be const int") - out_tuple = out_tuple + (elem.value, ) + value = get_const_int(elem) + out_tuple = out_tuple + (value, ) return out_tuple