From e42cc112bf95b1b5d3f295033b48819f1db604a8 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sat, 4 Feb 2017 22:41:28 -0800
Subject: [PATCH] [PASS] UnrollLoop, isolate arithmetic module. (#32)

---
 include/tvm/ir_pass.h                       | 10 ++-
 include/tvm/runtime/packed_func.h           |  2 +-
 python/tvm/build.py                         |  1 -
 src/README.md                               |  1 +
 src/api/api_pass.cc                         | 10 +++
 src/{schedule => arithmetic}/compute_expr.h | 10 +--
 src/{schedule => arithmetic}/int_set.cc     | 96 +++------------------
 src/{schedule => arithmetic}/int_set.h      | 75 ++++------------
 src/pass/inline.cc                          | 22 ++++-
 src/pass/simple_passes.cc                   |  4 +-
 src/pass/unroll_loop.cc                     | 78 +++++++++++++++++
 src/schedule/bound.cc                       | 78 ++++++++++++++++-
 src/schedule/graph.cc                       |  1 -
 src/schedule/schedule_ops.cc                | 14 ++-
 tests/python/unittest/test_pass_unroll.py   | 20 +++++
 15 files changed, 261 insertions(+), 161 deletions(-)
 rename src/{schedule => arithmetic}/compute_expr.h (94%)
 rename src/{schedule => arithmetic}/int_set.cc (82%)
 rename src/{schedule => arithmetic}/int_set.h (56%)
 create mode 100644 src/pass/unroll_loop.cc
 create mode 100644 tests/python/unittest/test_pass_unroll.py

diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index 8eaec0f52..f8412dc36 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -68,7 +68,7 @@ Stmt ConvertSSA(Stmt stmt);
  * \param value_map The map of new values.
  * \return The converted form.
  */
-Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);
+Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map);
 
 /*!
  * \brief inline all calls of f in stmt.
@@ -97,6 +97,13 @@ Stmt Inline(Stmt stmt,
 Stmt StorageFlatten(Stmt stmt,
                     Map<Tensor, Buffer> extern_buffer);
 
+/*!
+ * \brief unroll the constant loops
+ * \param stmt The statment to be unrolled.
+ * \param max_auto_step The maximum step to stop performing automatic unrolling.
+ */
+Stmt UnrollLoop(Stmt stmt, int max_auto_step);
+
 /*!
  * \brief Make an user callable API LoweredFunc.
  *
@@ -153,6 +160,7 @@ Array<LoweredFunc> SplitHostDevice(LoweredFunc func);
  */
 LoweredFunc StorageSync(LoweredFunc stmt, std::string storage_scope);
 
+
 }  // namespace ir
 }  // namespace tvm
 
diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h
index eafc367fe..3b1921ee8 100644
--- a/include/tvm/runtime/packed_func.h
+++ b/include/tvm/runtime/packed_func.h
@@ -562,7 +562,7 @@ inline TVMArgValue TVMArgs::operator[](int i) const {
   CHECK_LT(i, num_args)
       << "not enough argument passed, "
       << num_args << " passed"
-      << "but request arg" << i;
+      << " but request arg[" << i << "].";
   return TVMArgValue(values[i], type_codes[i]);
 }
 
diff --git a/python/tvm/build.py b/python/tvm/build.py
index 29321eabe..fbed0a33f 100644
--- a/python/tvm/build.py
+++ b/python/tvm/build.py
@@ -70,7 +70,6 @@ def build(sch,
     fsplits = [x for x in fsplits]
     for i in range(1, len(fsplits)):
         fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
-        fsplits[i] = ir_pass.StorageSync(fsplits[i], "global")
 
     if record_codes is not None:
         output_ssa = False
diff --git a/src/README.md b/src/README.md
index 16dfc19d8..91cb47ece 100644
--- a/src/README.md
+++ b/src/README.md
@@ -3,5 +3,6 @@
 - api API functionr registration
 - lang The definition of DSL related data structure
 - schedule The operations on the schedule graph before converting to IR.
+- arithmetic Arithmetic expression and set simplification
 - pass The optimization pass on the IR structure
 - runtime Minimum runtime related codes.
diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc
index 6e7bbd849..df79996e4 100644
--- a/src/api/api_pass.cc
+++ b/src/api/api_pass.cc
@@ -6,6 +6,7 @@
 #include <tvm/expr.h>
 #include <tvm/ir.h>
 #include <tvm/ir_pass.h>
+#include <tvm/ir_visitor.h>
 #include <tvm/api_registry.h>
 
 namespace tvm {
@@ -29,6 +30,14 @@ TVM_REGISTER_API(_pass_Equal)
     }
   });
 
+TVM_REGISTER_API(_pass_PostOrderVisit)
+.set_body([](TVMArgs args, TVMRetValue *ret) {
+    PackedFunc f = args[1];
+    ir::PostOrderVisit(args[0], [f](const NodeRef& n) {
+        f(n);
+      });
+  });
+
 // make from two arguments
 #define REGISTER_PASS1(PassName)                                  \
   TVM_REGISTER_API(_pass_## PassName)                             \
@@ -52,6 +61,7 @@ REGISTER_PASS1(ConvertSSA);
 REGISTER_PASS1(VerifySSA);
 REGISTER_PASS4(Inline);
 REGISTER_PASS2(StorageFlatten);
+REGISTER_PASS2(UnrollLoop);
 REGISTER_PASS2(StorageSync);
 REGISTER_PASS4(MakeAPI);
 REGISTER_PASS1(SplitHostDevice);
diff --git a/src/schedule/compute_expr.h b/src/arithmetic/compute_expr.h
similarity index 94%
rename from src/schedule/compute_expr.h
rename to src/arithmetic/compute_expr.h
index ee1947b61..9550c1c96 100644
--- a/src/schedule/compute_expr.h
+++ b/src/arithmetic/compute_expr.h
@@ -4,14 +4,14 @@
  * \brief Utility integer expression with quick eager simplification.
  *  This is weaker than Simplify but can be done Eagerly.
  */
-#ifndef TVM_SCHEDULE_COMPUTE_EXPR_H_
-#define TVM_SCHEDULE_COMPUTE_EXPR_H_
+#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
+#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
 
 #include <tvm/ir.h>
 #include <pass/Interval.h>
 
 namespace tvm {
-namespace schedule {
+namespace arith {
 
 using Halide::Internal::add_would_overflow;
 using Halide::Internal::sub_would_overflow;
@@ -104,6 +104,6 @@ inline Expr ComputeExpr<ir::Min>(Expr a, Expr b) {
   return Halide::Internal::Interval::make_min(a, b);
 }
 
-}  // namespace schedule
+}  // namespace arith
 }  // namespace tvm
-#endif   // TVM_SCHEDULE_COMPUTE_EXPR_H_
+#endif   // TVM_ARITHMETIC_COMPUTE_EXPR_H_
diff --git a/src/schedule/int_set.cc b/src/arithmetic/int_set.cc
similarity index 82%
rename from src/schedule/int_set.cc
rename to src/arithmetic/int_set.cc
index 0da1a39e7..04b40191d 100644
--- a/src/schedule/int_set.cc
+++ b/src/arithmetic/int_set.cc
@@ -1,6 +1,6 @@
 /*!
- *  Copyright (c) 2016 by Contributors
- * \file int_set_impl.cc
+ *  Copyright (c) 2017 by Contributors
+ * \file int_set.cc
  * \brief The integer set functions
  */
 #include <tvm/ir.h>
@@ -10,7 +10,7 @@
 #include "./compute_expr.h"
 
 namespace tvm {
-namespace schedule {
+namespace arith {
 
 using Halide::Internal::Interval;
 
@@ -94,6 +94,12 @@ bool IntSet::is_single_point() const {
   return (s_int && s_int->i.is_single_point());
 }
 
+Expr IntSet::point_value() const {
+  const IntervalSet* s_int = (*this).as<IntervalSet>();
+  CHECK(s_int && s_int->i.is_single_point());
+  return s_int->i.min;
+}
+
 IntSet IntSet::everything() {
   return IntervalSet::make(Interval::everything());
 }
@@ -115,8 +121,8 @@ IntSet IntSet::range(Range r) {
 }
 
 // Check if a is created from b.
-inline bool MatchRange(const IntSet& a,
-                       const Range& b) {
+bool IntSet::match_range(const Range& b) const {
+  const IntSet& a = *this;
   const IntervalSet* a_int = a.as<IntervalSet>();
   if (!a_int) return false;
   const Interval& i = a_int->i;
@@ -349,84 +355,6 @@ inline IntSet Combine(const IntSet& a, const IntSet &b) {
   return CombineSets<OP>(a, b);
 }
 
-// Implementation of Evaluations and passing.
-void PassUp(const SplitNode* s,
-            const std::unordered_map<IterVar, Range>& dom_map,
-            const IntSet& outer,
-            const IntSet& inner,
-            IntSet* parent) {
-  if (dom_map.count(s->outer) &&
-      dom_map.count(s->inner) &&
-      dom_map.count(s->parent) &&
-      MatchRange(outer, dom_map.at(s->outer)) &&
-      MatchRange(inner, dom_map.at(s->inner))) {
-    *parent = IntSet::range(dom_map.at(s->parent));
-    return;
-  }
-  Expr factor = dom_map.at(s->inner)->extent;
-  Expr parent_min = dom_map.at(s->parent)->min;
-  CHECK(outer.defined());
-  CHECK(inner.defined());
-  CHECK(factor.defined());
-
-  *parent = Combine<Add>(
-      Combine<Add>(
-          Combine<Mul>(outer, IntSet::single_point(factor)), inner),
-      IntSet::single_point(parent_min));
-}
-
-void PassUp(const FuseNode* s,
-            const std::unordered_map<IterVar, Range>& dom_map,
-            const IntSet& fused,
-            IntSet* outer,
-            IntSet* inner) {
-  CHECK(dom_map.count(s->outer));
-  CHECK(dom_map.count(s->inner));
-  CHECK(dom_map.count(s->fused));
-
-  if (MatchRange(fused, dom_map.at(s->fused))) {
-    *outer = IntSet::range(dom_map.at(s->outer));
-    *inner = IntSet::range(dom_map.at(s->inner));
-    return;
-  }
-
-  Expr outer_min = dom_map.at(s->outer)->min;
-  Expr inner_min = dom_map.at(s->inner)->min;
-
-  const IntervalSet* fused_int = fused.as<IntervalSet>();
-
-  if (fused_int && fused_int->i.is_single_point()) {
-    Expr value = fused_int->i.min;
-    Expr factor = dom_map.at(s->inner)->extent;
-    Expr v_outer  = value / factor;
-    Expr v_inner  = value % factor;
-    if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
-    if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
-    *outer = IntSet::single_point(v_outer);
-    *inner = IntSet::single_point(v_inner);
-  } else {
-    LOG(WARNING) << "use fallback inference rule in fuse";
-    // simply use the entire set, this rule can be enhanced.
-    *outer = IntSet::range(dom_map.at(s->outer));
-    *inner = IntSet::range(dom_map.at(s->inner));
-    return;
-  }
-}
-
-
-void PassUp(const RebaseNode* s,
-            const std::unordered_map<IterVar, Range>& dom_map,
-            const IntSet& rebased,
-            IntSet* parent) {
-  CHECK(dom_map.count(s->parent));
-  if (MatchRange(rebased, dom_map.at(s->rebased))) {
-    *parent = IntSet::range(dom_map.at(s->parent));
-    return;
-  }
-  Expr parent_min = dom_map.at(s->parent)->min;
-  *parent = Combine<Add>(rebased, IntSet::single_point(parent_min));
-}
-
 // Evaluator to evalute the epxression.
 class IntSetEvaluator {
  public:
@@ -527,5 +455,5 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
   });
 
 
-}  // namespace schedule
+}  // namespace arith
 }  // namespace tvm
diff --git a/src/schedule/int_set.h b/src/arithmetic/int_set.h
similarity index 56%
rename from src/schedule/int_set.h
rename to src/arithmetic/int_set.h
index 5866c123d..80c2fae79 100644
--- a/src/schedule/int_set.h
+++ b/src/arithmetic/int_set.h
@@ -3,14 +3,14 @@
  * \file int_set.h
  * \brief Abstraction for all integer set operations.
  */
-#ifndef TVM_SCHEDULE_INT_SET_H_
-#define TVM_SCHEDULE_INT_SET_H_
+#ifndef TVM_ARITHMETIC_INT_SET_H_
+#define TVM_ARITHMETIC_INT_SET_H_
 
 #include <tvm/expr.h>
 #include <tvm/schedule.h>
 
 namespace tvm {
-namespace schedule {
+namespace arith {
 
 // internal node container of int set.
 class IntSetNode;
@@ -44,6 +44,18 @@ class IntSet : public NodeRef {
   bool is_everything() const;
   /*! \return Whether the set is a single point */
   bool is_single_point() const;
+  /*!
+   * \brief The single point value, call only if is_single_point is true
+   * \return The point value.
+   */
+  Expr point_value() const;
+  /*!
+   * \brief Try to match IntSet with range r.
+   *
+   * \note It is guanrateed that IntSet::range(r).match_range(r) == true
+   * \return true if we can prove they are the same.
+   */
+  bool match_range(const Range& r) const;
   /*! \return Whether the set contains everything */
   static IntSet everything();
   /*!
@@ -88,59 +100,6 @@ IntSet EvalSet(Expr e,
 IntSet EvalSet(Range r,
                const Map<IterVar, IntSet>& dom_map);
 
-/*!
- * \brief Conditional upward message passing.
- *
- * Get domain of parent, condition on domain of children.
- * Domain is represented as IntSet.
- *
- * \param s The Split relation node.
- * \param dom_map The old domain result from downward message passing.
- *    Contains the domain set if all the children are full set.
- * \param outer domain of outer iteration.
- * \param inner domain of inner iteration.
- * \param parent The result domain of parent.
- */
-void PassUp(const SplitNode* s,
-            const std::unordered_map<IterVar, Range>& dom_map,
-            const IntSet& outer,
-            const IntSet& inner,
-            IntSet* parent);
-/*!
- * \brief Conditional upward message passing.
- *
- * Get domain of parent, condition on domain of children.
- * Domain is represented as IntSet.
- *
- * \param s The Fuse relation node.
- * \param dom_map The old domain result from downward message passing.
- *    Contains the domain set if all the children are full set.
- * \param fused domain of fused iteration.
- * \param outer The result domain of outer iteration.
- * \param inner The result domain of inner iteration.
- */
-void PassUp(const FuseNode* s,
-            const std::unordered_map<IterVar, Range>& dom_map,
-            const IntSet& fused,
-            IntSet* outer,
-            IntSet* inner);
-
-/*!
- * \brief Conditional upward message passing.
- *
- * Get domain of parent, condition on domain of children.
- * Domain is represented as IntSet.
- *
- * \param s The Fuse relation node.
- * \param dom_map The old domain result from downward message passing.
- *    Contains the domain set if all the children are full set.
- * \param rebased domain of rebased iteration.
- * \param parent The result domain of parent iteration.
- */
-void PassUp(const RebaseNode* s,
-            const std::unordered_map<IterVar, Range>& dom_map,
-            const IntSet& fused,
-            IntSet* parent);
 /*!
  * \brief Create an union set of all sets
  * \param sets The sets to be unioned
@@ -153,7 +112,7 @@ inline const IntSetNode* IntSet::operator->() const {
   return static_cast<const IntSetNode*>(node_.get());
 }
 
-}  // namespace schedule
+}  // namespace arith
 }  // namespace tvm
 
-#endif  // TVM_SCHEDULE_INT_SET_H_
+#endif  // TVM_ARITHMETIC_INT_SET_H_
diff --git a/src/pass/inline.cc b/src/pass/inline.cc
index de452c364..1dee4776e 100644
--- a/src/pass/inline.cc
+++ b/src/pass/inline.cc
@@ -24,10 +24,24 @@ class IRInline : public IRMutator {
     if (op->func == f_) {
       CHECK_EQ(op->value_index, 0);
       Expr expr = body_;
-      CHECK_EQ(args_.size(), op->args.size())
-          << op->args.size() << " vs " << args_.size();
-      for (size_t i = 0; i < args_.size(); ++i) {
-        expr = Let::make(args_[i], op->args[i], expr);
+      CHECK_EQ(args_.size(), op->args.size());
+
+      bool has_side_effect = false;
+      for (size_t i = 0; i < op->args.size(); ++i) {
+        if (HasSideEffect(op->args[i])) has_side_effect = true;
+      }
+
+      if (has_side_effect) {
+        for (size_t i = 0; i < args_.size(); ++i) {
+          expr = Let::make(args_[i], op->args[i], expr);
+        }
+      } else {
+        Map<Var, Expr> vmap;
+        for (size_t i = 0; i < args_.size(); ++i) {
+          vmap.Set(args_[i], op->args[i]);
+        }
+        expr = Substitute(
+            Evaluate::make(expr), vmap).as<Evaluate>()->value;
       }
       return expr;
     } else {
diff --git a/src/pass/simple_passes.cc b/src/pass/simple_passes.cc
index 0fe6b94eb..5fc928cdd 100644
--- a/src/pass/simple_passes.cc
+++ b/src/pass/simple_passes.cc
@@ -47,10 +47,10 @@ class IRSubstitue : public IRMutator {
   std::unordered_map<const Variable*, Expr> smap;
 };
 
-Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map) {
+Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map) {
   IRSubstitue m;
   for (auto kv : value_map) {
-    m.smap[kv.first->var.get()] = kv.second;
+    m.smap[kv.first.get()] = kv.second;
   }
   return m.Mutate(stmt);
 }
diff --git a/src/pass/unroll_loop.cc b/src/pass/unroll_loop.cc
new file mode 100644
index 000000000..555e5b970
--- /dev/null
+++ b/src/pass/unroll_loop.cc
@@ -0,0 +1,78 @@
+/*!
+ *  Copyright (c) 2016 by Contributors
+ *  SSA related checks and pass.
+ * \file ssa.cc
+ */
+#include <tvm/ir.h>
+#include <tvm/ir_pass.h>
+#include <tvm/ir_mutator.h>
+#include <unordered_set>
+#include <unordered_map>
+#include <vector>
+#include "../arithmetic//compute_expr.h"
+
+namespace tvm {
+namespace ir {
+
+class LoopUnroller : public IRMutator {
+ public:
+  explicit LoopUnroller(int max_auto_step)
+      : max_auto_step_(max_auto_step) {
+  }
+
+  Stmt Mutate_(const For* op, const Stmt& s) {
+    Stmt stmt = s;
+    // constant folding.
+    Expr extent = ir::Simplify(op->extent);
+    const IntImm* v1 = extent.as<IntImm>();
+    const UIntImm* v2 = extent.as<UIntImm>();
+    int value = -1;
+    if (v1 != nullptr) {
+      value = static_cast<int>(v1->value);
+    }
+    if (v2 != nullptr) {
+      value = static_cast<int>(v2->value);
+    }
+    bool allow_unroll = value >= 0 && value <= max_auto_step_;
+    if (op->for_type == ForType::Unrolled) {
+      CHECK_GE(value, 0)
+          << "Cannot unroll non-constant loop";
+      allow_unroll = true;
+    }
+
+    if (allow_unroll) {
+      using arith::ComputeExpr;
+      if (value == 0) return Evaluate::make(0);
+      Stmt body = op->body;
+      Map<Var, Expr> vmap;
+      Stmt unrolled;
+      for (int i = 0; i < value; ++i) {
+        Var lv(op->loop_var.node_);
+        vmap.Set(lv,
+                 ComputeExpr<Add>(
+                     op->min, make_const(op->loop_var.type(), i)));
+        Stmt step = Substitute(body, vmap);
+        if (unrolled.defined()) {
+          unrolled = Block::make(unrolled, step);
+        } else {
+          unrolled = step;
+        }
+      }
+      return this->Mutate(unrolled);
+    } else {
+      return IRMutator::Mutate_(op, stmt);
+    }
+  }
+
+ private:
+  int max_auto_step_;
+};
+
+
+Stmt UnrollLoop(Stmt stmt, int max_auto_step) {
+  Stmt ret = LoopUnroller(max_auto_step).Mutate(stmt);
+  return ConvertSSA(ret);
+}
+
+}  // namespace ir
+}  // namespace tvm
diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc
index 706550843..9fe530b67 100644
--- a/src/schedule/bound.cc
+++ b/src/schedule/bound.cc
@@ -7,13 +7,15 @@
 #include <tvm/ir_visitor.h>
 #include <tvm/ir_pass.h>
 #include <tvm/schedule_pass.h>
-#include "./int_set.h"
 #include "./graph.h"
+#include "../arithmetic/int_set.h"
 #include "../runtime/thread_storage_scope.h"
 
 namespace tvm {
 namespace schedule {
 
+using namespace arith;
+
 // result = ceil((a / b)), both a and b are positive integer
 inline Expr DivCeil(Expr a, Expr b) {
   return ir::Simplify((a + b - 1) / b);
@@ -70,6 +72,80 @@ void PassDown(const Stage& s,
 // pass the integer set on each leave loop up to the root
 // dom_map is the result of PassDown, it records the domain of each IterVar.
 // dom_map can be used to get cached result in reverse construction.
+// Implementation of Evaluations and passing.
+void PassUp(const SplitNode* s,
+            const std::unordered_map<IterVar, Range>& dom_map,
+            const IntSet& outer,
+            const IntSet& inner,
+            IntSet* parent) {
+  if (dom_map.count(s->outer) &&
+      dom_map.count(s->inner) &&
+      dom_map.count(s->parent) &&
+      outer.match_range(dom_map.at(s->outer)) &&
+      inner.match_range(dom_map.at(s->inner))) {
+    *parent = IntSet::range(dom_map.at(s->parent));
+    return;
+  }
+  Expr factor = dom_map.at(s->inner)->extent;
+  Expr parent_min = dom_map.at(s->parent)->min;
+  CHECK(outer.defined());
+  CHECK(inner.defined());
+  CHECK(factor.defined());
+  *parent = EvalSet(
+      s->outer->var * factor + s->inner->var + parent_min,
+      {{s->outer, outer}, {s->inner, inner}});
+}
+
+void PassUp(const FuseNode* s,
+            const std::unordered_map<IterVar, Range>& dom_map,
+            const IntSet& fused,
+            IntSet* outer,
+            IntSet* inner) {
+  CHECK(dom_map.count(s->outer));
+  CHECK(dom_map.count(s->inner));
+  CHECK(dom_map.count(s->fused));
+
+  if (fused.match_range(dom_map.at(s->fused))) {
+    *outer = IntSet::range(dom_map.at(s->outer));
+    *inner = IntSet::range(dom_map.at(s->inner));
+    return;
+  }
+  Expr outer_min = dom_map.at(s->outer)->min;
+  Expr inner_min = dom_map.at(s->inner)->min;
+
+  if (fused.is_single_point()) {
+    Expr value = fused.point_value();
+    Expr factor = dom_map.at(s->inner)->extent;
+    Expr v_outer  = value / factor;
+    Expr v_inner  = value % factor;
+    if (!is_zero(outer_min)) v_outer = v_outer + outer_min;
+    if (!is_zero(inner_min)) v_inner = v_inner + inner_min;
+    *outer = IntSet::single_point(v_outer);
+    *inner = IntSet::single_point(v_inner);
+  } else {
+    LOG(WARNING) << "use fallback inference rule in fuse";
+    // simply use the entire set, this rule can be enhanced.
+    *outer = IntSet::range(dom_map.at(s->outer));
+    *inner = IntSet::range(dom_map.at(s->inner));
+    return;
+  }
+}
+
+
+void PassUp(const RebaseNode* s,
+            const std::unordered_map<IterVar, Range>& dom_map,
+            const IntSet& rebased,
+            IntSet* parent) {
+  CHECK(dom_map.count(s->parent));
+  if (rebased.match_range(dom_map.at(s->rebased))) {
+    *parent = IntSet::range(dom_map.at(s->parent));
+    return;
+  }
+  Expr parent_min = dom_map.at(s->parent)->min;
+  *parent = EvalSet(s->rebased->var + parent_min,
+                    {{s->rebased, rebased}});
+}
+
 void PassUp(const Stage& s,
             const std::unordered_map<IterVar, Range>& dom_map,
             std::unordered_map<IterVar, IntSet>* p_state) {
diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc
index 530eecaac..33272fceb 100644
--- a/src/schedule/graph.cc
+++ b/src/schedule/graph.cc
@@ -6,7 +6,6 @@
 #include <tvm/ir.h>
 #include <tvm/ir_visitor.h>
 #include <unordered_set>
-#include "./int_set.h"
 #include "./graph.h"
 
 namespace tvm {
diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc
index e1390b589..58d5f6bdb 100644
--- a/src/schedule/schedule_ops.cc
+++ b/src/schedule/schedule_ops.cc
@@ -9,13 +9,13 @@
 #include <tvm/schedule_pass.h>
 
 #include "../pass/ir_util.h"
-#include "./int_set.h"
+#include "../arithmetic/compute_expr.h"
 #include "./graph.h"
-#include "./compute_expr.h"
 
 namespace tvm {
 namespace schedule {
 
+using namespace arith;
 using namespace ir;
 
 /*!
@@ -230,6 +230,15 @@ MakeLoopNest(const Stage& sch,
   return nest;
 }
 
+Stmt Substitute(Stmt s,
+                const std::unordered_map<IterVar, Expr>& value_map) {
+  Map<Var, Expr> temp;
+  for (const auto& kv : value_map) {
+    temp.Set(kv.first->var, kv.second);
+  }
+  return ir::Substitute(s, temp);
+}
+
 Stmt MakeLoop(const Stage& s,
               const Map<IterVar, Range>& dom_map,
               Stmt provide,
@@ -244,7 +253,6 @@ Stmt MakeLoop(const Stage& s,
   auto nest = MakeLoopNest(s, dom_map, 0, false,
                            bound_state, {}, &value_map);
 
-
   provide = Substitute(provide, value_map);
   if (init.defined()) {
     // try to find the location to insert the initialization.
diff --git a/tests/python/unittest/test_pass_unroll.py b/tests/python/unittest/test_pass_unroll.py
new file mode 100644
index 000000000..191377baa
--- /dev/null
+++ b/tests/python/unittest/test_pass_unroll.py
@@ -0,0 +1,20 @@
+import tvm
+
+def test_unroll_loop():
+    dtype = 'int64'
+    n = tvm.Var('n')
+    Ab = tvm.Buffer((n, ), dtype)
+    i = tvm.Var('i')
+    j = tvm.Var('j')
+    # for i in 0 to n-1:
+    stmt = tvm.make.For(
+        i, n, 2, 0, 0,
+        tvm.make.For(j, 0, n, 0, 0,
+                     tvm.make.Store(Ab.data,
+                                    tvm.make.Load(dtype, Ab.data, i) + 1,
+                                    j + 1)))
+    stmt = tvm.ir_pass.UnrollLoop(stmt, 8)
+    print(stmt)
+
+if __name__ == "__main__":
+    test_unroll_loop()
-- 
GitLab