diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst
index 7f150ddbf7cd23c39a583a3c25849957cfca2ad5..767dfe1ba8448ef4b21ddeb7e914063ec0fbae1c 100644
--- a/docs/api/python/topi.rst
+++ b/docs/api/python/topi.rst
@@ -49,6 +49,7 @@ List of operators
    topi.min
    topi.argmax
    topi.argmin
+   topi.prod
    topi.broadcast_to
    topi.add
    topi.subtract
@@ -107,6 +108,7 @@ topi
 .. autofunction:: topi.max
 .. autofunction:: topi.sum
 .. autofunction:: topi.min
+.. autofunction:: topi.prod
 .. autofunction:: topi.broadcast_to
 .. autofunction:: topi.add
 .. autofunction:: topi.subtract
diff --git a/docs/nnvm_top.rst b/docs/nnvm_top.rst
index 663c85ac789e33fad91dcf4660d316be224e3edc..be1077f664c35637e43c6ce188bb2c93e4911cc6 100644
--- a/docs/nnvm_top.rst
+++ b/docs/nnvm_top.rst
@@ -114,6 +114,8 @@ This level enables typical convnet models.
    nnvm.symbol.sum
    nnvm.symbol.min
    nnvm.symbol.max
+   nnvm.symbol.mean
+   nnvm.symbol.prod
    nnvm.symbol.broadcast_add
    nnvm.symbol.broadcast_sub
    nnvm.symbol.broadcast_mul
@@ -228,6 +230,8 @@ Detailed Definitions
 .. autofunction:: nnvm.symbol.sum
 .. autofunction:: nnvm.symbol.min
 .. autofunction:: nnvm.symbol.max
+.. autofunction:: nnvm.symbol.mean
+.. autofunction:: nnvm.symbol.prod
 .. autofunction:: nnvm.symbol.broadcast_add
 .. autofunction:: nnvm.symbol.broadcast_sub
 .. autofunction:: nnvm.symbol.broadcast_mul
diff --git a/include/tvm/expr.h b/include/tvm/expr.h
index fe645bcf580a25323ee9b5d4f0d964d0ed9e6707..fb2233dacb69e37cea8f53833468211bccba39d3 100644
--- a/include/tvm/expr.h
+++ b/include/tvm/expr.h
@@ -36,6 +36,7 @@ using HalideIR::Internal::Variable;
 
 using HalideIR::Internal::make_const;
 using HalideIR::Internal::make_zero;
+using HalideIR::Internal::make_one;
 using HalideIR::Internal::as_const_int;
 using HalideIR::Internal::as_const_uint;
 using HalideIR::Internal::const_true;
diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h
index e809b06e49b5f45cc04a123011ad7ffa4892a540..39588a2228f994d9d6e687f1baa7176685392055 100644
--- a/include/tvm/ir_operator.h
+++ b/include/tvm/ir_operator.h
@@ -41,6 +41,12 @@ TVM_DLL Expr max(Expr source, Array<IterVar> axis);
  */
 TVM_DLL Expr min(Expr source, Array<IterVar> axis);
 
+/*!
+ * \brief product of of source expression over axis
+ * \param source The source expression.
+ * \param axis List of iteration variables that will be used for reduction.
+ */
+TVM_DLL Expr prod(Expr source, Array<IterVar> axis);
 
 // Unary intrinsic operators
 #define TVM_DECLARE_INTRIN_UNARY(OpName)                                \
diff --git a/nnvm/python/nnvm/top/reduction.py b/nnvm/python/nnvm/top/reduction.py
index fd8e2f8df56e5800967ffa3f3ff96c7339677717..aef6e1dcc4a8e9f1e0b53f900fa16f06d9bc5d78 100644
--- a/nnvm/python/nnvm/top/reduction.py
+++ b/nnvm/python/nnvm/top/reduction.py
@@ -49,3 +49,11 @@ reg.register_schedule("argmax", _fschedule_reduce)
 # argmin
 reg.register_pattern("argmin", OpPattern.COMM_REDUCE)
 reg.register_schedule("argmin", _fschedule_reduce)
+
+# mean
+reg.register_pattern("mean", OpPattern.COMM_REDUCE)
+reg.register_schedule("mean", _fschedule_reduce)
+
+# product
+reg.register_pattern("prod", OpPattern.COMM_REDUCE)
+reg.register_schedule("prod", _fschedule_reduce)
diff --git a/nnvm/src/top/tensor/reduce.cc b/nnvm/src/top/tensor/reduce.cc
index d8f426b4f4bc235e7b93616e8d80d18672169760..10dd957422222e5b98ae5176cfe0237eaf401890 100644
--- a/nnvm/src/top/tensor/reduce.cc
+++ b/nnvm/src/top/tensor/reduce.cc
@@ -322,6 +322,70 @@ values over a given axis.
       topi::argmin(inputs[0], axis, param.keepdims) };
 });
 
+NNVM_REGISTER_REDUCE_OP(mean)
+  .describe(R"code(Computes the mean of array elements over given axes.
+
+Example::
+
+  data = [[[1,2],[2,3],[1,3]],
+          [[1,4],[4,3],[5,2]],
+          [[7,1],[7,2],[7,3]]]
+
+  mean(data)
+  [3.22]
+
+  mean(data, axis=[1,2])
+  [ 2.  3.16666667  4.5]
+
+)code" NNVM_ADD_FILELINE)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
+    TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
+                                  param.axis, param.exclude);
+    if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
+    auto axis = ShapeToArray(r_axes);
+
+    Expr count = make_one(inputs[0]->dtype);
+    for (auto& i : r_axes) {
+      count *= inputs[0]->shape[i];
+    }
+
+    return Array<Tensor>{
+      topi::divide(topi::sum(inputs[0], axis, param.keepdims), count) };
+});
+
+NNVM_REGISTER_REDUCE_OP(prod)
+  .describe(R"code(Computes the products of array elements over given axes.
+
+Example::
+
+  data = [[[1,2],[2,3],[1,3]],
+          [[1,4],[4,3],[5,2]],
+          [[7,1],[7,2],[7,3]]]
+
+  mean(data, axis=1)
+  [35562240]
+
+  mean(data, axis=[1,2])
+  [ 36  480  2058]
+
+)code" NNVM_ADD_FILELINE)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed);
+    TShape r_axes = GetReduceAxes(inputs[0]->shape.size(),
+                                  param.axis, param.exclude);
+    if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) };
+    auto axis = ShapeToArray(r_axes);
+    return Array<Tensor>{
+      topi::prod(inputs[0], axis, param.keepdims) };
+});
+
 
 }  // namespace top
 }  // namespace nnvm
diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py
index 6503d2d2292d1d72969bdd568189db6daabb498b..16b02f956cccc463e11cf88cf979e5ad2fcfef5d 100644
--- a/nnvm/tests/python/compiler/test_top_level4.py
+++ b/nnvm/tests/python/compiler/test_top_level4.py
@@ -31,6 +31,9 @@ def verify_reduce_explicit(dshape, data, result, fsym, oshape=None, otype='float
     x = sym.Variable("x")
     y = fsym(x + 0, **kwargs)
     for target, ctx in ctx_list():
+        # TODO(yuruofei): remove when cuda reduce schedule is done
+        if target == 'cuda' and fsym == sym.mean:
+            continue
         graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
         m = graph_runtime.create(graph, lib, ctx)
         # set input
@@ -93,6 +96,13 @@ def test_reduce():
     verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True)
     verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2))
     verify_reduce((4, 4, 3), np.sum, sym.sum)
+    verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 1), keepdims=False)
+    verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 2), keepdims=False)
+    verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 1), keepdims=True)
+    verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 2), keepdims=True)
+    verify_reduce((128, 24, 128), np.mean, sym.mean, keepdims=True)
+    verify_reduce((128, 24, 128), np.mean, sym.mean, keepdims=False)
+    verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 1, 2), keepdims=True)
 
     data = np.array([[[1,2],[3,4]],[[3,44],[5,6]]], dtype=np.float32)
     verify_reduce_explicit([2,2,2], data, np.array([[1,1],[1,0]]), sym.argmax, otype='int32', axis=[0,2], exclude=True)
diff --git a/src/lang/ir_operator.cc b/src/lang/ir_operator.cc
index ded27bbdce7ea105ecb5cd18a88f2f5986e15a3d..50e598d13dc23bedf8bf0bc2eb34cf53455b6a75 100644
--- a/src/lang/ir_operator.cc
+++ b/src/lang/ir_operator.cc
@@ -35,4 +35,13 @@ Expr min(Expr source, Array<IterVar> rdom) {
   return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
 }
 
+Expr prod(Expr source, Array<IterVar> rdom) {
+  Var x("x"), y("y");
+  Expr result = ir::Mul::make(x, y);
+  Expr identity_element = make_one(source.type());
+  ir::CommReducer combiner =
+    ir::CommReducerNode::make({x}, {y}, {result}, {identity_element});
+  return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0);
+}
+
 }  // namespace tvm
diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h
index f14187471faf4b9450e66cd48abe4a269277a9fa..1ac3f2d6157c6d03eb46bd1f8434eed4fe034763 100644
--- a/topi/include/topi/reduction.h
+++ b/topi/include/topi/reduction.h
@@ -12,6 +12,7 @@
 #include <vector>
 #include <iterator>
 
+#include "topi/broadcast.h"
 #include "topi/elemwise.h"
 #include "topi/tags.h"
 #include "topi/transform.h"
@@ -288,6 +289,11 @@ inline Expr MaxOp(Expr source, Array<IterVar> axis) {
   return tvm::max(source, axis);  // NOLINT(*)
 }
 
+/*! \brief Wrap tvm::prod to ensure we get the correct overload */
+inline Expr ProdOp(Expr source, Array<IterVar> axis) {
+  return tvm::prod(source, axis);  // NOLINT(*)
+}
+
 /*!
 * \brief Creates an operation that sums array elements over a given axis
 *
@@ -426,5 +432,21 @@ inline Tensor argmax(const Tensor& data, Array<Expr> axis, bool keepdims = false
   return CommReduceIdx(data, axis, func, keepdims);
 }
 
+/*!
+* \brief Creates product operation over given axis.
+*
+* \param data The input tensor
+* \param axis The axis to do product over. If axis is empty, the
+* operation will do the product over all elements of the array.
+* \param keepdims If this is set to true, the axes which are reduced are
+* left in the result as dimensions with size one. This enables the result
+* to broadcast correctly against the input array.
+*
+* \return A Tensor whose op member is the prod operation
+*/
+inline Tensor prod(const Tensor& data, Array<Expr> axis, bool keepdims = false) {  // NOLINT(*)
+  return CommReduce(data, axis, ProdOp, keepdims);
+}
+
 }  // namespace topi
 #endif  // TOPI_REDUCTION_H_
diff --git a/topi/python/topi/reduction.py b/topi/python/topi/reduction.py
index 9f88953bb770f3a20ed96a7dead2bed9b04e8dd6..52121a506f43222e8aaf0d7a3f2d7c1c49acb9c5 100644
--- a/topi/python/topi/reduction.py
+++ b/topi/python/topi/reduction.py
@@ -2,8 +2,8 @@
 """Reduce operators"""
 from __future__ import absolute_import as _abs
 import tvm
+from . import cpp
 from . import tag
-from .util import ravel_index
 
 def _get_real_axis(ndim, axis):
     if axis is None:
@@ -26,130 +26,6 @@ def _get_real_axis(ndim, axis):
     return real_axis
 
 
-def get_reduce_out_shape(src_shape, axis=None, keepdims=False):
-    """Get the output shape for the reduction OPs
-
-    Parameters
-    ----------
-    src_shape : tuple of int or tvm.expr.IntImm
-
-    axis : None or int or tuple of int
-
-    keepdims : bool
-
-    Returns
-    -------
-    dst_shape : tuple of int or tvm.expr.IntImm
-    """
-    real_axis = _get_real_axis(len(src_shape), axis)
-    if keepdims:
-        dst_shape = [src_shape[i] if i in real_axis else 1 for i in range(len(src_shape))]
-    else:
-        dst_shape = []
-        for i in range(len(src_shape)):
-            if i not in real_axis:
-                dst_shape.append(src_shape[i])
-    return dst_shape
-
-
-def _argmax_comp(lhs, rhs):
-    """Compare function of argmax"""
-    idx = tvm.make.Select((lhs[1] >= rhs[1]), lhs[0], rhs[0])
-    val = tvm.make.Select((lhs[1] >= rhs[1]), lhs[1], rhs[1])
-    return idx, val
-
-
-def _argmax_init(idx_typ, val_typ):
-    """Initial ind and val of argmax"""
-    return tvm.const(-1, idx_typ), tvm.min_value(val_typ)
-
-
-def _argmin_comp(lhs, rhs):
-    """Compare function of argmin"""
-    idx = tvm.make.Select((lhs[1] <= rhs[1]), lhs[0], rhs[0])
-    val = tvm.make.Select((lhs[1] <= rhs[1]), lhs[1], rhs[1])
-    return idx, val
-
-
-def _argmin_init(idx_typ, val_typ):
-    """Initial ind and val of argmax"""
-    return tvm.const(-1, idx_typ), tvm.max_value(val_typ)
-
-
-def _choose_idx(idx, _, *indices):
-    """Chose the idx from idx and val"""
-    return idx(*indices)
-
-
-def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum, is_idx_reduce=False):
-    """Reducing the data
-
-    Parameters
-    ----------
-    data : tvm.Tensor
-        The input data
-
-    axis : None or int or tuple of int
-        Axis or axes along which a sum is performed.
-        The default, axis=None, will sum all of the elements of the input array.
-        If axis is negative it counts from the last to the first axis.
-
-    keepdims : bool
-        If this is set to True, the axes which are reduced are left in the result as dimensions
-         with size one.
-        With this option, the result will broadcast correctly against the input array.
-
-    func : function
-        functions like tvm.sum, tvm.max, tvm.min
-
-    Returns
-    -------
-    ret : tvm.Tensor
-    """
-    ndim = len(data.shape)
-    assert ndim != 0, "Reduce a dim-0 input is not supported!"
-    real_axis = _get_real_axis(ndim, axis)
-    reduce_axes = [tvm.reduce_axis((0, data.shape[i]), "k%d" %i) for i in real_axis]
-    if keepdims:
-        target_shape = [1 if i in real_axis else data.shape[i] for i in range(ndim)]
-    else:
-        target_shape = []
-        for i in range(ndim):
-            if i not in real_axis:
-                target_shape.append(tvm.convert(data.shape[i]))
-    def _compute(*indices):
-        eval_range = []
-        eval_indices = []
-        if not keepdims:
-            arg_counter = 0
-        else:
-            arg_counter = None
-        red_counter = 0
-        for i in range(len(data.shape)):
-            if i in real_axis:
-                eval_range.append(reduce_axes[red_counter])
-                eval_indices.append(reduce_axes[red_counter].var)
-                red_counter += 1
-            else:
-                if not keepdims:
-                    eval_range.append(indices[arg_counter])
-                    arg_counter += 1
-                else:
-                    eval_range.append(indices[i])
-        if not is_idx_reduce:
-            return func(data[tuple(eval_range)], axis=reduce_axes)
-        idx = ravel_index(eval_indices, [data.shape[i] for i in real_axis])
-        return func((idx, data[tuple(eval_range)]), axis=reduce_axes)
-    if is_idx_reduce:
-        temp_idx, temp_val = tvm.compute(target_shape, _compute, name=data.name + "_red_temp")
-        out = tvm.compute(target_shape,
-                          lambda *indices: _choose_idx(temp_idx, temp_val, *indices),
-                          name=data.name + "_red")
-    else:
-        out = tvm.compute(target_shape, _compute, name=data.name + "_red")
-    return out
-
-
 @tvm.tag_scope(tag=tag.COMM_REDUCE)
 def sum(data, axis=None, keepdims=False):
     """Sum of array elements over a given axis or a list of axes
@@ -173,7 +49,7 @@ def sum(data, axis=None, keepdims=False):
     -------
     ret : tvm.Tensor
     """
-    return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.sum)
+    return cpp.sum(data, axis, keepdims)
 
 
 @tvm.tag_scope(tag=tag.COMM_REDUCE)
@@ -199,7 +75,7 @@ def max(data, axis=None, keepdims=False):
     -------
     ret : tvm.Tensor
     """
-    return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.max)
+    return cpp.max(data, axis, keepdims)
 
 
 @tvm.tag_scope(tag=tag.COMM_REDUCE)
@@ -225,7 +101,7 @@ def min(data, axis=None, keepdims=False):
     -------
     ret : tvm.Tensor
     """
-    return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.min)
+    return cpp.min(data, axis, keepdims)
 
 
 @tvm.tag_scope(tag=tag.COMM_REDUCE_IDX)
@@ -251,8 +127,7 @@ def argmax(data, axis=None, keepdims=False):
     -------
     ret : tvm.Tensor
     """
-    _argmax = tvm.comm_reducer(fcombine=_argmax_comp, fidentity=_argmax_init, name='argmax')
-    return comm_reduce(data, axis=axis, keepdims=keepdims, func=_argmax, is_idx_reduce=True)
+    return cpp.argmax(data, axis, keepdims)
 
 
 @tvm.tag_scope(tag=tag.COMM_REDUCE_IDX)
@@ -278,5 +153,30 @@ def argmin(data, axis=None, keepdims=False):
     -------
     ret : tvm.Tensor
     """
-    _argmin = tvm.comm_reducer(fcombine=_argmin_comp, fidentity=_argmin_init, name='argmin')
-    return comm_reduce(data, axis=axis, keepdims=keepdims, func=_argmin, is_idx_reduce=True)
+    return cpp.argmin(data, axis, keepdims)
+
+
+@tvm.tag_scope(tag=tag.COMM_REDUCE)
+def prod(data, axis=None, keepdims=False):
+    """Product of array elements over a given axis or a list of axes
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        The input tvm tensor
+
+    axis : None or int or tuple of int
+        Axis or axes along which a prod operation is performed.
+        The default, axis=None, will get the prod element over all of the elements of the
+        input array. If axis is negative it counts from the last to the first axis.
+
+    keepdims : bool
+        If this is set to True, the axes which are reduced are left in the result as dimensions
+        with size one.
+        With this option, the result will broadcast correctly against the input array.
+
+    Returns
+    -------
+    ret : tvm.Tensor
+    """
+    return cpp.prod(data, axis, keepdims)
diff --git a/topi/src/topi.cc b/topi/src/topi.cc
index 4cdab44014593f51ff64af30c16d799f7c996c6b..cac3545a75a2ebc42fec94fd40aa88af0314d9a7 100644
--- a/topi/src/topi.cc
+++ b/topi/src/topi.cc
@@ -230,6 +230,11 @@ TVM_REGISTER_GLOBAL("topi.argmax")
   *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]);
   });
 
+TVM_REGISTER_GLOBAL("topi.prod")
+.set_body([](TVMArgs args, TVMRetValue *rv) {
+  *rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]);
+  });
+
 /* Ops from transform.h */
 TVM_REGISTER_GLOBAL("topi.expand_dims")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
diff --git a/topi/tests/python/test_topi_reduce.py b/topi/tests/python/test_topi_reduce.py
index 0be652948060922c737633536b4e7d886ef1efa7..ceb2a4fe1bb17c154946f11dc08741b2da4a3f5b 100644
--- a/topi/tests/python/test_topi_reduce.py
+++ b/topi/tests/python/test_topi_reduce.py
@@ -72,6 +72,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
             out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims)
         else:
             raise NotImplementedError
+        out_npy = np.atleast_1d(out_npy)
         data_tvm = tvm.nd.array(in_npy, ctx=ctx)
         out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
         for _ in range(1):
diff --git a/topi/tests/python_cpp/test_topi_reduce.py b/topi/tests/python_cpp/test_topi_reduce.py
index 7bf369c7f1ff75521feb0a7a7663c07bc3216dcf..ab4ac9372373f946603acdc5090f088159594399 100644
--- a/topi/tests/python_cpp/test_topi_reduce.py
+++ b/topi/tests/python_cpp/test_topi_reduce.py
@@ -42,6 +42,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
     elif type == "argmin":
         B = topi.cpp.argmin(A1, axis, keepdims)
         out_dtype = "int32"
+    elif type == "prod":
+        B = topi.cpp.prod(A1, axis, keepdims)
     else:
         raise NotImplementedError
 
@@ -57,7 +59,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
         else:
             s = topi.cpp.cuda.schedule_reduce(target, [B])
 
-        foo = tvm.build(s, [A, B], device, name="sum")
+        foo = tvm.build(s, [A, B], device, name=type)
         # Test
         in_npy = np.random.uniform(size=in_shape).astype(np.float32)
         in_npy_map = np.sqrt(np.exp(in_npy)).astype(np.float32)
@@ -71,6 +73,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
             out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims)
         elif type == "argmin":
             out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims)
+        elif type == "prod":
+            out_npy = in_npy_map.prod(axis=axis, keepdims=keepdims)
         else:
             raise NotImplementedError
         out_npy = np.atleast_1d(out_npy)
@@ -100,21 +104,29 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
 
 def test_reduce_map():
     verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
-                        axis=(1, 2, 3),
-                        keepdims=True,
-                        type="sum")
+                          axis=(1, 2, 3),
+                          keepdims=True,
+                          type="sum")
     verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24),
-                        axis=(1,),
-                        keepdims=False,
-                        type="max")
+                          axis=(1,),
+                          keepdims=False,
+                          type="max")
     verify_reduce_map_ele(in_shape=(32, 128, 24),
-                        axis=None,
-                        keepdims=True,
-                        type="sum")
+                          axis=None,
+                          keepdims=True,
+                          type="sum")
     verify_reduce_map_ele(in_shape=(128, 24, 128, 24),
-                        axis=(0, 2),
-                        keepdims=False,
-                        type="min")
+                          axis=(0, 2),
+                          keepdims=False,
+                          type="min")
+    verify_reduce_map_ele(in_shape=(128, 4, 4, 128),
+                          axis=(1, ),
+                          keepdims=True,
+                          type="prod")
+    verify_reduce_map_ele(in_shape=(4, 4),
+                          axis=(0, 1),
+                          keepdims=False,
+                          type="prod")
     verify_reduce_map_ele(in_shape=(32, 128),
                           axis=1,
                           keepdims=True,