diff --git a/include/tvm/operation.h b/include/tvm/operation.h index d598df8d21b10525dc0cfdcb5e4b00f1d3168e89..9b950c3d544f5f93a6000ef67ac8dea747de9b83 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -117,11 +117,13 @@ class OperationNode : public FunctionBaseNode { * \brief Build the statement that provide the output tensors. * \param stage The schedule stage of the op. * \param dom_map The domain map of all iteration domains. + * \param del_trivial_loop Whether eliminate trivial loop with extent of 1 * \return A statement that add production and wraps consumer. */ virtual Stmt BuildProvide( const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) const = 0; + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) const = 0; static constexpr const char* _type_key = "Operation"; @@ -160,7 +162,8 @@ class PlaceholderOpNode : public OperationNode { const Stmt& body) const final; Stmt BuildProvide( const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) const final; + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); @@ -211,7 +214,8 @@ class ComputeOpNode : public OperationNode { const Stmt& body) const final; Stmt BuildProvide( const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) const final; + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); @@ -282,7 +286,8 @@ class ScanOpNode : public OperationNode { const Stmt& body) const final; Stmt BuildProvide( const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) const final; + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); @@ -345,7 +350,8 @@ class ExternOpNode : public OperationNode { const Stmt& body) const final; Stmt BuildProvide( const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) const final; + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) const final; void VisitAttrs(AttrVisitor* v) final { v->Visit("name", &name); diff --git a/include/tvm/schedule_pass.h b/include/tvm/schedule_pass.h index 719448513fb87a921bd9e24151fb61676ddbd315..011c7510ced929cde42bf8c58548ed502d0baddb 100644 --- a/include/tvm/schedule_pass.h +++ b/include/tvm/schedule_pass.h @@ -29,9 +29,10 @@ Map<IterVar, Range> InferBound(const Schedule& sch); * * \param s The schedule to be realized * \param dom_map The domain of each iter vars. + * \param del_trivial_loop Whether delete trivial loops with extent of 1 * \return the result Stmt */ -Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map); +Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map, bool del_trivial_loop); /*! * \brief To automatically inline the element-wise operations. diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc index 0b5ef251503cf85e3b00f28e3c8bae3c6a14ad21..b1a6729ec6620c36def9004d84c44edebb9d01d3 100644 --- a/src/api/api_schedule.cc +++ b/src/api/api_schedule.cc @@ -24,6 +24,14 @@ TVM_REGISTER_API("schedule.AutoInlineInjective") AutoInlineInjective(args[0]); }); +TVM_REGISTER_API("schedule.ScheduleOps") +.set_body([](TVMArgs args, TVMRetValue* ret) { + if (args.size() == 2) + *ret = ScheduleOps(args[0], args[1], true); + else + *ret = ScheduleOps(args[0], args[1], args[2]); +}); + #define REGISTER_SCHEDULE_PASS1(PassName) \ TVM_REGISTER_API("schedule."#PassName) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ @@ -43,7 +51,6 @@ REGISTER_SCHEDULE_PASS2(PostDFSOrder); REGISTER_SCHEDULE_PASS1(CreateAttachPath); REGISTER_SCHEDULE_PASS1(ScanGetBody); REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis); -REGISTER_SCHEDULE_PASS2(ScheduleOps); } // namespace schedule } // namespace tvm diff --git a/src/codegen/build_module.cc b/src/codegen/build_module.cc index 53aef46a97513ea5a7f19c7173beb295be85d08f..2e8e5bb278eb4124388525856f3fb2435fef1c1f 100644 --- a/src/codegen/build_module.cc +++ b/src/codegen/build_module.cc @@ -211,7 +211,7 @@ Stmt BuildStmt(Schedule sch, // Phase 0 auto bounds = schedule::InferBound(sch); - auto stmt = schedule::ScheduleOps(sch, bounds); + auto stmt = schedule::ScheduleOps(sch, bounds, true); stmt = ir::InjectPrefetch(stmt); // Phase 1 diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 9e7db1deea455285ed56665abd8b4bd3b84db9ee..8b8bfbfe602eeacc9701eaf09e1dadb7b498ea40 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -305,9 +305,10 @@ Stmt MakeProvide(const ComputeOpNode* op, Stmt MakeComputeStmt(const ComputeOpNode* self, const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) { + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) { // grab the nest structure - ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map); + ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, del_trivial_loop); // Normal loop structure n.init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); n.main_nest.emplace_back(op::MakeIfNest(n.main_predicates)); @@ -387,28 +388,30 @@ ComputeType DetectComputeType(const ComputeOpNode* self, // implement the provide utility. Stmt ComputeOpNode::BuildProvide( const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) const { + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); ComputeType ctype = DetectComputeType(this, stage); if (ctype == ComputeType::kCrossThreadReduction) { // specially handle cross thread reduction. - return MakeCrossThreadReduction(this, stage, dom_map); + return MakeCrossThreadReduction(this, stage, dom_map, del_trivial_loop); } else if (ctype == ComputeType::kTensorize) { - return MakeTensorize(this, stage, dom_map); + return MakeTensorize(this, stage, dom_map, del_trivial_loop); } else { - return MakeComputeStmt(this, stage, dom_map); + return MakeComputeStmt(this, stage, dom_map, del_trivial_loop); } } ComputeLoopNest ComputeLoopNest::make( const ComputeOpNode* self, const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) { + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) { CHECK_EQ(stage->op.operator->(), self); ComputeLoopNest ret; // make main loop nest ret.main_nest = op::MakeLoopNest( - stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap); + stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap, del_trivial_loop); ret.main_predicates = schedule::MakeBoundCheck( stage, dom_map, ret.main_vmap, false, std::unordered_set<IterVar>()); @@ -450,7 +453,7 @@ ComputeLoopNest ComputeLoopNest::make( } ret.init_nest = op::MakeLoopNest( stage, dom_map, begin_loop, true, - skip_iter, &(ret.init_vmap)); + skip_iter, &(ret.init_vmap), del_trivial_loop); ret.init_predicates = schedule::MakeBoundCheck( stage, dom_map, ret.init_vmap, true, skip_iter); for (auto& e : ret.init_predicates) { diff --git a/src/op/compute_op.h b/src/op/compute_op.h index 95dc0f44d8d438524dd45d79eaba5dd618034749..2164feee6988f84c0cee190549d262fbabb5a524 100644 --- a/src/op/compute_op.h +++ b/src/op/compute_op.h @@ -37,12 +37,14 @@ struct ComputeLoopNest { * \param self The pointer to compute op. * \param stage The scxhedule stage. * \param dom_map The domain map. + * \param del_trivial_loop Whether eliminate trivial loops with extent of 1 * \return The constructed loop nest */ static ComputeLoopNest make( const ComputeOpNode* self, const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map); + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop); }; /*! @@ -50,23 +52,27 @@ struct ComputeLoopNest { * \param self The pointer to ComputeOpNode * \param stage The schedule stage. * \param dom_map The domain map. + * \param del_trivial_loop Wheter eliminate trivial loops with extent of 1 * \return The created statement. */ Stmt MakeCrossThreadReduction( const ComputeOpNode* self, const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map); + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop); /*! * \brief Build body of compute for tensorization. * \param self The pointer to ComputeOpNode * \param stage The schedule stage. * \param dom_map The domain map. + * \param del_trivial_loop Wheter eliminate trivial loops with extent of 1 * \return The created statement. */ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map); + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop); } // namespace tvm #endif // TVM_OP_COMPUTE_OP_H_ diff --git a/src/op/cross_thread_reduction.cc b/src/op/cross_thread_reduction.cc index 6eec3bd69d6a245c36fd10c197614c50550ef209..e32b3dcd4407d3a814773b9abad0074272c38fb0 100644 --- a/src/op/cross_thread_reduction.cc +++ b/src/op/cross_thread_reduction.cc @@ -13,14 +13,15 @@ using namespace ir; Stmt MakeCrossThreadReduction( const ComputeOpNode* self, const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) { + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) { Array<Expr> args; for (IterVar iv : self->axis) { args.push_back(iv->var); } std::unordered_map<IterVar, Expr> value_map; auto nest = op::MakeLoopNest( - stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map); + stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map, del_trivial_loop); auto conds = schedule::MakeBoundCheck( stage, dom_map, value_map, false, std::unordered_set<IterVar>()); diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index e83f97b14652e922c11eb48a0c906e0753de719e..df3a32d50fe7b31d2743658c6b1191d0078eb396 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -128,7 +128,8 @@ Stmt ExternOpNode::BuildRealize( Stmt ExternOpNode::BuildProvide( const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) const { + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body); auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { diff --git a/src/op/op_util.cc b/src/op/op_util.cc index 78e092ca844ebc8a32792f32db3a80e040db956a..ef7af85bf079c43ab4e9e7d18aaeedf8e472eef3 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -23,7 +23,8 @@ MakeLoopNest(const Stage& stage, size_t begin_iter_pos, bool new_loop_var, const std::unordered_set<IterVar>& skip_iter, - std::unordered_map<IterVar, Expr>* p_value_map) { + std::unordered_map<IterVar, Expr>* p_value_map, + bool del_trivial_loop) { auto leaf_iter_vars = stage->leaf_iter_vars; Stmt no_op = Evaluate::make(0); // create the loop nest @@ -75,7 +76,7 @@ MakeLoopNest(const Stage& stage, AttrStmt::make(iv, ir::attr::pragma_scope, p, no_op)); } } - if (is_one(dom->extent)) { + if (del_trivial_loop && is_one(dom->extent)) { nest[i + 1].emplace_back( LetStmt::make(var, dom->min, no_op)); value_map[iv] = dom->min; @@ -130,7 +131,7 @@ MakeLoopNest(const Stage& stage, // annotate the extent of the IterVar nest[i + 1].emplace_back( AttrStmt::make(bind_iv, ir::attr::thread_extent, dom->extent, no_op)); - if (is_one(dom->extent)) { + if (del_trivial_loop && is_one(dom->extent)) { value_map[iv] = dom->min; } else { value_map[iv] = var; diff --git a/src/op/op_util.h b/src/op/op_util.h index 783fbb989422062618b4c8d5600491a126fed056..9b8f7dc629bda0ac6a3718b5e663fe0eb39c5b14 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -29,6 +29,7 @@ using ir::MergeNest; * \param new_loop_var Whether create new loop variable. * \param skip_iter Whether skip certain iteration. * \param p_value_map The result value of each IterVar. + * \param del_trivial_loop Whether eliminate trivial loops with extent of 1 */ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage, @@ -36,7 +37,8 @@ MakeLoopNest(const Stage& stage, size_t begin_iter_pos, bool new_loop_var, const std::unordered_set<IterVar>& skip_iter, - std::unordered_map<IterVar, Expr>* p_value_map); + std::unordered_map<IterVar, Expr>* p_value_map, + bool del_trivial_loop); /*! * \brief Create a nest of if checking the predicates. diff --git a/src/op/placeholder_op.cc b/src/op/placeholder_op.cc index 4e9d1d094d7444e5c9c88f574c07f6f50214a337..27c1fa9c500176da1cc3babf0673dda4c175e218 100644 --- a/src/op/placeholder_op.cc +++ b/src/op/placeholder_op.cc @@ -78,7 +78,8 @@ Stmt PlaceholderOpNode::BuildRealize( Stmt PlaceholderOpNode::BuildProvide( const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) const { + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) const { return Stmt(); } } // namespace tvm diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index 94e3a4aa65869bc604c23eadd0a716845b022513..5c61eae0f1839b16a2c5a0f7d80724a20f962524 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -252,7 +252,8 @@ Stmt ScanOpNode::BuildRealize( Stmt ScanOpNode::BuildProvide( const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) const { + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) const { CHECK_EQ(stage->op.operator->(), this); Stmt provide = AttrStmt::make( stage->op, attr::scan_update_scope, this->scan_axis->var, @@ -270,7 +271,7 @@ Stmt ScanOpNode::BuildProvide( std::unordered_map<IterVar, Expr> vmap; std::unordered_set<IterVar> empty; auto nest = op::MakeLoopNest( - stage, dom_map, 0, false, empty, &vmap); + stage, dom_map, 0, false, empty, &vmap, del_trivial_loop); nest[begin_scan].push_back(init); nest.push_back( op::MakeIfNest( diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 6fa5459829fce72b1728aa745d3da666c365e908..1f03ec9c0ebb85419fa357d2edd7653cb1062b70 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -369,14 +369,15 @@ Stmt TransformUpdate(const Stage& stage, Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, - const std::unordered_map<IterVar, Range>& dom_map) { + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) { std::unordered_map<IterVar, Range> out_dom; std::unordered_map<Tensor, Array<Range> > in_region; size_t tloc = InferTensorizeRegion(self, stage, dom_map, &out_dom, &in_region); TensorIntrin intrin = stage->iter_var_attrs.at( stage->leaf_iter_vars[tloc])->tensor_intrin; CHECK(intrin.defined()); - ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map); + ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, del_trivial_loop); VerifyTensorizeLoopNest(self, stage, n, tloc); VerifyTensorizeBody(self, stage, out_dom, in_region, intrin); // Start bind data. diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 875df556466a0b442e1a76b3102ec6eab589a5f0..e0dc3321b1fc901f7f97444642624db8f88146cb 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -22,8 +22,9 @@ using namespace ir; Stmt MakePipeline(const Stage& s, const std::unordered_map<IterVar, Range>& dom_map, - Stmt consumer) { - Stmt producer = s->op->BuildProvide(s, dom_map); + Stmt consumer, + bool del_trivial_loop) { + Stmt producer = s->op->BuildProvide(s, dom_map, del_trivial_loop); if (producer.defined()) { producer = ProducerConsumer::make(s->op, true, producer); } @@ -68,7 +69,7 @@ class InjectAttach : public IRMutator { found_attach = true; stmt = AttrStmt::make( op->node, op->attr_key, op->value, - MakePipeline(stage_, dom_map_, op->body)); + MakePipeline(stage_, dom_map_, op->body, true)); } } return stmt; @@ -107,7 +108,7 @@ class InjectScanStep : public IRMutator { found_attach = true; stmt = AttrStmt::make( op->node, op->attr_key, op->value, - MakePipeline(stage_, dom_map_, op->body)); + MakePipeline(stage_, dom_map_, op->body, true)); } } return stmt; @@ -324,7 +325,7 @@ class SchedulePostProc : public IRMutator { }; Stmt ScheduleOps( - Schedule sch, Map<IterVar, Range> dom_map_) { + Schedule sch, Map<IterVar, Range> dom_map_, bool del_trivial_loop) { Stmt body = Stmt(); std::unordered_map<IterVar, Range> dom_map = as_unordered_map(dom_map_); // scan init and scan updates @@ -374,7 +375,7 @@ Stmt ScheduleOps( // do nothing } else if (attach_spec->attach_type == kGroupRoot) { CHECK(!s->group.defined()); - body = MakePipeline(s, dom_map, body); + body = MakePipeline(s, dom_map, body, del_trivial_loop); } else { CHECK_EQ(attach_spec->attach_type, kScope); CHECK(body.defined());