From 9473dca266e307cf1f9faece219af111686ca946 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" <slyubomirsky@gmail.com> Date: Sat, 24 Nov 2018 22:33:14 -0500 Subject: [PATCH] [Relay][Op] Add compute, schedule, and tests for expand_dims and squeeze (#2133) --- python/tvm/relay/op/_transform.py | 45 +++++++++++++++++++++++++++- tests/python/relay/test_op_level1.py | 17 +++++++++++ tests/python/relay/test_op_level3.py | 17 +++++++++++ 3 files changed, 78 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 01814e0f7..cd32aea38 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -1,8 +1,51 @@ #pylint: disable=invalid-name, unused-argument """Backend compiler related feature registration""" from __future__ import absolute_import +import topi +import topi.cuda +from tvm import container from . import op as _reg -from .op import schedule_injective, OpPattern +from .op import (schedule_injective, register_compute, register_schedule, + register_pattern, OpPattern) + +schedule_broadcast = schedule_injective + +# squeeze +@register_compute("squeeze") +def squeeze_compiler(attrs, inputs, output_type, target): + """Compiler for squeeze dims.""" + assert len(inputs) == 1 + + if attrs.axis is None: + axis = None + elif isinstance(attrs.axis, container.Array): + axis = tuple(attrs.axis) + else: + axis = int(attrs.axis) + + return [topi.squeeze(inputs[0], axis)] + +register_pattern("squeeze", OpPattern.INJECTIVE) +register_schedule("squeeze", schedule_injective) + +# expand_dims +@register_compute("expand_dims") +def expand_dims_compiler(attrs, inputs, output_type, target): + """Compiler for expand_dims.""" + assert len(inputs) == 1 + + new_axis = int(attrs.num_newaxis) + assert new_axis >= 0 + + # axis should be in range [-data.ndim - 1, data.ndim] + axis = int(attrs.axis) + assert axis >= -len(inputs[0].shape) - 1 + assert axis <= len(inputs[0].shape) + + return [topi.expand_dims(inputs[0], axis, new_axis)] + +_reg.register_schedule("expand_dims", schedule_broadcast) +_reg.register_pattern("expand_dims", OpPattern.BROADCAST) # strided_slice _reg.register_schedule("strided_slice", schedule_injective) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 35844ddd4..d28aa0a56 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -90,6 +90,22 @@ def test_binary_op(): check_binary_op(opfunc, ref) +def test_expand_dims(): + # based on topi test + def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis): + x = relay.Var("x", relay.TensorType(dshape, dtype)) + func = relay.Function([x], relay.expand_dims(x, axis, num_newaxis)) + for target, ctx in ctx_list(): + data = np.random.uniform(size=dshape).astype(dtype) + ref_res = data.reshape(oshape) + intrp = relay.create_executor("graph", ctx=ctx, target=target) + op_res = intrp.evaluate(func)(data) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) + + verify_expand_dims((3, 10), 'float32', (3, 10, 1, 1), 2, 2) + verify_expand_dims((3, 10), 'float32', (1, 3, 10), -3, 1) + + def test_bias_add(): xshape=(10, 2, 3, 4) bshape=(2,) @@ -295,6 +311,7 @@ if __name__ == "__main__": test_binary_op() test_expand_dims_infer_type() test_concatenate() + test_expand_dims() test_softmax() test_log_softmax() test_dropout() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index 6f8fbd551..f6951b5ab 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -60,6 +60,22 @@ def test_clip(): np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) +def test_squeeze(): + def verify_squeeze(shape, dtype, axis): + x = relay.var("x", relay.TensorType(shape, dtype)) + squeeze = relay.squeeze(x, axis=axis) + + np_axis = tuple(axis) if axis is not None else None + + data = np.random.random_sample(shape).astype(dtype) + intrp = create_executor() + op_res = intrp.evaluate(squeeze, { x : relay.const(data) }) + ref_res = np.squeeze(data, axis=np_axis) + np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) + + verify_squeeze((1, 3, 2, 5), "float32", None) + verify_squeeze((1, 3, 1), "float32", [0]) + verify_squeeze((1, 2, 1, 2, 1), "float32", [0, 2]) def test_transpose_infer_type(): @@ -308,6 +324,7 @@ if __name__ == "__main__": test_full_like() test_infer_type_leaky_relu() test_infer_type_prelu() + test_squeeze() test_squeeze_infer_type() test_squeeze_bad_axes_infer_type() test_split_infer_type() -- GitLab