diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 5c972595ff009fff9a26bc7e7243b265468180d3..d4cb2b4c632b1f135100c2db7d92447c3d1a2250 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -321,27 +321,32 @@ Stmt MakeComputeStmt(const ComputeOpNode* self, source.push_back(stage->op.output(i)); } MakeReduction(self, source, &init, &provide); - init = op::Substitute(init, n.init_vmap); init = MergeNest(n.init_nest, init); + init = op::Substitute(init, n.init_vmap); // common nest std::vector<std::vector<Stmt> > common( n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); std::vector<std::vector<Stmt> > reduce( n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.end()); - provide = op::Substitute(provide, n.main_vmap); provide = MergeNest(reduce, provide); if (debug_keep_trivial_loop) { - return MergeNest(common, provide); + provide = MergeNest(common, provide); } else { - return MergeNest(common, Block::make(init, provide)); + provide = MergeNest(common, Block::make(init, provide)); } + // run substitution in the on the full nest, because loop condition + // could depend on outer loops. + return op::Substitute(provide, n.main_vmap); } else { std::vector<Stmt> provides; for (size_t i = 0; i < self->body.size(); ++i) { provides.emplace_back(MakeProvide(self, stage->op.output(i))); } - Stmt provide = op::Substitute(Block::make(provides), n.main_vmap); - return MergeNest(n.main_nest, provide); + Stmt provide = Block::make(provides); + provide = MergeNest(n.main_nest, provide); + // run substitution in the on the full nest, because loop condition + // could depend on outer loops. + return op::Substitute(provide, n.main_vmap); } } diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index e60073fe9f5c453456b782e3119856a856433482..e59a73529d24f6d156470bbb1e3719dfdd9fdc97 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -409,7 +409,18 @@ def test_schedule_tensor_compute3(): stmt = tvm.schedule.ScheduleOps(s, bounds) +def test_loop_dep_reduce(): + X = tvm.placeholder(shape=(10,), name="x") + def f(n): + rv = tvm.reduce_axis((0, n)) + return tvm.sum(X[rv], axis=rv) + Y = tvm.compute(X.shape, f, name="y") + s = tvm.create_schedule([Y.op]) + f = tvm.build(s, [X, Y]) + + if __name__ == "__main__": + test_loop_dep_reduce() test_schedule_middle_cache() test_inline_multi_reduce() test_schedule_cache_relayout4()