diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 51e27a9e94bf6f60454279834ea7f711ffde248f..568af8252f4ac1e089ffbc44f921cb43db2f13b5 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -313,10 +313,12 @@ class Schedule : public NodeRef { * * \param tensor The tensor to be factored. * \param axis The reduction axis in tensor's schedule to be factored. + * \param factor_axis The position where the new axis is placed. * \return The created factored tensors. */ EXPORT Array<Tensor> rfactor(const Tensor& tensor, - const IterVar& axis); + const IterVar& axis, + int factor_axis = 0); /*! * \brief Normalize the schedule. * This is needed before bound inference. diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index dda5f67d1b8990c9dccb9edc1fe48d4b55f4318f..b04945292adf224a3f2e8b009df001d1ef927cf2 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -279,7 +279,7 @@ class Schedule(NodeBase): """ return _api_internal._ScheduleCacheWrite(self, tensor, scope) - def rfactor(self, tensor, axis): + def rfactor(self, tensor, axis, factor_axis=0): """ Factor a reduction axis in tensor's schedule to be an explicit axis. This will create a new stage that generated the new tensor with axis @@ -292,13 +292,15 @@ class Schedule(NodeBase): The tensor to be factored. axis : IterVar The reduction axis in the schedule to be factored. + factor_axis : int + The position where the new axis is placed. Returns ------- tfactor : Tensor or Array of Tensor The created factored tensor. """ - factored = _api_internal._ScheduleRFactor(self, tensor, axis) + factored = _api_internal._ScheduleRFactor(self, tensor, axis, factor_axis) return factored[0] if len(factored) == 1 else factored diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index de388cf0b51fcf8d6e59f5614246fce39b0fce78..d1994340702d51563de64ecbcddf2bd9726c1a70 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -432,7 +432,7 @@ TVM_REGISTER_API("_ScheduleCacheWrite") TVM_REGISTER_API("_ScheduleRFactor") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Schedule() - .rfactor(args[1], args[2]); + .rfactor(args[1], args[2], args[3]); }); TVM_REGISTER_API("_CommReducerCombine") diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 59d425287be0fac08a19fb66f107403394817c4c..562eff417dd271923cb908c4375af7a98e5271de 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -395,7 +395,8 @@ Schedule Schedule::normalize() { // Handle reduction factor. Array<Tensor> Schedule::rfactor(const Tensor& tensor, - const IterVar& axis) { + const IterVar& axis, + int factor_axis) { (*this)->InvalidateCache(); using ir::Reduce; CHECK_EQ(axis->iter_type, kCommReduce) @@ -448,6 +449,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, reduce_stage, dom_map, value_map, true, skip_bound_check); // Get the factored op node. + const int factor_axis_pos = \ + factor_axis >= 0 ? factor_axis : static_cast<int>(compute_op->axis.size() + 1) + factor_axis; + CHECK_LE(factor_axis_pos, compute_op->axis.size()); auto n = std::make_shared<ComputeOpNode>(); n->name = compute_op->name + ".rf"; { @@ -458,10 +462,16 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, << "Can only factor reduction domain starting from 0"; iv_node->var = axis->var; iv_node->iter_type = kDataPar; - n->axis.push_back(IterVar(iv_node)); - for (IterVar iv : compute_op->axis) { - n->axis.push_back(iv); + const int size = compute_op->axis.size(); + for (int idx = 0; idx < size; ++idx) { + if (factor_axis_pos == idx) { + n->axis.push_back(IterVar(iv_node)); + } + n->axis.push_back(compute_op->axis[idx]); + } + if (factor_axis_pos == size) { + n->axis.push_back(IterVar(iv_node)); } } // predicate generation, copy not touched axis. @@ -548,9 +558,15 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, Array<Tensor> repl_tensors = compute(old_tensors[0]->shape, [&](const Array<Var>& i) { Array<Expr> indices; - indices.push_back(repl_red_axis->var); - for (Var v : i) { - indices.push_back(v); + const int idx_size = static_cast<int>(i.size()); + for (int idx = 0; idx < idx_size; ++idx) { + if (factor_axis_pos == idx) { + indices.push_back(repl_red_axis->var); + } + indices.push_back(i[idx]); + } + if (factor_axis_pos == idx_size) { + indices.push_back(repl_red_axis->var); } Array<Expr> factor_exprs; for (int idx = 0; idx < size; ++idx) { diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index 228786a7de615a6d257e6f7aee36efa042a3fccf..c8fb98746bf6e0da4e0998972be9624e3289e3d2 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -83,6 +83,36 @@ def test_rfactor(): check_target() +def test_rfactor_factor_axis(): + n = tvm.convert(1027) + A = tvm.placeholder((n,), name='A') + k = tvm.reduce_axis((0, n)) + B = tvm.compute((1,), lambda i: tvm.sum(A[k], axis=k), name='B') + # schedule + s = tvm.create_schedule(B.op) + kf, ki = s[B].split(k, nparts=4) + BF = s.rfactor(B, kf, 1) + s[BF].parallel(BF.op.axis[0]) + # one line to build the function. + def check_target(target="llvm"): + if not tvm.module.enabled(target): + return + ctx = tvm.cpu(0) + fapi = tvm.lower(s, args=[A, B]) + fsum = tvm.build(fapi, + target=target, + name="mysum") + # launch the kernel. + n = 1027 + a = tvm.nd.array(np.random.uniform(size=(n,)).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(1, dtype=B.dtype), ctx) + fsum(a, b) + res = np.sum(a.asnumpy(), axis=0) + np.testing.assert_allclose( + b.asnumpy(), res, rtol=1e-4) + + check_target() + def test_rfactor_threads(): nn = 1027 @@ -294,6 +324,7 @@ def test_rfactor_argmax(): if __name__ == "__main__": test_rfactor_elemwise_threads() test_rfactor_threads() + test_rfactor_factor_axis() test_rfactor() test_reduce_prims() test_argmax() diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index 6c29f1067632fdd53c665444133cc8a65f13a5fc..b29ebec180dfa42596bc5d8e5b8e50cdafce971c 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -137,6 +137,16 @@ def test_rfactor(): assert(BF.op.body[0].axis[0] == k2) assert(BF.op.body[0].axis[1].var == ko.var) assert(s[B].op.body[0].axis[0].dom.extent.value == 4) + # schedule with factor_axis + s = tvm.create_schedule(B.op) + ko, ki = s[B].split(k1, factor=4) + xo, xi = s[B].split(B.op.axis[0], factor=8) + BF = s.rfactor(B, ki, 1) + assert(n == BF.shape[0]) + assert(BF.shape[1].value == 4) + assert(BF.op.body[0].axis[0] == k2) + assert(BF.op.body[0].axis[1].var == ko.var) + assert(s[B].op.body[0].axis[0].dom.extent.value == 4) def test_tensor_intrin(): n = 16