From b5e0d79060bc931c9cdcd5bd3de535baa971b7e6 Mon Sep 17 00:00:00 2001
From: Siju <sijusamuel@gmail.com>
Date: Mon, 26 Nov 2018 11:41:33 +0530
Subject: [PATCH] [RELAY]sch and compute for reduce ops (#2091)

---
 python/tvm/relay/op/_reduce.py       |  1 +
 tests/python/relay/test_op_level4.py | 60 +++++++++++++++++++++++-----
 2 files changed, 50 insertions(+), 11 deletions(-)

diff --git a/python/tvm/relay/op/_reduce.py b/python/tvm/relay/op/_reduce.py
index fd18c0e71..5c720256b 100644
--- a/python/tvm/relay/op/_reduce.py
+++ b/python/tvm/relay/op/_reduce.py
@@ -15,5 +15,6 @@ _reg.register_schedule("argmax", _schedule_reduce)
 _reg.register_schedule("argmin", _schedule_reduce)
 _reg.register_schedule("sum", _schedule_reduce)
 _reg.register_schedule("max", _schedule_reduce)
+_reg.register_schedule("min", _schedule_reduce)
 _reg.register_schedule("prod", _schedule_reduce)
 _reg.register_schedule("mean", _schedule_reduce)
diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py
index dd12dc7cf..e5da48f10 100644
--- a/tests/python/relay/test_op_level4.py
+++ b/tests/python/relay/test_op_level4.py
@@ -106,8 +106,11 @@ def test_where():
     assert zz.checked_type == relay.TensorType((3, 4), "float32")
 
 
-def verify_reduce(test_func, data, axis, keepdims, exclude, output):
-    x = relay.var("x", relay.TensorType(data, "float32"))
+def verify_reduce(funcs, data, axis, keepdims, exclude, output, dtype="float32"):
+    test_func = funcs[0]
+    ref_func = funcs[1]
+
+    x = relay.var("x", relay.TensorType(data, dtype))
     z = test_func(x, axis, keepdims, exclude)
     zz = relay.ir_pass.infer_type(z)
     if axis:
@@ -116,25 +119,60 @@ def verify_reduce(test_func, data, axis, keepdims, exclude, output):
         assert "keepdims=" in z.astext()
     if exclude:
         assert "exclude=" in z.astext()
-    out_type = "int32" if test_func in [relay.argmin, relay.argmax] else "float32"
+    out_type = "int32" if test_func in [relay.argmin, relay.argmax] else dtype
     assert zz.checked_type == relay.ty.TensorType(output, out_type)
 
+    if all(isinstance(v, tvm.expr.Var) == 1 for v in data) or len(output) == 0:
+        return
+
+    func = relay.Function([x], z)
+    x_data = np.random.uniform(size=data).astype(dtype)
+    if ref_func in [np.sum]:
+        ref_res = ref_func(x_data + 0, axis=axis, dtype=dtype, keepdims=keepdims)
+    elif ref_func in [np.max, np.min, np.mean, np.prod]:
+        ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
+    else: #argmin/argmax
+        if axis and len(axis) > 1:
+            return
+        ref_res = ref_func(x_data + 0, axis=axis, keepdims=keepdims)
+
+    for target, ctx in ctx_list():
+        intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
+        intrp2 = relay.create_executor("debug", ctx=ctx, target=target)
+        op_res1 = intrp1.evaluate(func)(x_data)
+        tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5)
+        op_res2 = intrp2.evaluate(func)(x_data)
+        tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
+
 def test_reduce_functions():
+    def _with_keepdims(func):
+        def _wrapper(data, axis=None, keepdims=False):
+            if not keepdims:
+                return func(data, axis=axis)
+            else:
+                if axis is not None:
+                    axis = axis[0]
+                    out_shape = list(data.shape)
+                    out_shape[axis] = 1
+                else:
+                    out_shape = [1 for _ in range(len(data.shape))]
+                return func(data, axis=axis).reshape(out_shape)
+        return _wrapper
+
     d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
-    for func in [relay.sum,
-                 relay.max,
-                 relay.min,
-                 relay.mean,
-                 relay.prod,
-                 relay.argmin,
-                 relay.argmax]:
+    for func in [[relay.sum, np.sum],
+                 [relay.max, np.max],
+                 [relay.min, np.min],
+                 [relay.mean, np.mean],
+                 [relay.prod, np.prod],
+                 [relay.argmin, _with_keepdims(np.argmin)],
+                 [relay.argmax, _with_keepdims(np.argmax)]]:
         verify_reduce(func, (d1, d2, d3, d4), (2,), True, False, (d1, d2, 1, d4))
         verify_reduce(func, (d1, d2, d3), (1,), True, False, (d1, 1, d3))
         verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))
         verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3))
         verify_reduce(func, (2, 3, 4), (1,), True, False, (2, 1, 4))
         verify_reduce(func, (2, 3, 4), (0, 1, 2), False, False, ())
-        verify_reduce(func, (4, 4, 3), None, True, False, (1, 1, 1))
         verify_reduce(func, (4, 4, 3), None, False, True, ())
         verify_reduce(func, (4, 4, 3), (0, 2), False, False, (4,))
         verify_reduce(func, (128, 24, 128), (0, 1), False, False, (128,))
-- 
GitLab