diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index 7f150ddbf7cd23c39a583a3c25849957cfca2ad5..767dfe1ba8448ef4b21ddeb7e914063ec0fbae1c 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -49,6 +49,7 @@ List of operators topi.min topi.argmax topi.argmin + topi.prod topi.broadcast_to topi.add topi.subtract @@ -107,6 +108,7 @@ topi .. autofunction:: topi.max .. autofunction:: topi.sum .. autofunction:: topi.min +.. autofunction:: topi.prod .. autofunction:: topi.broadcast_to .. autofunction:: topi.add .. autofunction:: topi.subtract diff --git a/docs/nnvm_top.rst b/docs/nnvm_top.rst index 663c85ac789e33fad91dcf4660d316be224e3edc..be1077f664c35637e43c6ce188bb2c93e4911cc6 100644 --- a/docs/nnvm_top.rst +++ b/docs/nnvm_top.rst @@ -114,6 +114,8 @@ This level enables typical convnet models. nnvm.symbol.sum nnvm.symbol.min nnvm.symbol.max + nnvm.symbol.mean + nnvm.symbol.prod nnvm.symbol.broadcast_add nnvm.symbol.broadcast_sub nnvm.symbol.broadcast_mul @@ -228,6 +230,8 @@ Detailed Definitions .. autofunction:: nnvm.symbol.sum .. autofunction:: nnvm.symbol.min .. autofunction:: nnvm.symbol.max +.. autofunction:: nnvm.symbol.mean +.. autofunction:: nnvm.symbol.prod .. autofunction:: nnvm.symbol.broadcast_add .. autofunction:: nnvm.symbol.broadcast_sub .. autofunction:: nnvm.symbol.broadcast_mul diff --git a/include/tvm/expr.h b/include/tvm/expr.h index fe645bcf580a25323ee9b5d4f0d964d0ed9e6707..fb2233dacb69e37cea8f53833468211bccba39d3 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -36,6 +36,7 @@ using HalideIR::Internal::Variable; using HalideIR::Internal::make_const; using HalideIR::Internal::make_zero; +using HalideIR::Internal::make_one; using HalideIR::Internal::as_const_int; using HalideIR::Internal::as_const_uint; using HalideIR::Internal::const_true; diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h index e809b06e49b5f45cc04a123011ad7ffa4892a540..39588a2228f994d9d6e687f1baa7176685392055 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/ir_operator.h @@ -41,6 +41,12 @@ TVM_DLL Expr max(Expr source, Array<IterVar> axis); */ TVM_DLL Expr min(Expr source, Array<IterVar> axis); +/*! + * \brief product of of source expression over axis + * \param source The source expression. + * \param axis List of iteration variables that will be used for reduction. + */ +TVM_DLL Expr prod(Expr source, Array<IterVar> axis); // Unary intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ diff --git a/nnvm/python/nnvm/top/reduction.py b/nnvm/python/nnvm/top/reduction.py index fd8e2f8df56e5800967ffa3f3ff96c7339677717..aef6e1dcc4a8e9f1e0b53f900fa16f06d9bc5d78 100644 --- a/nnvm/python/nnvm/top/reduction.py +++ b/nnvm/python/nnvm/top/reduction.py @@ -49,3 +49,11 @@ reg.register_schedule("argmax", _fschedule_reduce) # argmin reg.register_pattern("argmin", OpPattern.COMM_REDUCE) reg.register_schedule("argmin", _fschedule_reduce) + +# mean +reg.register_pattern("mean", OpPattern.COMM_REDUCE) +reg.register_schedule("mean", _fschedule_reduce) + +# product +reg.register_pattern("prod", OpPattern.COMM_REDUCE) +reg.register_schedule("prod", _fschedule_reduce) diff --git a/nnvm/src/top/tensor/reduce.cc b/nnvm/src/top/tensor/reduce.cc index d8f426b4f4bc235e7b93616e8d80d18672169760..10dd957422222e5b98ae5176cfe0237eaf401890 100644 --- a/nnvm/src/top/tensor/reduce.cc +++ b/nnvm/src/top/tensor/reduce.cc @@ -322,6 +322,70 @@ values over a given axis. topi::argmin(inputs[0], axis, param.keepdims) }; }); +NNVM_REGISTER_REDUCE_OP(mean) + .describe(R"code(Computes the mean of array elements over given axes. + +Example:: + + data = [[[1,2],[2,3],[1,3]], + [[1,4],[4,3],[5,2]], + [[7,1],[7,2],[7,3]]] + + mean(data) + [3.22] + + mean(data, axis=[1,2]) + [ 2. 3.16666667 4.5] + +)code" NNVM_ADD_FILELINE) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); + TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), + param.axis, param.exclude); + if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) }; + auto axis = ShapeToArray(r_axes); + + Expr count = make_one(inputs[0]->dtype); + for (auto& i : r_axes) { + count *= inputs[0]->shape[i]; + } + + return Array<Tensor>{ + topi::divide(topi::sum(inputs[0], axis, param.keepdims), count) }; +}); + +NNVM_REGISTER_REDUCE_OP(prod) + .describe(R"code(Computes the products of array elements over given axes. + +Example:: + + data = [[[1,2],[2,3],[1,3]], + [[1,4],[4,3],[5,2]], + [[7,1],[7,2],[7,3]]] + + mean(data, axis=1) + [35562240] + + mean(data, axis=[1,2]) + [ 36 480 2058] + +)code" NNVM_ADD_FILELINE) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const ReduceParam& param = nnvm::get<ReduceParam>(attrs.parsed); + TShape r_axes = GetReduceAxes(inputs[0]->shape.size(), + param.axis, param.exclude); + if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) }; + auto axis = ShapeToArray(r_axes); + return Array<Tensor>{ + topi::prod(inputs[0], axis, param.keepdims) }; +}); + } // namespace top } // namespace nnvm diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index 6503d2d2292d1d72969bdd568189db6daabb498b..16b02f956cccc463e11cf88cf979e5ad2fcfef5d 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -31,6 +31,9 @@ def verify_reduce_explicit(dshape, data, result, fsym, oshape=None, otype='float x = sym.Variable("x") y = fsym(x + 0, **kwargs) for target, ctx in ctx_list(): + # TODO(yuruofei): remove when cuda reduce schedule is done + if target == 'cuda' and fsym == sym.mean: + continue graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) m = graph_runtime.create(graph, lib, ctx) # set input @@ -93,6 +96,13 @@ def test_reduce(): verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True) verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2)) verify_reduce((4, 4, 3), np.sum, sym.sum) + verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 1), keepdims=False) + verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 2), keepdims=False) + verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 1), keepdims=True) + verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 2), keepdims=True) + verify_reduce((128, 24, 128), np.mean, sym.mean, keepdims=True) + verify_reduce((128, 24, 128), np.mean, sym.mean, keepdims=False) + verify_reduce((128, 24, 128), np.mean, sym.mean, axis=(0, 1, 2), keepdims=True) data = np.array([[[1,2],[3,4]],[[3,44],[5,6]]], dtype=np.float32) verify_reduce_explicit([2,2,2], data, np.array([[1,1],[1,0]]), sym.argmax, otype='int32', axis=[0,2], exclude=True) diff --git a/src/lang/ir_operator.cc b/src/lang/ir_operator.cc index ded27bbdce7ea105ecb5cd18a88f2f5986e15a3d..50e598d13dc23bedf8bf0bc2eb34cf53455b6a75 100644 --- a/src/lang/ir_operator.cc +++ b/src/lang/ir_operator.cc @@ -35,4 +35,13 @@ Expr min(Expr source, Array<IterVar> rdom) { return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } +Expr prod(Expr source, Array<IterVar> rdom) { + Var x("x"), y("y"); + Expr result = ir::Mul::make(x, y); + Expr identity_element = make_one(source.type()); + ir::CommReducer combiner = + ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); +} + } // namespace tvm diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h index f14187471faf4b9450e66cd48abe4a269277a9fa..1ac3f2d6157c6d03eb46bd1f8434eed4fe034763 100644 --- a/topi/include/topi/reduction.h +++ b/topi/include/topi/reduction.h @@ -12,6 +12,7 @@ #include <vector> #include <iterator> +#include "topi/broadcast.h" #include "topi/elemwise.h" #include "topi/tags.h" #include "topi/transform.h" @@ -288,6 +289,11 @@ inline Expr MaxOp(Expr source, Array<IterVar> axis) { return tvm::max(source, axis); // NOLINT(*) } +/*! \brief Wrap tvm::prod to ensure we get the correct overload */ +inline Expr ProdOp(Expr source, Array<IterVar> axis) { + return tvm::prod(source, axis); // NOLINT(*) +} + /*! * \brief Creates an operation that sums array elements over a given axis * @@ -426,5 +432,21 @@ inline Tensor argmax(const Tensor& data, Array<Expr> axis, bool keepdims = false return CommReduceIdx(data, axis, func, keepdims); } +/*! +* \brief Creates product operation over given axis. +* +* \param data The input tensor +* \param axis The axis to do product over. If axis is empty, the +* operation will do the product over all elements of the array. +* \param keepdims If this is set to true, the axes which are reduced are +* left in the result as dimensions with size one. This enables the result +* to broadcast correctly against the input array. +* +* \return A Tensor whose op member is the prod operation +*/ +inline Tensor prod(const Tensor& data, Array<Expr> axis, bool keepdims = false) { // NOLINT(*) + return CommReduce(data, axis, ProdOp, keepdims); +} + } // namespace topi #endif // TOPI_REDUCTION_H_ diff --git a/topi/python/topi/reduction.py b/topi/python/topi/reduction.py index 9f88953bb770f3a20ed96a7dead2bed9b04e8dd6..52121a506f43222e8aaf0d7a3f2d7c1c49acb9c5 100644 --- a/topi/python/topi/reduction.py +++ b/topi/python/topi/reduction.py @@ -2,8 +2,8 @@ """Reduce operators""" from __future__ import absolute_import as _abs import tvm +from . import cpp from . import tag -from .util import ravel_index def _get_real_axis(ndim, axis): if axis is None: @@ -26,130 +26,6 @@ def _get_real_axis(ndim, axis): return real_axis -def get_reduce_out_shape(src_shape, axis=None, keepdims=False): - """Get the output shape for the reduction OPs - - Parameters - ---------- - src_shape : tuple of int or tvm.expr.IntImm - - axis : None or int or tuple of int - - keepdims : bool - - Returns - ------- - dst_shape : tuple of int or tvm.expr.IntImm - """ - real_axis = _get_real_axis(len(src_shape), axis) - if keepdims: - dst_shape = [src_shape[i] if i in real_axis else 1 for i in range(len(src_shape))] - else: - dst_shape = [] - for i in range(len(src_shape)): - if i not in real_axis: - dst_shape.append(src_shape[i]) - return dst_shape - - -def _argmax_comp(lhs, rhs): - """Compare function of argmax""" - idx = tvm.make.Select((lhs[1] >= rhs[1]), lhs[0], rhs[0]) - val = tvm.make.Select((lhs[1] >= rhs[1]), lhs[1], rhs[1]) - return idx, val - - -def _argmax_init(idx_typ, val_typ): - """Initial ind and val of argmax""" - return tvm.const(-1, idx_typ), tvm.min_value(val_typ) - - -def _argmin_comp(lhs, rhs): - """Compare function of argmin""" - idx = tvm.make.Select((lhs[1] <= rhs[1]), lhs[0], rhs[0]) - val = tvm.make.Select((lhs[1] <= rhs[1]), lhs[1], rhs[1]) - return idx, val - - -def _argmin_init(idx_typ, val_typ): - """Initial ind and val of argmax""" - return tvm.const(-1, idx_typ), tvm.max_value(val_typ) - - -def _choose_idx(idx, _, *indices): - """Chose the idx from idx and val""" - return idx(*indices) - - -def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum, is_idx_reduce=False): - """Reducing the data - - Parameters - ---------- - data : tvm.Tensor - The input data - - axis : None or int or tuple of int - Axis or axes along which a sum is performed. - The default, axis=None, will sum all of the elements of the input array. - If axis is negative it counts from the last to the first axis. - - keepdims : bool - If this is set to True, the axes which are reduced are left in the result as dimensions - with size one. - With this option, the result will broadcast correctly against the input array. - - func : function - functions like tvm.sum, tvm.max, tvm.min - - Returns - ------- - ret : tvm.Tensor - """ - ndim = len(data.shape) - assert ndim != 0, "Reduce a dim-0 input is not supported!" - real_axis = _get_real_axis(ndim, axis) - reduce_axes = [tvm.reduce_axis((0, data.shape[i]), "k%d" %i) for i in real_axis] - if keepdims: - target_shape = [1 if i in real_axis else data.shape[i] for i in range(ndim)] - else: - target_shape = [] - for i in range(ndim): - if i not in real_axis: - target_shape.append(tvm.convert(data.shape[i])) - def _compute(*indices): - eval_range = [] - eval_indices = [] - if not keepdims: - arg_counter = 0 - else: - arg_counter = None - red_counter = 0 - for i in range(len(data.shape)): - if i in real_axis: - eval_range.append(reduce_axes[red_counter]) - eval_indices.append(reduce_axes[red_counter].var) - red_counter += 1 - else: - if not keepdims: - eval_range.append(indices[arg_counter]) - arg_counter += 1 - else: - eval_range.append(indices[i]) - if not is_idx_reduce: - return func(data[tuple(eval_range)], axis=reduce_axes) - idx = ravel_index(eval_indices, [data.shape[i] for i in real_axis]) - return func((idx, data[tuple(eval_range)]), axis=reduce_axes) - if is_idx_reduce: - temp_idx, temp_val = tvm.compute(target_shape, _compute, name=data.name + "_red_temp") - out = tvm.compute(target_shape, - lambda *indices: _choose_idx(temp_idx, temp_val, *indices), - name=data.name + "_red") - else: - out = tvm.compute(target_shape, _compute, name=data.name + "_red") - return out - - @tvm.tag_scope(tag=tag.COMM_REDUCE) def sum(data, axis=None, keepdims=False): """Sum of array elements over a given axis or a list of axes @@ -173,7 +49,7 @@ def sum(data, axis=None, keepdims=False): ------- ret : tvm.Tensor """ - return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.sum) + return cpp.sum(data, axis, keepdims) @tvm.tag_scope(tag=tag.COMM_REDUCE) @@ -199,7 +75,7 @@ def max(data, axis=None, keepdims=False): ------- ret : tvm.Tensor """ - return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.max) + return cpp.max(data, axis, keepdims) @tvm.tag_scope(tag=tag.COMM_REDUCE) @@ -225,7 +101,7 @@ def min(data, axis=None, keepdims=False): ------- ret : tvm.Tensor """ - return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.min) + return cpp.min(data, axis, keepdims) @tvm.tag_scope(tag=tag.COMM_REDUCE_IDX) @@ -251,8 +127,7 @@ def argmax(data, axis=None, keepdims=False): ------- ret : tvm.Tensor """ - _argmax = tvm.comm_reducer(fcombine=_argmax_comp, fidentity=_argmax_init, name='argmax') - return comm_reduce(data, axis=axis, keepdims=keepdims, func=_argmax, is_idx_reduce=True) + return cpp.argmax(data, axis, keepdims) @tvm.tag_scope(tag=tag.COMM_REDUCE_IDX) @@ -278,5 +153,30 @@ def argmin(data, axis=None, keepdims=False): ------- ret : tvm.Tensor """ - _argmin = tvm.comm_reducer(fcombine=_argmin_comp, fidentity=_argmin_init, name='argmin') - return comm_reduce(data, axis=axis, keepdims=keepdims, func=_argmin, is_idx_reduce=True) + return cpp.argmin(data, axis, keepdims) + + +@tvm.tag_scope(tag=tag.COMM_REDUCE) +def prod(data, axis=None, keepdims=False): + """Product of array elements over a given axis or a list of axes + + Parameters + ---------- + data : tvm.Tensor + The input tvm tensor + + axis : None or int or tuple of int + Axis or axes along which a prod operation is performed. + The default, axis=None, will get the prod element over all of the elements of the + input array. If axis is negative it counts from the last to the first axis. + + keepdims : bool + If this is set to True, the axes which are reduced are left in the result as dimensions + with size one. + With this option, the result will broadcast correctly against the input array. + + Returns + ------- + ret : tvm.Tensor + """ + return cpp.prod(data, axis, keepdims) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 4cdab44014593f51ff64af30c16d799f7c996c6b..cac3545a75a2ebc42fec94fd40aa88af0314d9a7 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -230,6 +230,11 @@ TVM_REGISTER_GLOBAL("topi.argmax") *rv = topi::argmax(args[0], ArrayOrInt(args[1]), args[2]); }); +TVM_REGISTER_GLOBAL("topi.prod") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = topi::prod(args[0], ArrayOrInt(args[1]), args[2]); + }); + /* Ops from transform.h */ TVM_REGISTER_GLOBAL("topi.expand_dims") .set_body([](TVMArgs args, TVMRetValue *rv) { diff --git a/topi/tests/python/test_topi_reduce.py b/topi/tests/python/test_topi_reduce.py index 0be652948060922c737633536b4e7d886ef1efa7..ceb2a4fe1bb17c154946f11dc08741b2da4a3f5b 100644 --- a/topi/tests/python/test_topi_reduce.py +++ b/topi/tests/python/test_topi_reduce.py @@ -72,6 +72,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32") out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims) else: raise NotImplementedError + out_npy = np.atleast_1d(out_npy) data_tvm = tvm.nd.array(in_npy, ctx=ctx) out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype) for _ in range(1): diff --git a/topi/tests/python_cpp/test_topi_reduce.py b/topi/tests/python_cpp/test_topi_reduce.py index 7bf369c7f1ff75521feb0a7a7663c07bc3216dcf..ab4ac9372373f946603acdc5090f088159594399 100644 --- a/topi/tests/python_cpp/test_topi_reduce.py +++ b/topi/tests/python_cpp/test_topi_reduce.py @@ -42,6 +42,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): elif type == "argmin": B = topi.cpp.argmin(A1, axis, keepdims) out_dtype = "int32" + elif type == "prod": + B = topi.cpp.prod(A1, axis, keepdims) else: raise NotImplementedError @@ -57,7 +59,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): else: s = topi.cpp.cuda.schedule_reduce(target, [B]) - foo = tvm.build(s, [A, B], device, name="sum") + foo = tvm.build(s, [A, B], device, name=type) # Test in_npy = np.random.uniform(size=in_shape).astype(np.float32) in_npy_map = np.sqrt(np.exp(in_npy)).astype(np.float32) @@ -71,6 +73,8 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims) elif type == "argmin": out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims) + elif type == "prod": + out_npy = in_npy_map.prod(axis=axis, keepdims=keepdims) else: raise NotImplementedError out_npy = np.atleast_1d(out_npy) @@ -100,21 +104,29 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"): def test_reduce_map(): verify_reduce_map_ele(in_shape=(128, 24, 128, 24), - axis=(1, 2, 3), - keepdims=True, - type="sum") + axis=(1, 2, 3), + keepdims=True, + type="sum") verify_reduce_map_ele(in_shape=(128, 24 * 128 * 24), - axis=(1,), - keepdims=False, - type="max") + axis=(1,), + keepdims=False, + type="max") verify_reduce_map_ele(in_shape=(32, 128, 24), - axis=None, - keepdims=True, - type="sum") + axis=None, + keepdims=True, + type="sum") verify_reduce_map_ele(in_shape=(128, 24, 128, 24), - axis=(0, 2), - keepdims=False, - type="min") + axis=(0, 2), + keepdims=False, + type="min") + verify_reduce_map_ele(in_shape=(128, 4, 4, 128), + axis=(1, ), + keepdims=True, + type="prod") + verify_reduce_map_ele(in_shape=(4, 4), + axis=(0, 1), + keepdims=False, + type="prod") verify_reduce_map_ele(in_shape=(32, 128), axis=1, keepdims=True,