diff --git a/nnvm/src/top/tensor/reduce.cc b/nnvm/src/top/tensor/reduce.cc index 527a6a5abd743aec945e903a609859d31373053a..91d2ea7202b8445e4883acdd4a99b77e23788789 100644 --- a/nnvm/src/top/tensor/reduce.cc +++ b/nnvm/src/top/tensor/reduce.cc @@ -3,6 +3,9 @@ * \file reduce.cc * \brief reduce operator. */ +// Enforce TOPI to use old behavior that reduces to at least 1d +#define TOPI_REDUCE_ATLEAST1D 1 + #include <nnvm/op.h> #include <nnvm/node.h> #include <nnvm/op_attr_types.h> @@ -17,6 +20,8 @@ #include "topi/reduction.h" #include "topi/transform.h" +static_assert(TOPI_REDUCE_ATLEAST1D, "need to use legacy reduce behavior"); + namespace nnvm { namespace top { using namespace tvm; diff --git a/topi/include/topi/reduction.h b/topi/include/topi/reduction.h index ccc85e96643eec9adacff32e81adc88069bb3019..d68b9b39041959141a9fef6230d6c28c182b0186 100644 --- a/topi/include/topi/reduction.h +++ b/topi/include/topi/reduction.h @@ -20,6 +20,14 @@ #include "topi/detail/constant_utils.h" #include "tvm/tvm.h" +/*! + * \brief macro flag to enable some legacy behavior which requires + * reduction result to be at least 1d. + */ +#ifndef TOPI_REDUCE_ATLEAST1D +#define TOPI_REDUCE_ATLEAST1D 0 +#endif + namespace topi { using namespace tvm; @@ -96,6 +104,9 @@ inline Array<Expr> MakeReduceTargetShape(const std::vector<int>& real_axis, } } } + if (target_shape.size() == 0 && TOPI_REDUCE_ATLEAST1D) { + target_shape.push_back(1); + } return target_shape; } diff --git a/topi/src/topi.cc b/topi/src/topi.cc index cac3545a75a2ebc42fec94fd40aa88af0314d9a7..fef2487e67702d8937e7b733c9c16634e7437805 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -3,6 +3,8 @@ * \brief Registration of TVM operators and schedules * \file topi.cc */ +#define TOPI_REDUCE_ATLEAST1D 0 + #include <tvm/runtime/packed_func.h> #include <tvm/runtime/module.h> #include <tvm/runtime/registry.h>