From f13421cff29c321f1fff4c9f13c472153c14e99f Mon Sep 17 00:00:00 2001 From: Lianmin Zheng <mercy_zheng@sjtu.edu.cn> Date: Sat, 10 Mar 2018 03:39:38 +0800 Subject: [PATCH] fix keeping trivial loop (#982) --- src/schedule/schedule_ops.cc | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index 1fbffb61f..b9b02050a 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -57,8 +57,10 @@ class InjectAttach : public IRMutator { public: InjectAttach(const Stage& stage, const Stage& attach_spec, - const std::unordered_map<IterVar, Range>& dom_map) - : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map) {} + const std::unordered_map<IterVar, Range>& dom_map, + bool del_trivial_loop) + : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map), + del_trivial_loop_(del_trivial_loop) {} Stmt Mutate(Stmt stmt) final { CHECK(stmt.defined()); @@ -74,7 +76,7 @@ class InjectAttach : public IRMutator { found_attach = true; stmt = AttrStmt::make( op->node, op->attr_key, op->value, - MakePipeline(stage_, dom_map_, op->body, true)); + MakePipeline(stage_, dom_map_, op->body, del_trivial_loop_)); } } return stmt; @@ -89,6 +91,8 @@ class InjectAttach : public IRMutator { const Stage& attach_spec_; // domain map const std::unordered_map<IterVar, Range>& dom_map_; + // whether delete trivial loops with extent of 1 + bool del_trivial_loop_; }; // inject the operator's realization on the stmt. @@ -97,9 +101,10 @@ class InjectScanStep : public IRMutator { InjectScanStep(const Stage& stage, const Operation& scan_op, const std::unordered_map<IterVar, Range>& dom_map, - bool is_init) + bool is_init, + bool del_trivial_loop) : stage_(stage), scan_op_(scan_op), - dom_map_(dom_map), is_init_(is_init) {} + dom_map_(dom_map), is_init_(is_init), del_trivial_loop_(del_trivial_loop) {} Stmt Mutate(Stmt stmt) final { CHECK(stmt.defined()); @@ -113,7 +118,7 @@ class InjectScanStep : public IRMutator { found_attach = true; stmt = AttrStmt::make( op->node, op->attr_key, op->value, - MakePipeline(stage_, dom_map_, op->body, true)); + MakePipeline(stage_, dom_map_, op->body, del_trivial_loop_)); } } return stmt; @@ -130,6 +135,8 @@ class InjectScanStep : public IRMutator { const std::unordered_map<IterVar, Range>& dom_map_; // whether it is init. bool is_init_; + // whether delete trivial loops with extent of 1 + bool del_trivial_loop_; }; // Postprocessing of schedule op @@ -365,14 +372,14 @@ Stmt ScheduleOps( if (scan_init.count(s->op)) { CHECK(body.defined()); - InjectScanStep mu(s, scan_init.at(s->op), dom_map, true); + InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, del_trivial_loop); body = mu.Mutate(body); CHECK(mu.found_attach) << "did not find attachment point for scan.init"; } else if (attach_spec->attach_type == kScanUpdate) { // Handle scan update CHECK(body.defined()); - InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false); + InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, del_trivial_loop); body = mu.Mutate(body); CHECK(mu.found_attach) << "did not find attachment point for scan.update"; @@ -384,7 +391,7 @@ Stmt ScheduleOps( } else { CHECK_EQ(attach_spec->attach_type, kScope); CHECK(body.defined()); - InjectAttach mutator(s, attach_spec, dom_map); + InjectAttach mutator(s, attach_spec, dom_map, del_trivial_loop); body = mutator.Mutate(body); CHECK(mutator.found_attach) << "did not find attachment point for " << s << " in " -- GitLab