diff --git a/topi/python/topi/nn/dilate.py b/topi/python/topi/nn/dilate.py index 2b3b2f424aa699e9c956af57d78bccced0999099..6ba52b02ef529792288ef3bdf5dad620cf75ca44 100644 --- a/topi/python/topi/nn/dilate.py +++ b/topi/python/topi/nn/dilate.py @@ -2,6 +2,7 @@ """Dilation operators""" from __future__ import absolute_import as _abs import tvm +from .. import util @tvm.tag_scope(tag="dilation") @@ -29,15 +30,21 @@ def dilate(Input, strides): output_size += (tvm.ir_pass.Simplify((Input.shape[i]-1)*strides[i]+1),) def _dilate(data, *indices): - not_zero = (indices[0]%strides[0]).equal(0) - index_tuple = () + not_zero = [] + index_tuple = [] for i in range(n): - index_tuple += (indices[i]/strides[i],) - not_zero = tvm.all(not_zero, (indices[i]%strides[i]).equal(0)) - return tvm.select(not_zero, data[index_tuple], tvm.const(0.0, data.dtype)) + if not util.equal_const_int(strides[i], 1): + index_tuple.append(indices[i]/strides[i]) + not_zero.append((indices[i] % strides[i]).equal(0)) + else: + index_tuple.append(indices[i]) + if not_zero: + not_zero = tvm.all(*not_zero) + return tvm.select(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype)) + return data(*index_tuple) Output = tvm.compute( - (output_size), + output_size, lambda *indices: _dilate(Input, *indices), name='DilatedInput') diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 859e3d6caa2f6e62ea8cf8f0c9cd6871ebba35e3..ee53c1815be70df671873b6a50fbc7b462ea2791 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -7,21 +7,43 @@ def get_const_int(expr): Parameters ---------- - expr : + expr : tvm.Expr The input expression. Returns ------- - out_tuple : tuple of int + out_value : int The output. """ if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): - expr = tvm.ir_pass.Simplfy(expr) + expr = tvm.ir_pass.Simplify(expr) if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): raise ValueError("Expect value to be constant int") return expr.value +def equal_const_int(expr, value): + """Returns if expr equals value. + + Parameters + ---------- + expr : tvm.Expr + The input expression. + + Returns + ------- + equal : bool + Whether they equals. + """ + if isinstance(expr, int): + return expr == value + if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + expr = tvm.ir_pass.Simplify(expr) + if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)): + return False + return expr.value == value + + def get_const_tuple(in_tuple): """Verifies input tuple is IntImm, returns tuple of int.