From c8ec41118dc7a63581e387a8a74e61b55375341d Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Mon, 20 Feb 2017 13:09:23 -0800
Subject: [PATCH] [SCAN/Refactor] Refactor scan interface, enable fix point
 analysis. (#47)

---
 include/tvm/operation.h                       |   4 +-
 include/tvm/schedule.h                        |   4 +-
 include/tvm/tensor.h                          |   2 +
 python/tvm/addon/nvcc_compiler.py             |   9 +-
 python/tvm/api.py                             |  11 +-
 python/tvm/build.py                           |   3 +-
 src/api/api_schedule.cc                       |   3 +
 src/arithmetic/int_set.cc                     |  10 +-
 src/lang/operation.cc                         |  28 +-
 src/pass/inline.cc                            |   4 +-
 src/schedule/bound.cc                         | 117 +++----
 src/schedule/graph.cc                         | 307 ++++++++++++++++-
 src/schedule/graph.h                          |  54 +++
 src/schedule/schedule_dataflow_rewrite.cc     | 312 ++++++++++++++++++
 src/schedule/schedule_lang.cc                 | 265 +++------------
 src/schedule/schedule_ops.cc                  |  51 ++-
 tests/python/integration/test_scan.py         |   6 +-
 tests/python/unittest/test_lang_tensor.py     |   5 +-
 .../unittest/test_schedule_bound_inference.py |  31 +-
 tests/python/unittest/test_schedule_graph.py  | 101 ++++++
 .../unittest/test_schedule_schedule_ops.py    |  28 +-
 21 files changed, 977 insertions(+), 378 deletions(-)
 create mode 100644 src/schedule/schedule_dataflow_rewrite.cc
 create mode 100644 tests/python/unittest/test_schedule_graph.py

diff --git a/include/tvm/operation.h b/include/tvm/operation.h
index 745277308..85b289f5d 100644
--- a/include/tvm/operation.h
+++ b/include/tvm/operation.h
@@ -152,14 +152,12 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"
 /*!
  * \brief Construct new tensors by scan over scan_axis.
  *
- * \param scan_axis The iteration representing the scan.
  * \param init The intialize tensor of first K steps.
  * \param update The update tensor indicated the updated result after each timestamp.
  * \param state_placeholder The placeholder for the states.
  * \param name The optional name of the tensor.
  */
-Array<Tensor> scan(IterVar scan_axis,
-                   Array<Tensor> init,
+Array<Tensor> scan(Array<Tensor> init,
                    Array<Tensor> update,
                    Array<Tensor> state_placeholder,
                    std::string name = "scan");
diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
index 184075677..c6bbc6566 100644
--- a/include/tvm/schedule.h
+++ b/include/tvm/schedule.h
@@ -26,7 +26,9 @@ enum AttachType : int {
   kNone = 0,
   kRoot = 1,
   kInline = 2,
-  kScope = 3
+  kInlinedAlready = 3,
+  kScope = 4,
+  kScanUpdate = 5
 };
 
 /*! \brief IterVar type */
diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h
index 92786b331..11766cd00 100644
--- a/include/tvm/tensor.h
+++ b/include/tvm/tensor.h
@@ -175,6 +175,8 @@ class OperationNode : public FunctionBaseNode {
   virtual Type output_dtype(size_t i) const = 0;
   /*! \return shape of i-th output */
   virtual Array<Expr> output_shape(size_t i) const = 0;
+
+  static constexpr const char* _type_key = "Operation";
 };
 
 // Implementations of inline functions
diff --git a/python/tvm/addon/nvcc_compiler.py b/python/tvm/addon/nvcc_compiler.py
index a1c2b938d..7895a2b98 100644
--- a/python/tvm/addon/nvcc_compiler.py
+++ b/python/tvm/addon/nvcc_compiler.py
@@ -4,7 +4,7 @@ import sys
 import tempfile
 import subprocess
 
-def compile_source(code, target="cubin"):
+def compile_source(code, target="cubin", options=None):
     """Compile cuda code with NVCC from env.
 
     Parameters
@@ -12,9 +12,12 @@ def compile_source(code, target="cubin"):
     code : str
         The cuda code.
 
-    target: str
+    target : str
         The target format
 
+    options : str
+        The additional options
+
     Return
     ------
     cubin : bytearray
@@ -32,6 +35,8 @@ def compile_source(code, target="cubin"):
     cmd = ["nvcc"]
     cmd += ["--%s" % target, "-O3"]
     cmd += ["-o", path_target]
+    if options:
+        cmd += options
     cmd += [path_code]
     args = ' '.join(cmd)
 
diff --git a/python/tvm/api.py b/python/tvm/api.py
index 2c3f54483..d6c81bac6 100644
--- a/python/tvm/api.py
+++ b/python/tvm/api.py
@@ -140,14 +140,11 @@ def compute(shape, fcompute, name="compute"):
     return op_node.output(0)
 
 
-def scan(axis, init, update, state_placeholder, name="scan"):
+def scan(init, update, state_placeholder, name="scan"):
     """Construct new tensors by scanning over axis.
 
     Parameters
     ----------
-    axis: IterVar
-        The scanning axis.
-
     init: Tensor or list of Tensor
         The initial condition of first init.shape[0] timestamps
 
@@ -170,12 +167,11 @@ def scan(axis, init, update, state_placeholder, name="scan"):
     # The following code is equivalent to numpy.cumsum
     m = tvm.Var("m")
     n = tvm.Var("n")
-    t = tvm.IterVar((1, m), name="t")
     X = tvm.placeholder((m, n), name="X")
     s_state = tvm.placeholder((m, n))
     s_init = tvm.compute((1, n), lambda _, i: X[0, i])
-    s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
-    res = tvm.scan(t, s_init, s_update, s_state)
+    s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
+    res = tvm.scan(s_init, s_update, s_state)
     """
     if isinstance(init, _tensor.Tensor):
         init = [init]
@@ -185,6 +181,7 @@ def scan(axis, init, update, state_placeholder, name="scan"):
         state_placeholder = [state_placeholder]
     if len(init) != len(update) or len(init) != len(state_placeholder):
         raise ValueError("init, update, state_placeholder must have same length")
+    axis = IterVar((init[0].shape[0], update[0].shape[0]), "%s.idx" % name)
     op = _api_internal._ScanOp(name, axis, init, update, state_placeholder)
     res = [op.output(i) for i in range(len(update))]
     return (res[0] if len(res) == 1 else res)
diff --git a/python/tvm/build.py b/python/tvm/build.py
index 40cb92b45..764db0ae5 100644
--- a/python/tvm/build.py
+++ b/python/tvm/build.py
@@ -63,7 +63,8 @@ def build(sch,
             arg_list.append(x)
         else:
             raise ValueError("args must be Tensor, Buffer or Var")
-    # lowering
+    # normalize schedule first
+    sch.normalize()
     bounds = schedule.InferBound(sch)
     stmt = schedule.ScheduleOps(sch, bounds)
     stmt = ir_pass.StorageFlatten(stmt, binds)
diff --git a/src/api/api_schedule.cc b/src/api/api_schedule.cc
index 882ff94bd..d953e37e2 100644
--- a/src/api/api_schedule.cc
+++ b/src/api/api_schedule.cc
@@ -34,6 +34,9 @@ TVM_REGISTER_API(_schedule_AutoInlineElemWise)
 REGISTER_SCHEDULE_PASS1(InferBound);
 REGISTER_SCHEDULE_PASS1(CreateReadGraph);
 REGISTER_SCHEDULE_PASS2(PostDFSOrder);
+REGISTER_SCHEDULE_PASS1(ScanGetBody);
+REGISTER_SCHEDULE_PASS1(CreateAttachPath);
+REGISTER_SCHEDULE_PASS2(ScanFixPointAnalysis);
 REGISTER_SCHEDULE_PASS2(ScheduleOps);
 
 }  // namespace schedule
diff --git a/src/arithmetic/int_set.cc b/src/arithmetic/int_set.cc
index 8fdba6650..8c89d93e6 100644
--- a/src/arithmetic/int_set.cc
+++ b/src/arithmetic/int_set.cc
@@ -166,7 +166,15 @@ IntSet Union(const Array<IntSet>& set) {
   if (set.size() == 1) return set[0];
   Interval x = set[0].cover_interval().as<IntervalSet>()->i;
   for (size_t i = 1; i < set.size(); ++i) {
-    x.include(set[i].cover_interval().as<IntervalSet>()->i);
+    IntSet s = set[i].cover_interval();
+    const Interval& y = s.as<IntervalSet>()->i;
+    if (can_prove(x.max + 1 >= y.min)) {
+      x.max = y.max;
+    } else if (can_prove(y.max + 1 >= x.min)) {
+      x.min = y.min;
+    } else {
+      x.include(y);
+    }
   }
   return IntervalSet::make(x);
 }
diff --git a/src/lang/operation.cc b/src/lang/operation.cc
index ddc4770f0..ac1e95417 100644
--- a/src/lang/operation.cc
+++ b/src/lang/operation.cc
@@ -51,8 +51,6 @@ Operation PlaceholderOpNode::make(std::string name,
   return Operation(n);
 }
 
-
-
 Tensor placeholder(Array<Expr> shape, Type dtype, std::string name) {
   return PlaceholderOpNode::make(name, shape, dtype).output(0);
 }
@@ -162,24 +160,25 @@ Operation ScanOpNode::make(std::string name,
         << " scan_axis.dom.min + scan_axis.dom.extent";
     CHECK_EQ(state_placeholder[i].ndim(), init[i].ndim())
         << "The dimension of init need to match state_placeholder";
-    CHECK_EQ(update[i].ndim() + 1, state_placeholder[i].ndim())
+    CHECK_EQ(update[i].ndim(), state_placeholder[i].ndim())
         << "The update.ndim need to be state_placeholder.ndim - 1";
     for (size_t k = 0;  k < update[i].ndim(); ++k) {
       CHECK(prove_equal(
-          update[i]->shape[k], state_placeholder[i]->shape[k + 1]));
-      // setup spatial axis
-      std::ostringstream spatial_name;
-      spatial_name << name << ".out" << i << ".i" << k + 1;
-      n->spatial_axis_.push_back(
-          IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
-                  spatial_name.str()));
+          update[i]->shape[k], state_placeholder[i]->shape[k]));
+      if (k != 0) {
+        // setup spatial axis
+        std::ostringstream spatial_name;
+        spatial_name << name << ".out" << i << ".i" << k;
+        n->spatial_axis_.push_back(
+            IterVar(Range::make_with_min_extent(0, update[i]->shape[k]),
+                    spatial_name.str()));
+      }
     }
     for (size_t k = 1;  k < init[i].ndim(); ++k) {
       CHECK(prove_equal(
           init[i]->shape[k], state_placeholder[i]->shape[k]));
     }
   }
-
   n->name = name;
   n->scan_axis = axis;
   n->init = init;
@@ -188,11 +187,14 @@ Operation ScanOpNode::make(std::string name,
   return Operation(n);
 }
 
-Array<Tensor> scan(IterVar scan_axis,
-                   Array<Tensor> init,
+Array<Tensor> scan(Array<Tensor> init,
                    Array<Tensor> update,
                    Array<Tensor> state_placeholder,
                    std::string name) {
+  IterVar scan_axis(
+      Range::make_with_min_extent(
+          init[0]->shape[0], update[0]->shape[0] - init[0]->shape[0]),
+      name + ".idx");
   Operation op = ScanOpNode::make(
       name, scan_axis, init, update, state_placeholder);
   Array<Tensor> res;
diff --git a/src/pass/inline.cc b/src/pass/inline.cc
index 1dee4776e..87f54ce0b 100644
--- a/src/pass/inline.cc
+++ b/src/pass/inline.cc
@@ -61,7 +61,9 @@ Stmt Inline(Stmt stmt,
             Expr body) {
   CHECK_EQ(f->num_outputs(), 1)
       << "can only inline output single value operation";
-  return ConvertSSA(IRInline(f, args, body).Mutate(stmt));
+  Stmt ret = IRInline(f, args, body).Mutate(stmt);
+  if (ret.same_as(stmt)) return ret;
+  return ConvertSSA(ret);
 }
 }  // namespace ir
 }  // namespace tvm
diff --git a/src/schedule/bound.cc b/src/schedule/bound.cc
index 4724d9762..c2fa061bd 100644
--- a/src/schedule/bound.cc
+++ b/src/schedule/bound.cc
@@ -1,3 +1,4 @@
+
 /*!
  *  Copyright (c) 2016 by Contributors
  * \file bound.cc
@@ -259,11 +260,14 @@ void BoundProp(const Operation& op,
         init_dom->data[0].push_back(IntSet::range(
             Range::make_with_min_extent(0, scan->init[i]->shape[0])));
       }
+      if (update_dom) {
+        update_dom->data[0].push_back(dom_map.at(scan->scan_axis->var.get()));
+      }
       // The update dimensions
-      for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
+      for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
         IterVar sp_ax = scan->spatial_axis_[sp_idx];
         if (init_dom) {
-          init_dom->data[k + 1].push_back(dom_map.at(sp_ax->var.get()));
+          init_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
         }
         if (update_dom) {
           update_dom->data[k].push_back(dom_map.at(sp_ax->var.get()));
@@ -277,10 +281,12 @@ void BoundProp(const Operation& op,
   }
 }
 
+
 // Given the bound of output of op
 // Pass the bound to the related axis in op.
 void GatherOpBound(const ScanOpNode* scan,
                    const Operation& op,
+                   const FeedGraph& fg,
                    const std::unordered_map<Tensor, TensorDom>& tmap,
                    std::unordered_map<IterVar, Range>* rmap) {
   CHECK(!rmap->count(scan->scan_axis));
@@ -299,21 +305,29 @@ void GatherOpBound(const ScanOpNode* scan,
   Range r = arith::Union(time_dom).cover_range(sdom);
   (*rmap)[scan->scan_axis] = Range::make_with_min_extent(
       sdom->min, ir::Simplify(r->extent + r->min - sdom->min));
+  Array<Operation> body = ScanGetBody_(scan, fg);
+  Map<IterVar, Expr> fix_pt = ScanFixPointAnalysis(op, body);
   // Update for spatial axis.
   size_t sp_idx = 0;
   for (size_t i = 0; i < output.size(); ++i) {
-    for (size_t k = 0; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
+    const TensorDom& d = tmap.at(output[i]);
+    for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
       IterVar sp_ax = scan->spatial_axis_[sp_idx];
       CHECK(!rmap->count(sp_ax));
-      // In default, we always need all spatial axis
-      // Unless that axis only refers back to itself as a fixed point.
-      // TODO(tqchen): Add fix point detection.
-      (*rmap)[sp_ax] = sp_ax->dom;
+      CHECK(fix_pt.count(sp_ax));
+      if (fix_pt[sp_ax].as<ir::IntImm>()->value) {
+        // fix point, we can slice it.
+        (*rmap)[sp_ax] = arith::Union(d.data[k + 1]).cover_range(sp_ax->dom);
+      } else {
+        // not a fix point, need to include everything.
+        (*rmap)[sp_ax] = sp_ax->dom;
+      }
     }
   }
 }
 
 void GatherOpBound(const Operation& op,
+                   const FeedGraph& fg,
                    const std::unordered_map<Tensor, TensorDom>& tmap,
                    std::unordered_map<IterVar, Range>* rmap) {
   if (op.as<ComputeOpNode>()) {
@@ -329,7 +343,7 @@ void GatherOpBound(const Operation& op,
       (*rmap)[compute->reduce_axis[i]] = compute->reduce_axis[i]->dom;
     }
   } else if (op.as<ScanOpNode>()) {
-    GatherOpBound(op.as<ScanOpNode>(), op, tmap, rmap);
+    GatherOpBound(op.as<ScanOpNode>(), op, fg, tmap, rmap);
   } else if (op.as<PlaceholderOpNode>()) {
     // dp nothing
   } else {
@@ -347,20 +361,14 @@ inline bool ScopeRelax(const IterVar& iv, const std::string& scope) {
   return StorageScope::make(scope).rank <= ThreadScope::make(iv->thread_tag).rank;
 }
 
-// The map beteen tensor and operation it feeds ti
-using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
-
-// AttachPath maps op-> a list of IterVar
-// That represents the loop nest op sits in from inner most to outermost
-using AttachPath = Map<Operation, Array<IterVar> >;
-
-
 void InferRootBound(const Stage& stage,
                     const FeedGraph& feed_graph,
                     const AttachPath& attach_path,
                     std::unordered_map<IterVar, Range>* rmap) {
-  if (stage->attach_type == kInline) return;
-  if (stage->attach_type == kRoot || stage->attach_type == kNone) {
+  CHECK_NE(stage->attach_type, kInline)
+      << "call schedule.normalize before scheduleops";
+  if (stage->attach_type == kInlinedAlready) return;
+  if (stage->is_output || stage->op.as<PlaceholderOpNode>()) {
     for (auto iv :  OutputRelatedIterVars(stage->op)) {
       CHECK(iv->dom.defined());
       CHECK(!rmap->count(iv));
@@ -368,11 +376,11 @@ void InferRootBound(const Stage& stage,
     }
     return;
   }
-  // Infer root bounds for the attached node.
-  CHECK_EQ(stage->attach_type, kScope);
-  Stage parent = stage->attach_stage;
-  CHECK(parent.defined());
-
+  // parent stage, if any
+  Stage parent;
+  if (stage->attach_type == kScope || stage->attach_type == kScanUpdate) {
+    parent = stage->attach_stage;
+  }
   // The tensor domain.
   std::unordered_map<Tensor, TensorDom> tmap;
   // consumers other than parent
@@ -385,7 +393,7 @@ void InferRootBound(const Stage& stage,
     auto it = feed_graph.find(t);
     if (it != feed_graph.end()) {
       for (const Operation& op : it->second) {
-        if (op != parent->op) {
+        if (!parent.defined() || op != parent->op) {
           consumers.insert(op);
         } else {
           direct_consume_by_parent = true;
@@ -404,16 +412,20 @@ void InferRootBound(const Stage& stage,
       relax_set[iv->var.get()] = IntSet::range(rmap->at(iv));
     }
   }
-
   if (direct_consume_by_parent) {
+    // parent stage if exist
+    Stage parent = stage->attach_stage;
     // Bound inference logics in parent.
     std::unordered_map<IterVar, IntSet> up_state;
     bool fix_value = true;
     for (auto iv : parent->leaf_iter_vars) {
-      Range vrange = rmap->at(iv);
+      auto it = rmap->find(iv);
+      CHECK(it != rmap->end());
+      Range vrange = it->second;
       CHECK(is_zero(vrange->min))
           << "InferBound requires every leaf iter var's min equals 0, "
-          << "call schedule.normalize to achieve this.";
+          << " call schedule.normalize to achieve this. "
+          << " stage=" << parent;
       // special optimization to remove trivial loop
       if (is_one(vrange->extent)) {
         up_state[iv] = IntSet::single_point(vrange->min);
@@ -464,8 +476,9 @@ void InferRootBound(const Stage& stage,
   for (const Operation& op : consumers) {
     std::unordered_map<const Variable*, IntSet> dom_map;
     bool found = false;
+    Array<IterVar> attach = attach_path.at(stage->op);
     for (IterVar iv : attach_path.at(op)) {
-      if (iv == stage->attach_ivar) {
+      if (attach.size() != 0 && iv == attach[0]) {
         found = true; break;
       }
       Range vrange = rmap->at(iv);
@@ -474,7 +487,7 @@ void InferRootBound(const Stage& stage,
           << "call schedule.normalize to achieve this.";
       relax_set[iv->var.get()] = IntSet::range(vrange);
     }
-    CHECK(found)
+    CHECK(found || attach.size() == 0)
         << "Invalid Schedule, cannot find the producer " << stage->op
         << " along the loop nest specified by compute_at of consumer " << op;
     for (auto iv : OutputRelatedIterVars(op)) {
@@ -483,50 +496,15 @@ void InferRootBound(const Stage& stage,
     }
     BoundProp(op, dom_map, &tmap);
   }
-  GatherOpBound(stage->op, tmap, rmap);
+  GatherOpBound(stage->op, feed_graph, tmap, rmap);
 }
 
-FeedGraph CreateFeedGraph(const Schedule& sch) {
+Map<IterVar, Range> InferBound(const Schedule& sch) {
   Array<Operation> roots;
   for (Operation op : sch->outputs) {
     roots.push_back(sch->stage_map[op]->op);
   }
-  auto g = CreateReadGraph(roots);
-  FeedGraph fg;
-  for (auto kv : g) {
-    for (Tensor t : kv.second) {
-      fg[t].push_back(kv.first);
-    }
-  }
-  return fg;
-}
-
-// Create AttachPath that  maps op-> a list of IterVar
-// That represents the loop nest op sits in from inner most to outermost
-AttachPath CreateAttachPath(const Schedule& sch) {
-  AttachPath ret;
-  for (Stage stage : sch->stages) {
-    Array<IterVar> path;
-    for (Stage s = stage; s->attach_type == kScope;) {
-      IterVar attach_ivar = s->attach_ivar;
-      s = s->attach_stage;
-      bool start_attach = false;
-      for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
-        IterVar iv = s->leaf_iter_vars[i - 1];
-        if (iv == attach_ivar) start_attach = true;
-        if (start_attach) path.push_back(iv);
-      }
-      CHECK(start_attach)
-          << "Invalid Schedule: cannot find attach point " << attach_ivar
-          << " in the schedule of " << s->op;
-    }
-    ret.Set(stage->op, path);
-  }
-  return ret;
-}
-
-Map<IterVar, Range> InferBound(const Schedule& sch) {
-  FeedGraph feed_graph = CreateFeedGraph(sch);
+  FeedGraph feed_graph = CreateFeedGraph(CreateReadGraph(roots));
   AttachPath attach_path = CreateAttachPath(sch);
 
   std::unordered_map<IterVar, Range> ret;
@@ -535,6 +513,11 @@ Map<IterVar, Range> InferBound(const Schedule& sch) {
     InferRootBound(stage, feed_graph, attach_path, &ret);
     // pass down to get bound of all iter vars.
     PassDown(stage, &ret);
+    // setup outer most threads.
+    for (IterVar iv : stage->outermost_threads) {
+      CHECK(iv->dom.defined());
+      ret[iv] = iv->dom;
+    }
   }
   return Map<IterVar, Range>(ret.begin(), ret.end());
 }
diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc
index f1047bf95..5cd6b9519 100644
--- a/src/schedule/graph.cc
+++ b/src/schedule/graph.cc
@@ -8,6 +8,46 @@
 #include <unordered_set>
 #include "./graph.h"
 
+namespace tvm {
+namespace schedule {
+// key to specific tensor dimension.
+struct TensorDimKey {
+  FunctionRef f;
+  int value_index;
+  int dim;
+  TensorDimKey() {}
+  TensorDimKey(const ir::Call* op, int dim)
+      : f(op->func), value_index(op->value_index), dim(dim) {
+  }
+  TensorDimKey(const Tensor& t, int dim)
+      : f(t->op), value_index(t->value_index), dim(dim) {
+  }
+  inline bool operator==(const TensorDimKey& other) const {
+    return f == other.f &&
+        value_index == other.value_index &&
+        dim == other.dim;
+  }
+  inline bool operator!=(const TensorDimKey& other) const {
+    return !operator==(other);
+  }
+};
+}  // namespace schedule
+}  // namespace tvm
+
+namespace std {
+template <>
+struct hash<::tvm::schedule::TensorDimKey> {
+  std::size_t operator()(const ::tvm::schedule::TensorDimKey& k) const {
+    size_t lhs = k.f.hash();
+    size_t rhs = static_cast<size_t>(k.value_index) << 32UL |
+        static_cast<size_t>(k.dim);
+    lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
+    return lhs;
+  }
+};
+}  // namespace std
+
+
 namespace tvm {
 namespace schedule {
 
@@ -28,7 +68,7 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
     stack.pop_back();
     Array<Tensor> deps;
     if (op.as<ComputeOpNode>()) {
-      auto fvisit = [&deps, &visited, &stack](const NodeRef& n) {
+      auto fvisit = [&deps](const NodeRef& n) {
         auto *call = n.as<ir::Call>();
         if (call != nullptr && call->func.defined()) {
           Operation call_op(call->func.node_);
@@ -59,7 +99,6 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots) {
   return rmap;
 }
 
-
 void PostDFSOrder(const Operation& op,
                   const ReadGraph& g,
                   std::unordered_set<Operation>* visited,
@@ -83,5 +122,269 @@ Array<Operation> PostDFSOrder(
   return post_order;
 }
 
+FeedGraph CreateFeedGraph(const ReadGraph& g) {
+  FeedGraph fg;
+  for (auto kv : g) {
+    for (Tensor t : kv.second) {
+      fg[t].push_back(kv.first);
+    }
+  }
+  return fg;
+}
+
+AttachPath CreateAttachPath(Schedule sch) {
+  AttachPath ret;
+
+  for (Stage stage : sch->stages) {
+    if (stage->attach_type == kScanUpdate) {
+      const Stage& parent = stage->attach_stage;
+      stage->attach_ivar =
+          parent->leaf_iter_vars[parent->leaf_iter_vars.size() - 1];
+    }
+  }
+
+  for (Stage stage : sch->stages) {
+    Array<IterVar> path;
+
+    for (Stage s = stage; s->attach_type == kScope || s->attach_type == kScanUpdate;) {
+      IterVar attach_ivar = s->attach_ivar;
+      s = s->attach_stage;
+      bool start_attach = false;
+      for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
+        IterVar iv = s->leaf_iter_vars[i - 1];
+        if (iv == attach_ivar) start_attach = true;
+        if (start_attach) path.push_back(iv);
+      }
+      CHECK(start_attach)
+          << "Invalid Schedule: cannot find attach point " << attach_ivar
+          << " in the schedule of " << s->op;
+    }
+
+    if (!ret.count(stage->op)) {
+      ret.Set(stage->op, path);
+    }
+  }
+  return ret;
+}
+
+// graph of push reach relation of tensor dimensions
+using ReachGraph = std::unordered_map<TensorDimKey, std::vector<TensorDimKey> >;
+
+ReachGraph GetReachGraph(const Array<Operation>& ops) {
+  ReachGraph reach;
+  std::unordered_set<const Node*> bset;
+  for (size_t i = 0; i < ops.size(); ++i) {
+    bset.insert(ops[i].get());
+  }
+
+  for (Operation op : ops) {
+    if (op.as<ScanOpNode>()) {
+      const auto& update = op.as<ScanOpNode>()->update;
+      const auto& init = op.as<ScanOpNode>()->init;
+      for (size_t i = 0; i < update.size(); ++i) {
+        Tensor t = op.output(i);
+        for (size_t k = 1; k < update[i]->shape.size(); ++k) {
+          reach[TensorDimKey(t, k)].emplace_back(
+              TensorDimKey(update[i], k));
+          reach[TensorDimKey(t, k)].emplace_back(
+              TensorDimKey(init[i], k));
+        }
+      }
+    } else if (op.as<ComputeOpNode>()) {
+      std::unordered_map<const Node*, TensorDimKey> vmap;
+      const auto& axis = op.as<ComputeOpNode>()->axis;
+      Tensor t = op.output(0);
+      for (size_t i = 0; i < axis.size(); ++i) {
+        vmap[axis[i]->var.get()] = TensorDimKey(t, i);
+        reach[TensorDimKey(t, i)] = {};
+      }
+      auto fvisit = [&vmap, &reach, &bset](const NodeRef& n) {
+        const ir::Call *call = n.as<ir::Call>();
+        if (call != nullptr && call->func.defined()) {
+          if (!bset.count(call->func.get())) return;
+          for (size_t i = 0; i < call->args.size(); ++i) {
+            TensorDimKey dkey(call, i);
+            auto fpush = [&dkey, &vmap, &reach](const NodeRef& node) {
+              const Variable *v = node.as<Variable>();
+              auto it = vmap.find(v);
+              if (it != vmap.end()) {
+                reach[it->second].push_back(dkey);
+              }
+            };
+            ir::PostOrderVisit(call->args[i], fpush);
+          }
+        }
+      };
+      ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
+    }
+  }
+  return reach;
+}
+
+// Get all the operations that forms body of scan
+void ScanGetBodyPostDFS_(
+    Operation op,
+    const ScanOpNode* scan,
+    const FeedGraph& feed_graph,
+    std::unordered_set<const Node*>* visited,
+    Array<Operation>* result) {
+  if (op.get() == scan) return;
+  bool empty_feed = true;
+  for (int i = 0; i < op->num_outputs(); ++i) {
+    auto it = feed_graph.find(op.output(i));
+    if (it != feed_graph.end() && it->second.size()) {
+      empty_feed = false;
+      for (const Operation& xop : it->second) {
+        if (visited->count(xop.get())) continue;
+        visited->insert(xop.get());
+        ScanGetBodyPostDFS_(xop, scan, feed_graph, visited, result);
+        result->push_back(xop);
+      }
+    }
+  }
+  if (empty_feed && op.get() != scan) {
+    LOG(FATAL) << "Bad scan body, tensor reads scan_state but not connect to scan";
+  }
+}
+
+Array<Operation> ScanGetBody_(
+    const ScanOpNode* scan,
+    const FeedGraph& feed_graph) {
+  CHECK(scan != nullptr);
+  std::unordered_set<const Node*> visited;
+  Array<Operation> result;
+  for (Tensor t : scan->state_placeholder) {
+    ScanGetBodyPostDFS_(t->op, scan, feed_graph, &visited, &result);
+  }
+  return result;
+}
+
+Array<Operation> ScanGetBody(const Operation& scan) {
+  return ScanGetBody_(scan.as<ScanOpNode>(),
+                      CreateFeedGraph(CreateReadGraph({scan})));
+}
+
+Map<IterVar, Expr> ScanFixPointAnalysis(
+    const Operation& scan_op, const Array<Operation>& body) {
+  const ScanOpNode* scan = scan_op.as<ScanOpNode>();
+  CHECK(body[0].get() == scan);
+
+  std::unordered_map<TensorDimKey, const Node*> exact_reach;
+  std::unordered_set<const Node*> fail_set;
+
+  for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
+    for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
+      TensorDimKey key(scan->state_placeholder[i], k);
+      exact_reach[key] = scan->spatial_axis_[sp_idx].get();
+    }
+  }
+  // merge exact reach
+  auto f_merge_key = [&exact_reach, &fail_set](
+      const TensorDimKey& dst, const TensorDimKey& src) {
+    auto sit = exact_reach.find(src);
+    if (sit == exact_reach.end()) return;
+    auto dit = exact_reach.find(dst);
+    if (dit == exact_reach.end()) {
+      exact_reach[dst] = sit->second;
+    } else {
+      if (dit->second != sit->second) {
+        fail_set.insert(dit->second);
+        fail_set.insert(sit->second);
+      }
+    }
+  };
+  // prop exact reach back.
+  for (size_t i = body.size(); i != 1; --i) {
+    const Operation& op = body[i - 1];
+    if (op.as<ScanOpNode>()) {
+      const auto& update = op.as<ScanOpNode>()->update;
+      const auto& init = op.as<ScanOpNode>()->init;
+      for (size_t i = 0; i < update.size(); ++i) {
+        Tensor t = op.output(i);
+        for (size_t k = 1; i < update[i]->shape.size(); ++k) {
+          f_merge_key(TensorDimKey(t, k), TensorDimKey(update[i], k));
+          f_merge_key(TensorDimKey(t, k), TensorDimKey(init[i], k));
+        }
+      }
+    } else if (op.as<ComputeOpNode>()) {
+      std::unordered_map<const Node*, TensorDimKey> vmap;
+      const auto& axis = op.as<ComputeOpNode>()->axis;
+      Tensor t = op.output(0);
+      for (size_t i = 0; i < axis.size(); ++i) {
+        vmap[axis[i]->var.get()] = TensorDimKey(t, i);
+      }
+      auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set](
+          const NodeRef& n) {
+        const ir::Call *call = n.as<ir::Call>();
+        if (call != nullptr && call->func.defined()) {
+          for (size_t i = 0; i < call->args.size(); ++i) {
+            auto it = vmap.find(call->args[i].get());
+            TensorDimKey src(call, i);
+            if (it != vmap.end()) {
+              f_merge_key(it->second, src);
+            } else {
+              if (exact_reach.count(src)) {
+                fail_set.insert(exact_reach.at(src));
+              }
+            }
+          }
+        }
+      };
+      ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit);
+    }
+  }
+  ReachGraph reach;
+  Map<IterVar, Expr> ret;
+  std::unordered_set<TensorDimKey> place_holder_ref;
+  for (size_t i = 0; i < scan->state_placeholder.size(); ++i) {
+    for (size_t k = 0; k < scan->state_placeholder[i]->shape.size(); ++k) {
+      place_holder_ref.insert(TensorDimKey(scan->state_placeholder[i], k));
+    }
+  }
+
+  for (size_t i = 0, sp_idx = 0; i < scan->update.size(); ++i) {
+    for (size_t k = 1; k < scan->update[i]->shape.size(); ++k, ++sp_idx) {
+      TensorDimKey key(scan->update[i], k);
+      TensorDimKey target(scan->state_placeholder[i], k);
+      IterVar sp_iv = scan->spatial_axis_[sp_idx];
+      if (fail_set.count(sp_iv.get()) ||
+          !exact_reach.count(key) ||
+          exact_reach.at(key) != sp_iv.get()) {
+        ret.Set(sp_iv, make_const(Int(32), 0));
+      } else {
+        // now we proved exact match, need to prove no interference with other graph.
+        if (reach.size() == 0) reach = GetReachGraph(body);
+        // do a DFS
+        std::unordered_set<TensorDimKey> visited;
+        std::vector<TensorDimKey> stack{key};
+        visited.insert(key);
+        while (!stack.empty()) {
+          TensorDimKey k = stack.back();
+          if (k != target && place_holder_ref.count(k)) break;
+          stack.pop_back();
+          if (!reach.count(k)) {
+            LOG(FATAL) << "cannot find reach of " << k.f << "-" << k.dim;
+          }
+
+          for (TensorDimKey kk : reach.at(k)) {
+            if (visited.count(kk)) {
+              continue;
+            }
+            visited.insert(kk);
+            stack.push_back(kk);
+          }
+        }
+        if (!stack.empty()) {
+          // failed the prove.
+          ret.Set(sp_iv, make_const(Int(32), 0));
+        } else {
+          ret.Set(sp_iv, make_const(Int(32), 1));
+        }
+      }
+    }
+  }
+  return ret;
+}
+
 }  // namespace schedule
 }  // namespace tvm
diff --git a/src/schedule/graph.h b/src/schedule/graph.h
index 5a40c8e4c..4b4b2df6e 100644
--- a/src/schedule/graph.h
+++ b/src/schedule/graph.h
@@ -9,6 +9,7 @@
 #include <tvm/expr.h>
 #include <tvm/schedule.h>
 #include <unordered_map>
+#include <unordered_set>
 #include <vector>
 
 namespace tvm {
@@ -19,6 +20,16 @@ namespace schedule {
  */
 using ReadGraph = Map<Operation, Array<Tensor> >;
 
+/*!
+ * \brief The map beteen tensor and operation it feeds to
+ */
+using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >;
+
+/*!
+ * \brief AttachPath maps op-> a list of IterVar
+ */
+using AttachPath = Map<Operation, Array<IterVar> >;
+
 /*!
  * \brief Get read graph of each operation to all the
  *  Tensors that it directly depends on.
@@ -41,6 +52,49 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots);
 Array<Operation> PostDFSOrder(
     const Array<Operation>& roots, const ReadGraph& g);
 
+/*!
+ * \brief Create feedgraph for given Schedule
+ * \param  g The read graph.
+ * \return The created feedgraph.
+ */
+FeedGraph CreateFeedGraph(const ReadGraph& g);
+
+/*!
+ * \brief Create AttachPath that  maps op-> a list of IterVar
+ *  That represents the loop nest op sits in from inner most to outermost
+ *  Also inserts attach_stage for scan updates when needed.
+ *
+ * \param sch The schedule.
+ * \return The attach path.
+ */
+AttachPath CreateAttachPath(Schedule sch);
+
+/*!
+ * \brief Get all operations inside the recursion of scan.
+ * \param scan The scan node.
+ * \param feed_graph The feed graph to help analysis.
+ * \return The body operations, in read dependency order.
+ */
+Array<Operation> ScanGetBody_(
+    const ScanOpNode* scan, const FeedGraph& feed_graph);
+// same as ScanGetBody_, but create FeedGraph internally.
+Array<Operation> ScanGetBody(const Operation& scan);
+
+/*!
+ * \brief Analyze each spatial dimension of scan's result.
+ *  Give check on whether each dimension is fix point,
+ *  An axis is a fixed point if it only refers back to itself in recursion
+ *  and it is not used in axis of other recursion field.
+ *
+ *  next_state[t, ..., axis, ...] = f(prev_state[t-1, ...,axis,...]
+ *
+ * \param scan The scan node.
+ * \param body The body of scan, sorted in reverse PostDFSOrder.
+ * \return Map of spatial_axis -> IntImm
+ */
+Map<IterVar, Expr> ScanFixPointAnalysis(
+    const Operation& scan, const Array<Operation>& body);
+
 }  // namespace schedule
 }  // namespace tvm
 
diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
new file mode 100644
index 000000000..9a44a9641
--- /dev/null
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -0,0 +1,312 @@
+/*!
+ *  Copyright (c) 2017 by Contributors
+ * \file schedule_dataflow_rewrite.cc
+ */
+#include <tvm/schedule.h>
+#include <tvm/ir_mutator.h>
+#include <tvm/ir_pass.h>
+#include <unordered_set>
+
+namespace tvm {
+
+// find first occurance location in leaf
+template<typename T>
+size_t FindNodeRef(ArrayNode* array_node, const T& v) {
+  const Node* n = v.get();
+  for (size_t i = 0; i < array_node->data.size(); ++i) {
+    if (array_node->data[i].get() == n) return i;
+  }
+  return array_node->data.size();
+}
+
+using ir::TensorKey;
+
+// The replacer of cache.
+class TensorReplacer : public ir::IRMutator {
+ public:
+  explicit TensorReplacer(const std::unordered_map<TensorKey, Tensor>& vmap)
+      : vmap_(vmap) {}
+  Expr Mutate_(const ir::Call* op, const Expr& e) {
+    if (op->call_type == ir::Call::Halide) {
+      ir::TensorKey key{op->func, op->value_index};
+      auto it = vmap_.find(key);
+      if (it != vmap_.end()) {
+        Expr ret = ir::Call::make(
+            op->type, it->second->op->name, op->args,
+            op->call_type, it->second->op, it->second->value_index);
+        found = true;
+        return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
+      }
+    }
+    return IRMutator::Mutate_(op, e);
+  }
+
+  // whether it is found.
+  bool found{false};
+
+ private:
+  const std::unordered_map<TensorKey, Tensor>& vmap_;
+};
+
+class VarReplacer : public ir::IRMutator {
+ public:
+  explicit VarReplacer(
+      const std::unordered_map<const Variable*, Expr>& vsub)
+      : vsub_(vsub) {}
+  Expr Mutate_(const Variable* op, const Expr& e) {
+    auto it = vsub_.find(op);
+    if (it != vsub_.end()) return it->second;
+    return e;
+  }
+
+ private:
+  const std::unordered_map<const Variable*, Expr>& vsub_;
+};
+
+// Replace data flow appears in all stages given the tensor change.
+// Also update vmap if subsequent dataflow need to be replaced.
+void ReplaceDataFlow(const Array<Stage>& stages,
+                     std::unordered_map<TensorKey, Tensor>* vmap) {
+  for (Stage s : stages) {
+    if (s->op.as<ComputeOpNode>()) {
+      const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
+      TensorReplacer repl(*vmap);
+      Expr body = repl.Mutate(compute->body);
+      if (repl.found) {
+        Operation op = ComputeOpNode::make(
+            compute->name, compute->axis, body);
+        (*vmap)[TensorKey{s->op, 0}] = op.output(0);
+        s->op = op;
+      }
+    } else if (s->op.as<ScanOpNode>()) {
+      const ScanOpNode* scan = s->op.as<ScanOpNode>();
+      std::shared_ptr<ScanOpNode> n =
+          std::make_shared<ScanOpNode>(*scan);
+      // copy on write semantics ganrantees correctness
+      for (size_t i = 0; i < n->init.size(); ++i) {
+        TensorKey key{n->init[i]->op, n->init[i]->value_index};
+        if (vmap->count(key)) {
+          n->init.Set(i, vmap->at(key));
+        }
+      }
+      for (size_t i = 0; i < n->update.size(); ++i) {
+        TensorKey key{n->update[i]->op, n->update[i]->value_index};
+        if (vmap->count(key)) {
+          n->update.Set(i, vmap->at(key));
+        }
+      }
+      if (!n->init.same_as(scan->init) ||
+          !n->update.same_as(scan->update)) {
+        Operation op(n);
+        for (int i = 0; i < op->num_outputs(); ++i) {
+          (*vmap)[TensorKey{s->op, i}] = op.output(i);
+        }
+        s->op = op;
+      }
+    } else if (s->op.as<PlaceholderOpNode>()) {
+    } else {
+      LOG(FATAL) << "unhandled problem";
+    }
+  }
+}
+
+Tensor Schedule::cache_read(const Tensor& tensor,
+                            const std::string& scope,
+                            const Array<Operation>& readers) {
+  // create identity mapping.
+  std::ostringstream os;
+  os << tensor->op->name;
+  if (tensor->op->num_outputs() != 1) {
+    os << ".v" << tensor->value_index;
+  }
+  os << "." << scope;
+
+  Tensor cache = compute(tensor->shape, [&tensor](const Array<Var>& i) {
+      return tensor(Array<Expr>(i.begin(), i.end()));
+    }, os.str());
+  std::unordered_map<TensorKey, Tensor> vsub;
+  vsub[TensorKey{tensor->op, tensor->value_index}] = cache;
+
+  std::unordered_map<TensorKey, Tensor> vmap;
+  for (Operation op : readers) {
+    const ComputeOpNode* compute = op.as<ComputeOpNode>();
+    CHECK(compute)
+        << "cache read only take ComputeOp as readers";
+    Stage s = operator[](op);
+    compute = s->op.as<ComputeOpNode>();
+
+    TensorReplacer repl(vsub);
+    Expr body = repl.Mutate(compute->body);
+    CHECK(repl.found)
+        << "Cannot find " << tensor
+        << " in the body of specified reader " << op;
+    Operation repl_op = ComputeOpNode::make(
+        compute->name, compute->axis, body);
+    vmap[TensorKey{s->op, 0}] = repl_op.output(0);
+    s->op = repl_op;
+  }
+  ReplaceDataFlow((*this)->stages, &vmap);
+  ArrayNode* stages = (*this)->stages.CopyOnWrite();
+  size_t pos = FindNodeRef(stages, operator[](tensor->op));
+  Stage cache_stage = Stage(cache->op);
+  cache_stage.set_scope(scope);
+  CHECK_LT(pos, stages->data.size());
+  stages->data.insert(stages->data.begin() + pos + 1,
+                      cache_stage.node_);
+  (*this)->stage_map.Set(cache->op, cache_stage);
+  return cache;
+}
+
+Tensor Schedule::cache_write(const Tensor& tensor,
+                             const std::string& scope) {
+  Stage orig_stage = operator[](tensor->op);
+  const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
+  CHECK(compute)
+      << "cache write only take ComputeOp as writers";
+  CHECK_EQ(orig_stage->relations.size(), 0U)
+      << "Create cache_write before doing split/fuse/reorder";
+  compute = orig_stage->op.as<ComputeOpNode>();
+  CHECK(compute);
+  Array<Expr> args;
+  Array<IterVar> new_axis;
+  std::unordered_map<const Variable*, Expr> vsub;
+  for (IterVar iv : compute->axis) {
+    args.push_back(iv->var);
+    IterVar new_iv(iv->dom, iv->var->name_hint + ".c");
+    new_axis.push_back(new_iv);
+    vsub[iv->var.get()] = new_iv->var;
+  }
+  VarReplacer repl(vsub);
+  Expr body = repl.Mutate(compute->body);
+  Operation cache_op = ComputeOpNode::make(
+      compute->name + "." + scope, new_axis, body);
+  Tensor cache_tensor = cache_op.output(0);
+  Operation orig_new_op = ComputeOpNode::make(
+      compute->name, compute->axis,
+      cache_tensor(args));
+
+  std::unordered_map<TensorKey, Tensor> vmap;
+  vmap[TensorKey{orig_stage->op, 0}] = orig_new_op.output(0);
+  ReplaceDataFlow((*this)->stages, &vmap);
+
+  // mutate orig stage
+  orig_stage->op = orig_new_op;
+  orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
+  orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
+  // create schedule for new cached stage.
+  ArrayNode* stages = (*this)->stages.CopyOnWrite();
+  size_t pos = FindNodeRef(stages, orig_stage);
+  Stage cache_stage = Stage(cache_op);
+  cache_stage.set_scope(scope);
+  CHECK_LT(pos, stages->data.size());
+  stages->data.insert(stages->data.begin() + pos,
+                      cache_stage.node_);
+  (*this)->stage_map.Set(cache_op, cache_stage);
+  return cache_tensor;
+}
+
+
+void RebaseNonZeroMinLoop(const Schedule& sch) {
+  std::unordered_map<IterVar, IterVar> rebase_map;
+  std::unordered_map<const Node*, int> attach_mark;
+
+  for (Stage s : sch->stages) {
+    if (s->attach_type == kScope) {
+      attach_mark[s->attach_stage.get()] = 1;
+    }
+    if (s->op.as<ScanOpNode>()) {
+      attach_mark[s.get()] = 1;
+    }
+  }
+
+  for (Stage s : sch->stages) {
+    if (!attach_mark.count(s.get())) continue;
+    auto root_iter_vars = s->op->root_iter_vars();
+    ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
+    for (IterVar iv : root_iter_vars) {
+      size_t idx = FindNodeRef(leaf_vars, iv);
+      if (idx < leaf_vars->data.size()) {
+        // insert rebase
+        IterVar rebased(Range(), iv->var->name_hint + ".rb");
+        s->relations.push_back(RebaseNode::make(iv, rebased));
+        leaf_vars->data[idx] = rebased.node_;
+        rebase_map[iv] = rebased;
+      }
+    }
+  }
+  // remap the parent relation
+  for (Stage s : sch->stages) {
+    if (s->attach_type != kScope) continue;
+    if (rebase_map.count(s->attach_ivar)) {
+      s->attach_ivar = rebase_map.at(s->attach_ivar);
+    }
+  }
+}
+
+void SetScanAttach(const Schedule& sch) {  // NOLINT(*)
+  for (Stage stage : sch->stages) {
+    if (stage->attach_type == kScanUpdate) {
+      const Stage& parent = stage->attach_stage;
+      stage->attach_ivar =
+          parent->leaf_iter_vars[parent->leaf_iter_vars.size() - 1];
+    }
+  }
+}
+
+
+void InjectInline(const Schedule& sch) {
+  std::vector<Expr> new_body(sch->stages.size());
+  // inline all the ops
+  for (size_t i = sch->stages.size(); i != 0; --i) {
+    Stage stage = sch->stages[i - 1];
+    if (stage->attach_type == kInline) {
+      stage->attach_type = kInlinedAlready;
+      Array<Var> args;
+      Expr body;
+      {
+        // setup args
+        const ComputeOpNode* compute = stage->op.as<ComputeOpNode>();
+        CHECK(compute)
+            << "can only inline compute op";
+        for (auto iv : compute->axis) {
+          args.push_back(iv->var);
+        }
+        body = compute->body;
+      }
+      for (size_t j = i; j < sch->stages.size(); ++j) {
+        Stage s = sch->stages[j];
+        const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
+        if (compute) {
+          if (!new_body[j].defined()) {
+            new_body[j] = s->op.as<ComputeOpNode>()->body;
+          }
+          new_body[j] = ir::Inline(ir::Evaluate::make(new_body[j]),
+                                   stage->op, args, body).as<ir::Evaluate>()->value;
+        }
+      }
+    }
+  }
+  std::unordered_map<TensorKey, Tensor> repl;
+  // rewrite dataflow
+  for (size_t i = 0; i < sch->stages.size(); ++i) {
+    if (new_body[i].defined() &&
+        !new_body[i].same_as(sch->stages[i]->op)) {
+      const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>();
+      CHECK(compute);
+      Operation op = ComputeOpNode::make(
+          compute->name, compute->axis, new_body[i]);
+      repl[TensorKey{sch->stages[i]->op, 0}] = op.output(0);
+      Stage s = sch->stages[i];
+      s->op = op;
+    }
+  }
+  ReplaceDataFlow(sch->stages, &repl);
+}
+
+void Schedule::normalize() {
+  RebaseNonZeroMinLoop(*this);
+  SetScanAttach(*this);
+  InjectInline(*this);
+}
+
+}  // namespace tvm
diff --git a/src/schedule/schedule_lang.cc b/src/schedule/schedule_lang.cc
index b18ae28e5..308070a8b 100644
--- a/src/schedule/schedule_lang.cc
+++ b/src/schedule/schedule_lang.cc
@@ -1,6 +1,6 @@
 /*!
  *  Copyright (c) 2016 by Contributors
- * \file schedule.cc
+ * \file schedule_lang.cc
  */
 #include <tvm/schedule.h>
 #include <tvm/ir_mutator.h>
@@ -37,6 +37,10 @@ size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v)
 
 void Split(StageNode* self, IterVar parent,
            IterVar outer, IterVar inner, Expr factor) {
+  if (self->attach_type == kScanUpdate) {
+    CHECK(!parent.same_as(self->all_iter_vars[0]))
+        << "Cannot split on axis[0] of scan update";
+  }
   ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
   ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
   size_t pos = FindLeafVar(all_vars, leaf_vars, parent);
@@ -83,6 +87,8 @@ Stage& Stage::set_scope(std::string scope) {  // NOLINT(*)
 }
 
 Stage& Stage::compute_at(Stage parent, IterVar scope) {   // NOLINT(*)
+  CHECK_NE((*this)->attach_type, kScanUpdate)
+      << "Cannot specify compute_at for scan updates";
   (*this)->attach_type = kScope;
   (*this)->attach_ivar = scope;
   (*this)->attach_stage = parent;
@@ -93,16 +99,22 @@ Stage& Stage::compute_at(Stage parent, IterVar scope) {   // NOLINT(*)
     }
   }
   CHECK(found)
-      << "Cannot find the axis in parent's leaf_iter_vars or outermost_threads";
+      << "Cannot find the axis " << scope
+      << " in parent's leaf_iter_vars or outermost_threads:"
+      << " parent=" << parent;
   return *this;
 }
 
 Stage& Stage::compute_inline() {   // NOLINT(*)
+  CHECK_NE((*this)->attach_type, kScanUpdate)
+      << "Cannot specify compute_at for scan updates";
   (*this)->attach_type = kInline;
   return *this;
 }
 
 Stage& Stage::compute_root() {   // NOLINT(*)
+  CHECK_NE((*this)->attach_type, kScanUpdate)
+      << "Cannot specify compute_at for scan updates";
   (*this)->attach_type = kRoot;
   return *this;
 }
@@ -128,9 +140,15 @@ Stage& Stage::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor
 }
 
 Stage& Stage::fuse(IterVar inner, IterVar outer, IterVar* p_target) {  // NOLINT(*)
+  StageNode* self = operator->();
+  if (self->attach_type == kScanUpdate) {
+    CHECK(!inner.same_as(self->all_iter_vars[0]))
+        << "Cannot split on axis[0] of scan update";
+    CHECK(!outer.same_as(self->all_iter_vars[0]))
+        << "Cannot split on axis[0] of scan update";
+  }
   IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused");
   *p_target = fused;
-  StageNode* self = operator->();
   ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite();
   ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite();
 
@@ -157,6 +175,10 @@ Stage& Stage::reorder(const Array<IterVar>& order) {  // NOLINT(*)
   std::vector<size_t> pos;
 
   for (size_t i = 0; i < order.size(); ++i) {
+    if ((*this)->attach_type == kScanUpdate) {
+      CHECK(!order[i].same_as(self->all_iter_vars[0]))
+          << "Cannot split on axis[0] of scan update";
+    }
     pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i]));
   }
   std::vector<std::shared_ptr<Node> > temp;
@@ -239,12 +261,25 @@ Schedule::Schedule(Array<Operation> ops) {
     stage->is_output = output_set.count(op);
     n->stages.push_back(stage);
     n->stage_map.Set(op, stage);
+    // mark scan updates.
+    if (op.as<ScanOpNode>()) {
+      const ScanOpNode* scan = op.as<ScanOpNode>();
+      for (size_t i = 0; i < scan->update.size(); ++i) {
+        Stage s = n->stage_map[scan->update[i]->op];
+        s->attach_type = kScanUpdate;
+        s->attach_stage = stage;
+      }
+    }
   }
   node_ = std::move(n);
 }
 
 Stage Schedule::operator[](const Operation& op) {
-  return (*this)->stage_map.at(op);
+  auto it = (*this)->stage_map.find(op);
+  CHECK(it != (*this)->stage_map.end())
+      << "Cannot find Stage for operator " << op
+      << " in the schedule";
+  return (*it).second;
 }
 
 IterVarRelation SplitNode::make(
@@ -274,42 +309,6 @@ IterVarRelation RebaseNode::make(IterVar parent, IterVar rebased) {
   return IterVarRelation(n);
 }
 
-void Schedule::normalize() {
-  std::unordered_map<IterVar, IterVar> rebase_map;
-  std::unordered_map<const Node*, int> attach_mark;
-
-
-  for (Stage s : (*this)->stages) {
-    if (s->attach_type == kScope) {
-      attach_mark[s->attach_stage.get()] = 1;
-    }
-  }
-
-  for (Stage s : (*this)->stages) {
-    if (!attach_mark.count(s.get())) continue;
-    auto root_iter_vars = s->op->root_iter_vars();
-    ArrayNode* leaf_vars = s->leaf_iter_vars.CopyOnWrite();
-
-    for (IterVar iv : root_iter_vars) {
-      size_t idx = FindNodeRef(leaf_vars, iv);
-      if (idx < leaf_vars->data.size()) {
-        // insert rebase
-        IterVar rebased(Range(), iv->var->name_hint + ".rb");
-        s->relations.push_back(RebaseNode::make(iv, rebased));
-        leaf_vars->data[idx] = rebased.node_;
-        rebase_map[iv] = rebased;
-      }
-    }
-  }
-  // remap the parent relation
-  for (Stage s : (*this)->stages) {
-    if (s->attach_type != kScope) continue;
-    if (rebase_map.count(s->attach_ivar)) {
-      s->attach_ivar = rebase_map.at(s->attach_ivar);
-    }
-  }
-}
-
 IterVarAttr::IterVarAttr(IterVarType t) {
   std::shared_ptr<IterVarAttrNode> n = std::make_shared<IterVarAttrNode>();
   n->iter_type = t;
@@ -323,190 +322,4 @@ TVM_REGISTER_NODE_TYPE(FuseNode);
 TVM_REGISTER_NODE_TYPE(RebaseNode);
 TVM_REGISTER_NODE_TYPE(ScheduleNode);
 
-using ir::TensorKey;
-
-// The replacer of cache.
-class TensorReplacer : public ir::IRMutator {
- public:
-  TensorReplacer(const std::unordered_map<TensorKey, Tensor>& vmap)
-      : vmap_(vmap) {}
-  Expr Mutate_(const ir::Call* op, const Expr& e) {
-    if (op->call_type == ir::Call::Halide) {
-      ir::TensorKey key{op->func, op->value_index};
-      auto it = vmap_.find(key);
-      if (it != vmap_.end()) {
-        Expr ret = ir::Call::make(
-            op->type, it->second->op->name, op->args,
-            op->call_type, it->second->op, it->second->value_index);
-        found = true;
-        return IRMutator::Mutate_(ret.as<ir::Call>(), ret);
-      }
-    }
-    return IRMutator::Mutate_(op, e);
-  }
-
-  // whether it is found.
-  bool found{false};
-
- private:
-  const std::unordered_map<TensorKey, Tensor>& vmap_;
-};
-
-class VarReplacer : public ir::IRMutator {
- public:
-  explicit VarReplacer(
-      const std::unordered_map<const Variable*, Expr>& vsub)
-      : vsub_(vsub) {}
-  Expr Mutate_(const Variable* op, const Expr& e) {
-    auto it = vsub_.find(op);
-    if (it != vsub_.end()) return it->second;
-    return e;
-  }
-
- private:
-  const std::unordered_map<const Variable*, Expr>& vsub_;
-};
-
-// Replace data flow appears in all stages given the tensor change.
-// Also update vmap if subsequent dataflow need to be replaced.
-void ReplaceDataFlow(const Array<Stage>& stages,
-                     std::unordered_map<TensorKey, Tensor>* vmap) {
-  for (Stage s : stages) {
-    if (s->op.as<ComputeOpNode>()) {
-      const ComputeOpNode* compute = s->op.as<ComputeOpNode>();
-      TensorReplacer repl(*vmap);
-      Expr body = repl.Mutate(compute->body);
-      if (repl.found) {
-        Operation op = ComputeOpNode::make(
-            compute->name, compute->axis, body);
-        (*vmap)[TensorKey{s->op, 0}] = op.output(0);
-        s->op = op;
-      }
-    } else if (s->op.as<ScanOpNode>()) {
-      const ScanOpNode* scan = s->op.as<ScanOpNode>();
-      std::shared_ptr<ScanOpNode> n =
-          std::make_shared<ScanOpNode>(*scan);
-      // copy on write semantics ganrantees correctness
-      for (size_t i = 0; i < n->init.size(); ++i) {
-        TensorKey key{n->init[i]->op, n->init[i]->value_index};
-        if (vmap->count(key)) {
-          n->init.Set(i, vmap->at(key));
-        }
-      }
-      for (size_t i = 0; i < n->update.size(); ++i) {
-        TensorKey key{n->update[i]->op, n->update[i]->value_index};
-        if (vmap->count(key)) {
-          n->update.Set(i, vmap->at(key));
-        }
-      }
-      if (!n->init.same_as(scan->init) ||
-          !n->update.same_as(scan->update)) {
-        Operation op(n);
-        for (int i = 0; i < op->num_outputs(); ++i) {
-          (*vmap)[TensorKey{s->op, i}] = op.output(i);
-        }
-        s->op = op;
-      }
-    } else if (s->op.as<PlaceholderOpNode>()) {
-    } else {
-      LOG(FATAL) << "unhandled problem";
-    }
-  }
-}
-
-Tensor Schedule::cache_read(const Tensor& tensor,
-                            const std::string& scope,
-                            const Array<Operation>& readers) {
-  // create identity mapping.
-  std::ostringstream os;
-  os << tensor->op->name;
-  if (tensor->op->num_outputs() != 1) {
-    os << ".v" << tensor->value_index;
-  }
-  os << "." << scope;
-
-  Tensor cache = compute(tensor->shape, [&tensor](const Array<Var>& i) {
-      return tensor(Array<Expr>(i.begin(), i.end()));
-    }, os.str());
-  std::unordered_map<TensorKey, Tensor> vsub;
-  vsub[TensorKey{tensor->op, tensor->value_index}] = cache;
-
-  std::unordered_map<TensorKey, Tensor> vmap;
-  for (Operation op : readers) {
-    const ComputeOpNode* compute = op.as<ComputeOpNode>();
-    CHECK(compute)
-        << "cache read only take ComputeOp as readers";
-    Stage s = operator[](op);
-    compute = s->op.as<ComputeOpNode>();
-
-    TensorReplacer repl(vsub);
-    Expr body = repl.Mutate(compute->body);
-    CHECK(repl.found)
-        << "Cannot find " << tensor
-        << " in the body of specified reader" << op;
-    Operation repl_op = ComputeOpNode::make(
-        compute->name, compute->axis, body);
-    vmap[TensorKey{s->op, 0}] = repl_op.output(0);
-    s->op = repl_op;
-  }
-  ReplaceDataFlow((*this)->stages, &vmap);
-  ArrayNode* stages = (*this)->stages.CopyOnWrite();
-  size_t pos = FindNodeRef(stages, operator[](tensor->op));
-  Stage cache_stage = Stage(cache->op);
-  cache_stage.set_scope(scope);
-  CHECK_LT(pos, stages->data.size());
-  stages->data.insert(stages->data.begin() + pos + 1,
-                      cache_stage.node_);
-  (*this)->stage_map.Set(cache->op, cache_stage);
-  return cache;
-}
-
-Tensor Schedule::cache_write(const Tensor& tensor,
-                             const std::string& scope) {
-  Stage orig_stage = operator[](tensor->op);
-  const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
-  CHECK(compute)
-      << "cache write only take ComputeOp as writers";
-  CHECK(!orig_stage.is_scheduled())
-      << "Create cache_write before doing split/fuse/reorder";
-  compute = orig_stage->op.as<ComputeOpNode>();
-  CHECK(compute);
-  Array<Expr> args;
-  Array<IterVar> new_axis;
-  std::unordered_map<const Variable*, Expr> vsub;
-  for (IterVar iv : compute->axis) {
-    args.push_back(iv->var);
-    IterVar new_iv(iv->dom, iv->var->name_hint + ".c");
-    new_axis.push_back(new_iv);
-    vsub[iv->var.get()] = new_iv->var;
-  }
-  VarReplacer repl(vsub);
-  Expr body = repl.Mutate(compute->body);
-  Operation cache_op = ComputeOpNode::make(
-      compute->name + "." + scope, new_axis, body);
-  Tensor cache_tensor = cache_op.output(0);
-  Operation orig_new_op = ComputeOpNode::make(
-      compute->name, compute->axis,
-      cache_tensor(args));
-
-  std::unordered_map<TensorKey, Tensor> vmap;
-  vmap[TensorKey{orig_stage->op, 0}] = orig_new_op.output(0);
-  ReplaceDataFlow((*this)->stages, &vmap);
-
-  // mutate orig stage
-  orig_stage->op = orig_new_op;
-  orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
-  orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
-  // create schedule for new cached stage.
-  ArrayNode* stages = (*this)->stages.CopyOnWrite();
-  size_t pos = FindNodeRef(stages, orig_stage);
-  Stage cache_stage = Stage(cache_op);
-  cache_stage.set_scope(scope);
-  CHECK_LT(pos, stages->data.size());
-  stages->data.insert(stages->data.begin() + pos,
-                      cache_stage.node_);
-  (*this)->stage_map.Set(cache_op, cache_stage);
-  return cache_tensor;
-}
-
 }  // namespace tvm
diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc
index aa7c38363..4b7c7f886 100644
--- a/src/schedule/schedule_ops.cc
+++ b/src/schedule/schedule_ops.cc
@@ -369,7 +369,7 @@ Stmt MakeRealize(const ScanOpNode* op,
     CHECK_EQ(static_cast<size_t>(t->value_index), i);
     Halide::Internal::Region bounds;
     bounds.push_back(tdom);
-    for (size_t k = 0; k < op->update[i]->shape.size(); ++k, ++sp_idx) {
+    for (size_t k = 1; k < op->update[i]->shape.size(); ++k, ++sp_idx) {
       IterVar sp_ax = op->spatial_axis_[sp_idx];
       bounds.push_back(dom_map.at(sp_ax));
     }
@@ -561,6 +561,7 @@ class InjectScanStep : public IRMutator {
 
 Stmt InjectInline(const Operation op, Stmt body) {
   CHECK(body.defined());
+
   const ComputeOpNode* compute = op.as<ComputeOpNode>();
   CHECK(compute != nullptr)
       << "can only inline compute op";
@@ -614,7 +615,7 @@ class SchedulePostProc : public IRMutator {
         if (it->second.defined()) {
           Stmt ret = AttrStmt::make(
               it->second, op->type_key, op->value, op->body);
-          return this->Mutate_(ret.as<AttrStmt>(), ret);
+          return this->Mutate(ret);
         } else {
           return this->Mutate(op->body);
         }
@@ -631,7 +632,7 @@ class SchedulePostProc : public IRMutator {
         Stmt ret = Realize::make(
             it->second->op, it->second->value_index,
             op->type, op->bounds, op->condition, op->body);
-        return this->Mutate_(ret.as<Realize>(), ret);
+        return this->Mutate(ret);
       } else {
         return this->Mutate(op->body);
       }
@@ -644,11 +645,10 @@ class SchedulePostProc : public IRMutator {
     TensorKey key{op->func, op->value_index};
     auto it = replace_buffer_.find(key);
     if (it != replace_buffer_.end()) {
-      const Tensor& dst = it->second.first;
+      const Tensor& dst = it->second;
       Stmt ret = Provide::make(
-          dst->op, dst->value_index, op->value,
-          RewriteArgs(it->second.second, op->args));
-      return IRMutator::Mutate_(ret.as<Provide>(), ret);
+          dst->op, dst->value_index, op->value, op->args);
+      return this->Mutate(ret);
     } else {
       return IRMutator::Mutate_(op, s);
     }
@@ -659,12 +659,11 @@ class SchedulePostProc : public IRMutator {
       TensorKey key{op->func, op->value_index};
       auto it = replace_buffer_.find(key);
       if (it != replace_buffer_.end()) {
-        const Tensor& dst = it->second.first;
+        const Tensor& dst = it->second;
         Expr ret = Call::make(
-            op->type, dst->op->name,
-            RewriteArgs(it->second.second, op->args),
+            op->type, dst->op->name, op->args,
             op->call_type, dst->op, dst->value_index);
-        return IRMutator::Mutate_(ret.as<Call>(), ret);
+        return this->Mutate(ret);
       }
     }
     return IRMutator::Mutate_(op, e);
@@ -685,14 +684,14 @@ class SchedulePostProc : public IRMutator {
         const ScanOpNode* scan = s->op.as<ScanOpNode>();
         for (size_t i = 0; i < scan->update.size(); ++i) {
           Tensor t = s->origin_op.output(i);
-          AddReplace(scan->init[i], t, Expr());
-          AddReplace(scan->update[i], t, scan->scan_axis->var);
-          AddReplace(scan->state_placeholder[i], t, Expr());
+          AddReplace(scan->init[i], t);
+          AddReplace(scan->update[i], t);
+          AddReplace(scan->state_placeholder[i], t);
         }
       } else if (!s->op.same_as(s->origin_op)) {
         Tensor target = s->origin_op.output(0);
         AddReplace(s->op.output(0), target,
-                   Expr(), target, s->origin_op);
+                   target, s->origin_op);
       }
     }
   }
@@ -700,26 +699,17 @@ class SchedulePostProc : public IRMutator {
  private:
   void AddReplace(Tensor src,
                   Tensor dst,
-                  Expr head_idx,
                   Tensor repl_realize = Tensor(),
                   Operation repl_op = Operation()) {
     TensorKey key{src->op, src->value_index};
-    replace_buffer_[key] = std::make_pair(dst, head_idx);
+    replace_buffer_[key] = dst;
     replace_realize_[key] = repl_realize;
     replace_op_[src->op.get()] = repl_op;
   }
-  Array<Expr> RewriteArgs(Expr head, Array<Expr> args) {
-    if (!head.defined()) return args;
-    Array<Expr> new_args{head};
-    for (Expr e : args) {
-      new_args.push_back(e);
-    }
-    return new_args;
-  }
   // The scan value
   std::unordered_map<const Variable*, Expr> var_value_;
   // buffer replacement
-  std::unordered_map<TensorKey, std::pair<Tensor, Expr> > replace_buffer_;
+  std::unordered_map<TensorKey, Tensor> replace_buffer_;
   // buffere realization to be replaced
   std::unordered_map<TensorKey, Tensor> replace_realize_;
   // replace producer consumer.
@@ -755,10 +745,13 @@ Stmt ScheduleOps(
   // reverse the post DFS order.
   for (size_t i = sch->stages.size(); i != 0; --i) {
     Stage s = sch->stages[i - 1];
+    CHECK_NE(s->attach_type, kInline)
+        << "call schedule.normalize before scheduleops";
     // no need to specify place holder op.
     if (s->op.as<PlaceholderOpNode>()) continue;
     if (scan_attach.count(s->op)) {
-      CHECK(s->attach_type == kNone || s->attach_type == kInline)
+      CHECK(s->attach_type == kNone ||
+            s->attach_type == kScanUpdate)
           << "Cannot specify compute_at for scan's init/update";
       CHECK(body.defined());
       const auto& p = scan_attach.at(s->op);
@@ -766,8 +759,8 @@ Stmt ScheduleOps(
       body = mu.Mutate(body);
       CHECK(mu.found_attach)
           << "did not find attachment point for scan.init/update";
-    } else if (s->attach_type == kInline) {
-      body = InjectInline(s->op, body);
+    } else if (s->attach_type == kInlinedAlready) {
+      // do nothing
     } else if (s->attach_type == kRoot || s-> attach_type == kNone) {
       body = MakePipeline(s, dom_map, body);
     } else if (s->attach_type == kScope) {
diff --git a/tests/python/integration/test_scan.py b/tests/python/integration/test_scan.py
index 38cd832f2..08adab491 100644
--- a/tests/python/integration/test_scan.py
+++ b/tests/python/integration/test_scan.py
@@ -8,8 +8,8 @@ def test_scan():
     X = tvm.placeholder((m, n), name="X")
     s_state = tvm.placeholder((m, n))
     s_init = tvm.compute((1, n), lambda _, i: X[0, i])
-    s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + X[t, i])
-    res = tvm.scan(t, s_init, s_update, s_state)
+    s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
+    res = tvm.scan(s_init, s_update, s_state)
 
     # schedule
     s = tvm.Schedule(res.op)
@@ -18,7 +18,7 @@ def test_scan():
     thread_x = tvm.IterVar((0, num_thread), thread_tag="threadIdx.x")
     _, x = s[s_init].split(s_init.op.axis[1], factor=num_thread, outer=block_x)
     _, x = s[s_init].split(x, outer=thread_x)
-    _, x = s[s_update].split(s_update.op.axis[0], factor=num_thread, outer=block_x)
+    _, x = s[s_update].split(s_update.op.axis[1], factor=num_thread, outer=block_x)
     _, x = s[s_update].split(x, outer=thread_x)
 
     # one line to build the function.
diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py
index 3459e80e9..c5dfb748d 100644
--- a/tests/python/unittest/test_lang_tensor.py
+++ b/tests/python/unittest/test_lang_tensor.py
@@ -40,9 +40,8 @@ def test_tensor_scan():
     t = tvm.IterVar((1, m), "t")
     x = tvm.placeholder((m, n))
     s = tvm.placeholder((m, n))
-    res = tvm.scan(t,
-                   tvm.compute((1, n), lambda _, i: x[0, i]),
-                   tvm.compute((n,), lambda i: s[t-1, i] + x[t, i]),
+    res = tvm.scan(tvm.compute((1, n), lambda _, i: x[0, i]),
+                   tvm.compute((m, n), lambda t, i: s[t-1, i] + x[t, i]),
                    s)
     assert tuple(res.shape) == (m, n)
 
diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py
index e80fb275c..3e187766f 100644
--- a/tests/python/unittest/test_schedule_bound_inference.py
+++ b/tests/python/unittest/test_schedule_bound_inference.py
@@ -50,25 +50,30 @@ def test_bound3():
     assert(bounds[A1.op.axis[0]].extent.value==32)
     assert(bounds[A1.op.axis[1]].extent.value==16)
 
+def test_bound_scan():
+    m = tvm.Var("m")
+    n = tvm.Var("n")
+    X = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
+    s_state = tvm.placeholder((m, n))
+    s_init = tvm.compute((1, n), lambda _, i: X[0, i])
+    s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + X[t, i])
+    s_scan = tvm.scan(s_init, s_update, s_state)
 
-def test_create_read_graph():
-    m = tvm.Var('m')
-    l = tvm.Var('l')
-    A = tvm.placeholder((m, l), name='A')
-    A1 = tvm.compute((m, l), lambda i, j: A[i, j])
-    A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3)
+    assert tuple(s_scan.shape) == (m, n)
 
-    g = tvm.schedule.CreateReadGraph([A2.op])
+    s = tvm.Schedule(s_scan.op)
+    XX = s.cache_read(X, "local", s_update)
+    xo, xi = s[s_update].split(s_update.op.axis[1], factor=4)
+    s[XX].compute_at(s[s_update], xo)
 
-    assert g[A2.op][0] == A1
-    assert g[A1.op][0] == A
-    post_order = tvm.schedule.PostDFSOrder([A2.op], g)
-    assert(post_order[0] == A.op)
-    assert(post_order[1] == A1.op)
+    s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+    assert bounds[XX.op.axis[1]].extent.value == 4
 
 
 if __name__ == "__main__":
-    test_create_read_graph()
+    test_bound_scan()
     test_bound3()
     test_bound1()
     test_bound2()
diff --git a/tests/python/unittest/test_schedule_graph.py b/tests/python/unittest/test_schedule_graph.py
new file mode 100644
index 000000000..2d1af01d7
--- /dev/null
+++ b/tests/python/unittest/test_schedule_graph.py
@@ -0,0 +1,101 @@
+import tvm
+
+def test_scan():
+    m = tvm.Var("m")
+    n = tvm.Var("n")
+    x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
+    s_state = tvm.placeholder((m, n))
+    s_init = tvm.compute((1, n), lambda _, i: x[0, i], name="s_init")
+    x_trans = tvm.compute((m, n), lambda i, j: x[i, j] + 1, name="x_trans")
+    s_up1 = tvm.compute((m, n), lambda t, i: s_state[t - 1, i] + 1, name="up1")
+    s_update = tvm.compute((m, n), lambda t, i: s_up1[t, i] + x_trans[t, i], name="update")
+    s_scan = tvm.scan(s_init, s_update, s_state)
+
+    def test_getbody():
+        body = tvm.schedule.ScanGetBody(s_scan.op)
+        assert set(body) == set([s_scan.op, s_update.op, s_up1.op])
+
+    def test_attach_path():
+        s = tvm.Schedule(s_scan.op)
+        s[x_trans].compute_at(s[s_update], s_update.op.axis[0])
+        apath = tvm.schedule.CreateAttachPath(s)
+        assert(tuple(apath[s_update.op]) == tuple([s_scan.op.scan_axis]))
+        assert(tuple(apath[x_trans.op]) == tuple([s_update.op.axis[0], s_scan.op.scan_axis]))
+
+    def test_fix_pt():
+        body = tvm.schedule.ScanGetBody(s_scan.op)
+        fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
+        assert(fxpt[s_scan.spatial_axis_[0]].value != 0)
+
+def test_scan_fix_point():
+    m = tvm.Var("m")
+    n = tvm.Var("n")
+    l = tvm.Var("l")
+    x = tvm.compute((l, m, n), lambda *i: tvm.const(1, "float32"), name="x")
+    s_state = tvm.placeholder((l, m, n))
+    s_init = tvm.compute((1, m, n), lambda _, i, j: x[0, i, j], name="s_init")
+
+    def test_scan0():
+        s_update = tvm.compute((l, m, n),
+                               lambda t, i, j: x[t, j, i]  + s_state[t-1, i, j], name="update")
+        s_scan = tvm.scan(s_init, s_update, s_state)
+        body = tvm.schedule.ScanGetBody(s_scan.op)
+        fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
+        assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
+        assert(fxpt[s_scan.op.spatial_axis_[1]].value == 1)
+
+    def test_scan1():
+        s_update = tvm.compute((l, m, n),
+                               lambda t, i, j: x[t, j, i]  + s_state[t-1, j, i], name="update")
+        s_scan = tvm.scan(s_init, s_update, s_state)
+        body = tvm.schedule.ScanGetBody(s_scan.op)
+        fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
+        assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
+        assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
+
+    def test_scan3_not_exact_reach():
+        s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t-1, i, j], name="h1")
+        s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t-1, i, 10] * 2, name="h1")
+        s_update = tvm.compute((l, m, n), lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
+        s_scan = tvm.scan(s_init, s_update, s_state)
+        body = tvm.schedule.ScanGetBody(s_scan.op)
+        fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
+        assert(fxpt[s_scan.op.spatial_axis_[0]].value == 1)
+        assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
+
+    def test_scan4_reach_other():
+        s_h1 = tvm.compute((l, n, m), lambda t, j, i: s_state[t-1, j, j], name="h1")
+        s_h2 = tvm.compute((l, m, n), lambda t, i, j: s_state[t-1, i, j] * 2, name="h1")
+        s_update = tvm.compute((l, m, n),
+                               lambda t, i, j: s_h1[t, j, i] + s_h2[t, i, j], name="update")
+        s_scan = tvm.scan(s_init, s_update, s_state)
+        body = tvm.schedule.ScanGetBody(s_scan.op)
+        fxpt = tvm.schedule.ScanFixPointAnalysis(s_scan.op, body)
+        assert(fxpt[s_scan.op.spatial_axis_[0]].value == 0)
+        assert(fxpt[s_scan.op.spatial_axis_[1]].value == 0)
+
+    test_scan0()
+    test_scan1()
+    test_scan3_not_exact_reach()
+    test_scan4_reach_other()
+
+def test_create_read_graph():
+    m = tvm.Var('m')
+    l = tvm.Var('l')
+    A = tvm.placeholder((m, l), name='A')
+    A1 = tvm.compute((m, l), lambda i, j: A[i, j])
+    A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3)
+
+    g = tvm.schedule.CreateReadGraph([A2.op])
+
+    assert g[A2.op][0] == A1
+    assert g[A1.op][0] == A
+    post_order = tvm.schedule.PostDFSOrder([A2.op], g)
+    assert(post_order[0] == A.op)
+    assert(post_order[1] == A1.op)
+
+
+if __name__ == "__main__":
+    test_scan()
+    test_create_read_graph()
+    test_scan_fix_point()
diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py
index 625bee596..f24e7ffd1 100644
--- a/tests/python/unittest/test_schedule_schedule_ops.py
+++ b/tests/python/unittest/test_schedule_schedule_ops.py
@@ -43,13 +43,11 @@ def test_schedule2():
 def test_schedule_scan():
     m = tvm.Var("m")
     n = tvm.Var("n")
-    l = tvm.Var("l")
-    t = tvm.IterVar((1, m), name="t")
     x = tvm.compute((m, n), lambda i, j: tvm.const(1, "float32"), name="x")
     s_state = tvm.placeholder((m, n))
     s_init = tvm.compute((1, n), lambda _, i: x[0, i])
-    s_update = tvm.compute((n,), lambda i: s_state[t-1, i] + x[t, i])
-    res = tvm.scan(t, s_init, s_update, s_state)
+    s_update = tvm.compute((m, n), lambda t, i: s_state[t-1, i] + x[t, i])
+    res = tvm.scan(s_init, s_update, s_state)
 
     assert tuple(res.shape) == (m, n)
     s = tvm.Schedule(res.op)
@@ -59,7 +57,6 @@ def test_schedule_scan():
     stmt = tvm.schedule.ScheduleOps(s, bounds)
     print(stmt)
 
-
 def test_auto_inline():
     m = tvm.Var('m')
     n = tvm.Var('n')
@@ -71,9 +68,27 @@ def test_auto_inline():
 
     s = tvm.Schedule(T2.op)
     tvm.schedule.AutoInlineElemWise(s)
+    s.normalize()
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
 
+def test_inline_mixed():
+    n = tvm.Var('n')
+    A = tvm.placeholder((n, ), name='A')
+    A1 = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='A1')
+    A2 = tvm.compute(A.shape, lambda *i: A1(*i) + 2, name='A2')
+    C = tvm.compute((n,), lambda i: A2[i] + A1[i], name='C')
+
+    s = tvm.Schedule(C.op)
+    xo, xi = s[C].split(C.op.axis[0], factor=8)
+    s[A1].compute_at(s[C], xo)
+    s[A2].compute_inline()
+    s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+    print(stmt)
+
+
 def test_schedule_cache():
     m = tvm.Var('m')
     n = tvm.Var('n')
@@ -90,9 +105,10 @@ def test_schedule_cache():
 
 
 if __name__ == "__main__":
+    test_inline_mixed()
+    test_auto_inline()
     test_schedule_scan()
     test_schedule0()
     test_schedule1()
     test_schedule2()
-    test_auto_inline()
     test_schedule_cache()
-- 
GitLab