From 2a8e07469da3e0178e693ecc9fbb4d3d3a15e826 Mon Sep 17 00:00:00 2001
From: Xingjian Shi <xshiab@ust.hk>
Date: Wed, 27 Dec 2017 01:52:37 -0800
Subject: [PATCH] [TOPI]Support dim-0 tensor in topi broadcast/reduce (#731)

* support dim-0 tensor in topi ops

revert transform

* revert
---
 topi/python/topi/reduction.py            | 4 +---
 topi/tests/python/test_topi_broadcast.py | 2 ++
 topi/tests/python/test_topi_reduce.py    | 5 ++++-
 3 files changed, 7 insertions(+), 4 deletions(-)

diff --git a/topi/python/topi/reduction.py b/topi/python/topi/reduction.py
index 3c6bf1ca0..997ec8e9b 100644
--- a/topi/python/topi/reduction.py
+++ b/topi/python/topi/reduction.py
@@ -107,10 +107,8 @@ def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum, is_idx_reduce=Fal
     ret : tvm.Tensor
     """
     ndim = len(data.shape)
+    assert ndim != 0, "Reduce a dim-0 input is not supported!"
     real_axis = _get_real_axis(ndim, axis)
-    if real_axis == list(range(ndim)) and keepdims is False:
-        raise ValueError("Currently we do not support all reduce + keepdims = False!"
-                         " axis={}, keepdims={}".format(axis, keepdims))
     reduce_axes = [tvm.reduce_axis((0, data.shape[i]), "k%d" %i) for i in real_axis]
     if keepdims:
         target_shape = [1 if i in real_axis else data.shape[i] for i in range(ndim)]
diff --git a/topi/tests/python/test_topi_broadcast.py b/topi/tests/python/test_topi_broadcast.py
index e5f88e9d4..28a9e721a 100644
--- a/topi/tests/python/test_topi_broadcast.py
+++ b/topi/tests/python/test_topi_broadcast.py
@@ -89,12 +89,14 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, typ="add"):
 
 def test_broadcast_to():
     verify_broadcast_to_ele((1,), (10,))
+    verify_broadcast_to_ele((), (10,))
     verify_broadcast_to_ele((1, 1, 5, 4), (3, 4, 4, 4, 5, 4))
     verify_broadcast_to_ele((1, 128, 1, 32), (64, 128, 64, 32))
 
 
 def test_broadcast_binary():
     verify_broadcast_binary_ele((5, 2, 3), (2, 1), typ="add")
+    verify_broadcast_binary_ele((5, 2, 3), (), typ="add")
     verify_broadcast_binary_ele((5, 64, 128), (2, 5, 64, 1), typ="mul")
     verify_broadcast_binary_ele((2, 3, 1, 32), (64, 32), typ="div")
     verify_broadcast_binary_ele((1, 32), (64, 32), typ="sub")
diff --git a/topi/tests/python/test_topi_reduce.py b/topi/tests/python/test_topi_reduce.py
index 13cd8fcdc..08e66e140 100644
--- a/topi/tests/python/test_topi_reduce.py
+++ b/topi/tests/python/test_topi_reduce.py
@@ -108,7 +108,10 @@ def test_reduce_map():
                           axis=None,
                           keepdims=True,
                           type="argmax")
-
+    verify_reduce_map_ele(in_shape=(31, 21, 15),
+                          axis=None,
+                          keepdims=False,
+                          type="sum")
 
 if __name__ == "__main__":
     test_reduce_map()
-- 
GitLab