From 2a8e07469da3e0178e693ecc9fbb4d3d3a15e826 Mon Sep 17 00:00:00 2001 From: Xingjian Shi <xshiab@ust.hk> Date: Wed, 27 Dec 2017 01:52:37 -0800 Subject: [PATCH] [TOPI]Support dim-0 tensor in topi broadcast/reduce (#731) * support dim-0 tensor in topi ops revert transform * revert --- topi/python/topi/reduction.py | 4 +--- topi/tests/python/test_topi_broadcast.py | 2 ++ topi/tests/python/test_topi_reduce.py | 5 ++++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/topi/python/topi/reduction.py b/topi/python/topi/reduction.py index 3c6bf1ca0..997ec8e9b 100644 --- a/topi/python/topi/reduction.py +++ b/topi/python/topi/reduction.py @@ -107,10 +107,8 @@ def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum, is_idx_reduce=Fal ret : tvm.Tensor """ ndim = len(data.shape) + assert ndim != 0, "Reduce a dim-0 input is not supported!" real_axis = _get_real_axis(ndim, axis) - if real_axis == list(range(ndim)) and keepdims is False: - raise ValueError("Currently we do not support all reduce + keepdims = False!" - " axis={}, keepdims={}".format(axis, keepdims)) reduce_axes = [tvm.reduce_axis((0, data.shape[i]), "k%d" %i) for i in real_axis] if keepdims: target_shape = [1 if i in real_axis else data.shape[i] for i in range(ndim)] diff --git a/topi/tests/python/test_topi_broadcast.py b/topi/tests/python/test_topi_broadcast.py index e5f88e9d4..28a9e721a 100644 --- a/topi/tests/python/test_topi_broadcast.py +++ b/topi/tests/python/test_topi_broadcast.py @@ -89,12 +89,14 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"): def test_broadcast_to(): verify_broadcast_to_ele((1,), (10,)) + verify_broadcast_to_ele((), (10,)) verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4)) verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32)) def test_broadcast_binary(): verify_broadcast_binary_ele((5, 2, 3), (2, 1), typ="add") + verify_broadcast_binary_ele((5, 2, 3), (), typ="add") verify_broadcast_binary_ele((5, 64, 128), (2, 5, 64, 1), typ="mul") verify_broadcast_binary_ele((2, 3, 1, 32), (64, 32), typ="div") verify_broadcast_binary_ele((1, 32), (64, 32), typ="sub") diff --git a/topi/tests/python/test_topi_reduce.py b/topi/tests/python/test_topi_reduce.py index 13cd8fcdc..08e66e140 100644 --- a/topi/tests/python/test_topi_reduce.py +++ b/topi/tests/python/test_topi_reduce.py @@ -108,7 +108,10 @@ def test_reduce_map(): axis=None, keepdims=True, type="argmax") - + verify_reduce_map_ele(in_shape=(31, 21, 15), + axis=None, + keepdims=False, + type="sum") if __name__ == "__main__": test_reduce_map() -- GitLab