From 9473dca266e307cf1f9faece219af111686ca946 Mon Sep 17 00:00:00 2001
From: "Steven S. Lyubomirsky" <slyubomirsky@gmail.com>
Date: Sat, 24 Nov 2018 22:33:14 -0500
Subject: [PATCH] [Relay][Op] Add compute, schedule, and tests for expand_dims
 and squeeze (#2133)

---
 python/tvm/relay/op/_transform.py    | 45 +++++++++++++++++++++++++++-
 tests/python/relay/test_op_level1.py | 17 +++++++++++
 tests/python/relay/test_op_level3.py | 17 +++++++++++
 3 files changed, 78 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
index 01814e0f7..cd32aea38 100644
--- a/python/tvm/relay/op/_transform.py
+++ b/python/tvm/relay/op/_transform.py
@@ -1,8 +1,51 @@
 #pylint: disable=invalid-name, unused-argument
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
+import topi
+import topi.cuda
+from tvm import container
 from . import op as _reg
-from .op import schedule_injective, OpPattern
+from .op import (schedule_injective, register_compute, register_schedule,
+                 register_pattern, OpPattern)
+
+schedule_broadcast = schedule_injective
+
+# squeeze
+@register_compute("squeeze")
+def squeeze_compiler(attrs, inputs, output_type, target):
+    """Compiler for squeeze dims."""
+    assert len(inputs) == 1
+
+    if attrs.axis is None:
+        axis = None
+    elif isinstance(attrs.axis, container.Array):
+        axis = tuple(attrs.axis)
+    else:
+        axis = int(attrs.axis)
+
+    return [topi.squeeze(inputs[0], axis)]
+
+register_pattern("squeeze", OpPattern.INJECTIVE)
+register_schedule("squeeze", schedule_injective)
+
+# expand_dims
+@register_compute("expand_dims")
+def expand_dims_compiler(attrs, inputs, output_type, target):
+    """Compiler for expand_dims."""
+    assert len(inputs) == 1
+
+    new_axis = int(attrs.num_newaxis)
+    assert new_axis >= 0
+
+    # axis should be in range [-data.ndim - 1, data.ndim]
+    axis = int(attrs.axis)
+    assert axis >= -len(inputs[0].shape) - 1
+    assert axis <= len(inputs[0].shape)
+
+    return [topi.expand_dims(inputs[0], axis, new_axis)]
+
+_reg.register_schedule("expand_dims", schedule_broadcast)
+_reg.register_pattern("expand_dims", OpPattern.BROADCAST)
 
 # strided_slice
 _reg.register_schedule("strided_slice", schedule_injective)
diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py
index 35844ddd4..d28aa0a56 100644
--- a/tests/python/relay/test_op_level1.py
+++ b/tests/python/relay/test_op_level1.py
@@ -90,6 +90,22 @@ def test_binary_op():
         check_binary_op(opfunc, ref)
 
 
+def test_expand_dims():
+    # based on topi test
+    def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis):
+        x = relay.Var("x", relay.TensorType(dshape, dtype))
+        func = relay.Function([x], relay.expand_dims(x, axis, num_newaxis))
+        for target, ctx in ctx_list():
+            data = np.random.uniform(size=dshape).astype(dtype)
+            ref_res = data.reshape(oshape)
+            intrp = relay.create_executor("graph", ctx=ctx, target=target)
+            op_res = intrp.evaluate(func)(data)
+            np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
+
+    verify_expand_dims((3, 10), 'float32', (3, 10, 1, 1), 2, 2)
+    verify_expand_dims((3, 10), 'float32', (1, 3, 10), -3, 1)
+
+
 def test_bias_add():
     xshape=(10, 2, 3, 4)
     bshape=(2,)
@@ -295,6 +311,7 @@ if __name__ == "__main__":
     test_binary_op()
     test_expand_dims_infer_type()
     test_concatenate()
+    test_expand_dims()
     test_softmax()
     test_log_softmax()
     test_dropout()
diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py
index 6f8fbd551..f6951b5ab 100644
--- a/tests/python/relay/test_op_level3.py
+++ b/tests/python/relay/test_op_level3.py
@@ -60,6 +60,22 @@ def test_clip():
     np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
 
 
+def test_squeeze():
+    def verify_squeeze(shape, dtype, axis):
+        x = relay.var("x", relay.TensorType(shape, dtype))
+        squeeze = relay.squeeze(x, axis=axis)
+
+        np_axis = tuple(axis) if axis is not None else None
+
+        data = np.random.random_sample(shape).astype(dtype)
+        intrp = create_executor()
+        op_res = intrp.evaluate(squeeze, { x : relay.const(data) })
+        ref_res = np.squeeze(data, axis=np_axis)
+        np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
+
+    verify_squeeze((1, 3, 2, 5), "float32", None)
+    verify_squeeze((1, 3, 1), "float32", [0])
+    verify_squeeze((1, 2, 1, 2, 1), "float32", [0, 2])
 
 
 def test_transpose_infer_type():
@@ -308,6 +324,7 @@ if __name__ == "__main__":
     test_full_like()
     test_infer_type_leaky_relu()
     test_infer_type_prelu()
+    test_squeeze()
     test_squeeze_infer_type()
     test_squeeze_bad_axes_infer_type()
     test_split_infer_type()
-- 
GitLab