From abccd9cda439dedbf4c6b0deabdba70cef313240 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Tue, 15 Aug 2017 09:18:56 -0700
Subject: [PATCH] [TOPI] Improve dilate (#330)

---
 topi/python/topi/nn/dilate.py | 19 +++++++++++++------
 topi/python/topi/util.py      | 28 +++++++++++++++++++++++++---
 2 files changed, 38 insertions(+), 9 deletions(-)

diff --git a/topi/python/topi/nn/dilate.py b/topi/python/topi/nn/dilate.py
index 2b3b2f424..6ba52b02e 100644
--- a/topi/python/topi/nn/dilate.py
+++ b/topi/python/topi/nn/dilate.py
@@ -2,6 +2,7 @@
 """Dilation operators"""
 from __future__ import absolute_import as _abs
 import tvm
+from .. import util
 
 
 @tvm.tag_scope(tag="dilation")
@@ -29,15 +30,21 @@ def dilate(Input, strides):
         output_size += (tvm.ir_pass.Simplify((Input.shape[i]-1)*strides[i]+1),)
 
     def _dilate(data, *indices):
-        not_zero = (indices[0]%strides[0]).equal(0)
-        index_tuple = ()
+        not_zero = []
+        index_tuple = []
         for i in range(n):
-            index_tuple += (indices[i]/strides[i],)
-            not_zero = tvm.all(not_zero, (indices[i]%strides[i]).equal(0))
-        return tvm.select(not_zero, data[index_tuple], tvm.const(0.0, data.dtype))
+            if not util.equal_const_int(strides[i], 1):
+                index_tuple.append(indices[i]/strides[i])
+                not_zero.append((indices[i] % strides[i]).equal(0))
+            else:
+                index_tuple.append(indices[i])
+        if not_zero:
+            not_zero = tvm.all(*not_zero)
+            return tvm.select(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
+        return data(*index_tuple)
 
     Output = tvm.compute(
-        (output_size),
+        output_size,
         lambda *indices: _dilate(Input, *indices),
         name='DilatedInput')
 
diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py
index 859e3d6ca..ee53c1815 100644
--- a/topi/python/topi/util.py
+++ b/topi/python/topi/util.py
@@ -7,21 +7,43 @@ def get_const_int(expr):
 
     Parameters
     ----------
-    expr :
+    expr : tvm.Expr
         The input expression.
 
     Returns
     -------
-    out_tuple : tuple of int
+    out_value : int
         The output.
     """
     if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
-        expr = tvm.ir_pass.Simplfy(expr)
+        expr = tvm.ir_pass.Simplify(expr)
     if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
         raise ValueError("Expect value to be constant int")
     return expr.value
 
 
+def equal_const_int(expr, value):
+    """Returns if expr equals value.
+
+    Parameters
+    ----------
+    expr : tvm.Expr
+        The input expression.
+
+    Returns
+    -------
+    equal : bool
+        Whether they equals.
+    """
+    if isinstance(expr, int):
+        return expr == value
+    if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+        expr = tvm.ir_pass.Simplify(expr)
+    if not isinstance(expr, (tvm.expr.IntImm, tvm.expr.UIntImm)):
+        return False
+    return expr.value == value
+
+
 def get_const_tuple(in_tuple):
     """Verifies input tuple is IntImm, returns tuple of int.
 
-- 
GitLab