From 401ffe131e207bc83fa424df3dbc14ed1c987731 Mon Sep 17 00:00:00 2001
From: "Steven S. Lyubomirsky" <slyubomirsky@gmail.com>
Date: Mon, 19 Nov 2018 21:35:21 -0800
Subject: [PATCH] [Relay][Op] Add test for batch_flatten (#2134)

* Add tests for batch_flatten and softmax

* Softmax is already tested elsewhere
---
 python/tvm/relay/op/nn/_nn.py        |  1 +
 tests/python/relay/test_op_level2.py | 22 ++++++++++++++++++++++
 2 files changed, 23 insertions(+)

diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index cd807ad62..b48bfde97 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -9,6 +9,7 @@ from ..op import OpPattern, schedule_injective
 reg.register_schedule("nn.relu", schedule_injective)
 reg.register_pattern("nn.relu", OpPattern.ELEMWISE)
 
+# softmax
 @reg.register_schedule("nn.softmax")
 def schedule_softmax(_, outputs, target):
     """Schedule definition of softmax"""
diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py
index 1ae372407..cd9321c5a 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -391,6 +391,27 @@ def test_l2_normalize():
         tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
 
 
+def batch_flatten(data):
+    shape = data.shape
+    target_dim = 1
+    for i in range(len(shape) - 1):
+        target_dim = target_dim * shape[i + 1]
+    return np.reshape(data, (shape[0], target_dim))
+
+
+def test_batch_flatten():
+    t1 = relay.TensorType((5, 10, 5))
+    x = relay.Var("x", t1)
+    func = relay.Function([x], relay.nn.batch_flatten(x))
+
+    data = np.random.rand(5, 10, 5).astype(t1.dtype)
+    ref_res = batch_flatten(data)
+    for target, ctx in ctx_list():
+        intrp = relay.create_executor("graph", ctx=ctx, target=target)
+        op_res = intrp.evaluate(func)(data)
+        np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
+
+
 if __name__ == "__main__":
     test_pool2d()
     test_avg_pool2d_no_count_pad()
@@ -403,3 +424,4 @@ if __name__ == "__main__":
     test_conv2d_transpose_infer_type()
     test_conv2d_transpose_run()
     test_conv2d_run()
+    test_batch_flatten()
-- 
GitLab