From ac7b42054f85647ba2536aac716691eee72ac6c6 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Wed, 19 Sep 2018 10:10:26 -0700
Subject: [PATCH] [TOPI] Fix reduce behavior to be consistent to numpy (#1738)

[TOPI] Fix reduce behavior to be consistent with numpy
---
 topi/include/topi/reduction.h             | 3 ---
 topi/tests/python/test_topi_reduce.py     | 1 -
 topi/tests/python_cpp/test_topi_reduce.py | 1 -
 3 files changed, 5 deletions(-)

diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h
index 1ac3f2d61..ccc85e966 100644
--- a/topi/include/topi/reduction.h
+++ b/topi/include/topi/reduction.h
@@ -95,9 +95,6 @@ inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis,
         target_shape.push_back(data->shape[i]);
       }
     }
-    if (target_shape.size() == 0) {
-      target_shape.push_back(1);
-    }
   }
   return target_shape;
 }
diff --git a/topi/tests/python/test_topi_reduce.py b/topi/tests/python/test_topi_reduce.py
index ceb2a4fe1..0be652948 100644
--- a/topi/tests/python/test_topi_reduce.py
+++ b/topi/tests/python/test_topi_reduce.py
@@ -72,7 +72,6 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
             out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims)
         else:
             raise NotImplementedError
-        out_npy = np.atleast_1d(out_npy)
         data_tvm = tvm.nd.array(in_npy, ctx=ctx)
         out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
         for _ in range(1):
diff --git a/topi/tests/python_cpp/test_topi_reduce.py b/topi/tests/python_cpp/test_topi_reduce.py
index ab4ac9372..b17176938 100644
--- a/topi/tests/python_cpp/test_topi_reduce.py
+++ b/topi/tests/python_cpp/test_topi_reduce.py
@@ -77,7 +77,6 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
             out_npy = in_npy_map.prod(axis=axis, keepdims=keepdims)
         else:
             raise NotImplementedError
-        out_npy = np.atleast_1d(out_npy)
         data_tvm = tvm.nd.array(in_npy, ctx=ctx)
         out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
         for _ in range(1):
-- 
GitLab