From 401ffe131e207bc83fa424df3dbc14ed1c987731 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" <slyubomirsky@gmail.com> Date: Mon, 19 Nov 2018 21:35:21 -0800 Subject: [PATCH] [Relay][Op] Add test for batch_flatten (#2134) * Add tests for batch_flatten and softmax * Softmax is already tested elsewhere --- python/tvm/relay/op/nn/_nn.py | 1 + tests/python/relay/test_op_level2.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index cd807ad62..b48bfde97 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -9,6 +9,7 @@ from ..op import OpPattern, schedule_injective reg.register_schedule("nn.relu", schedule_injective) reg.register_pattern("nn.relu", OpPattern.ELEMWISE) +# softmax @reg.register_schedule("nn.softmax") def schedule_softmax(_, outputs, target): """Schedule definition of softmax""" diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 1ae372407..cd9321c5a 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -391,6 +391,27 @@ def test_l2_normalize(): tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) +def batch_flatten(data): + shape = data.shape + target_dim = 1 + for i in range(len(shape) - 1): + target_dim = target_dim * shape[i + 1] + return np.reshape(data, (shape[0], target_dim)) + + +def test_batch_flatten(): + t1 = relay.TensorType((5, 10, 5)) + x = relay.Var("x", t1) + func = relay.Function([x], relay.nn.batch_flatten(x)) + + data = np.random.rand(5, 10, 5).astype(t1.dtype) + ref_res = batch_flatten(data) + for target, ctx in ctx_list(): + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(data) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) + + if __name__ == "__main__": test_pool2d() test_avg_pool2d_no_count_pad() @@ -403,3 +424,4 @@ if __name__ == "__main__": test_conv2d_transpose_infer_type() test_conv2d_transpose_run() test_conv2d_run() + test_batch_flatten() -- GitLab