From f13421cff29c321f1fff4c9f13c472153c14e99f Mon Sep 17 00:00:00 2001
From: Lianmin Zheng <mercy_zheng@sjtu.edu.cn>
Date: Sat, 10 Mar 2018 03:39:38 +0800
Subject: [PATCH] fix keeping trivial loop (#982)

---
 src/schedule/schedule_ops.cc | 25 ++++++++++++++++---------
 1 file changed, 16 insertions(+), 9 deletions(-)

diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc
index 1fbffb61f..b9b02050a 100644
--- a/src/schedule/schedule_ops.cc
+++ b/src/schedule/schedule_ops.cc
@@ -57,8 +57,10 @@ class InjectAttach : public IRMutator {
  public:
   InjectAttach(const Stage& stage,
                const Stage& attach_spec,
-               const std::unordered_map<IterVar, Range>& dom_map)
-      : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map) {}
+               const std::unordered_map<IterVar, Range>& dom_map,
+               bool del_trivial_loop)
+      : stage_(stage), attach_spec_(attach_spec), dom_map_(dom_map),
+        del_trivial_loop_(del_trivial_loop) {}
 
   Stmt Mutate(Stmt stmt) final {
     CHECK(stmt.defined());
@@ -74,7 +76,7 @@ class InjectAttach : public IRMutator {
         found_attach = true;
         stmt = AttrStmt::make(
             op->node, op->attr_key, op->value,
-            MakePipeline(stage_, dom_map_, op->body, true));
+            MakePipeline(stage_, dom_map_, op->body, del_trivial_loop_));
       }
     }
     return stmt;
@@ -89,6 +91,8 @@ class InjectAttach : public IRMutator {
   const Stage& attach_spec_;
   // domain map
   const std::unordered_map<IterVar, Range>& dom_map_;
+  // whether delete trivial loops with extent of 1
+  bool del_trivial_loop_;
 };
 
 // inject the operator's realization on the stmt.
@@ -97,9 +101,10 @@ class InjectScanStep : public IRMutator {
   InjectScanStep(const Stage& stage,
                  const Operation& scan_op,
                  const std::unordered_map<IterVar, Range>& dom_map,
-                 bool is_init)
+                 bool is_init,
+                 bool del_trivial_loop)
       : stage_(stage), scan_op_(scan_op),
-        dom_map_(dom_map), is_init_(is_init) {}
+        dom_map_(dom_map), is_init_(is_init), del_trivial_loop_(del_trivial_loop) {}
 
   Stmt Mutate(Stmt stmt) final {
     CHECK(stmt.defined());
@@ -113,7 +118,7 @@ class InjectScanStep : public IRMutator {
         found_attach = true;
         stmt = AttrStmt::make(
             op->node, op->attr_key, op->value,
-            MakePipeline(stage_, dom_map_, op->body, true));
+            MakePipeline(stage_, dom_map_, op->body, del_trivial_loop_));
       }
     }
     return stmt;
@@ -130,6 +135,8 @@ class InjectScanStep : public IRMutator {
   const std::unordered_map<IterVar, Range>& dom_map_;
   // whether it is init.
   bool is_init_;
+  // whether delete trivial loops with extent of 1
+  bool del_trivial_loop_;
 };
 
 // Postprocessing of schedule op
@@ -365,14 +372,14 @@ Stmt ScheduleOps(
 
     if (scan_init.count(s->op)) {
       CHECK(body.defined());
-      InjectScanStep mu(s, scan_init.at(s->op), dom_map, true);
+      InjectScanStep mu(s, scan_init.at(s->op), dom_map, true, del_trivial_loop);
       body = mu.Mutate(body);
       CHECK(mu.found_attach)
           << "did not find attachment point for scan.init";
     } else if (attach_spec->attach_type == kScanUpdate) {
       // Handle scan update
       CHECK(body.defined());
-      InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false);
+      InjectScanStep mu(s, attach_spec->attach_stage->op, dom_map, false, del_trivial_loop);
       body = mu.Mutate(body);
       CHECK(mu.found_attach)
           << "did not find attachment point for scan.update";
@@ -384,7 +391,7 @@ Stmt ScheduleOps(
     } else {
       CHECK_EQ(attach_spec->attach_type, kScope);
       CHECK(body.defined());
-      InjectAttach mutator(s, attach_spec, dom_map);
+      InjectAttach mutator(s, attach_spec, dom_map, del_trivial_loop);
       body = mutator.Mutate(body);
       CHECK(mutator.found_attach)
           << "did not find attachment point for " << s << " in "
-- 
GitLab