From 400c1c483e7aa9aaebd93f3552d1e8e31697e497 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Thu, 7 Sep 2017 15:50:25 -0700
Subject: [PATCH] [SCHEDULE] Enhance cache_write to enable layout change.
 (#432)

* [SCHEDULE] Enahance cache_write to enable layout change.

* more tests
---
 include/tvm/schedule.h                        |  11 +-
 python/tvm/schedule.py                        |   8 +
 src/op/compute_op.cc                          |   9 +-
 src/op/cross_thread_reduction.cc              |   6 +-
 src/op/op_util.cc                             |  85 ----------
 src/op/op_util.h                              |  17 +-
 src/op/scan_op.cc                             |   2 +-
 src/schedule/message_passing.cc               | 140 +++++++++++++++-
 src/schedule/message_passing.h                |  32 ++++
 src/schedule/schedule_dataflow_rewrite.cc     | 154 ++++++++++++++----
 .../unittest/test_schedule_schedule_ops.py    |  52 +++++-
 11 files changed, 368 insertions(+), 148 deletions(-)

diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
index aeb5ffa66..957b425a9 100644
--- a/include/tvm/schedule.h
+++ b/include/tvm/schedule.h
@@ -284,8 +284,15 @@ class Schedule : public NodeRef {
   /*!
    * \brief Create a cache write tensor for producing tensor.
    *  The the tensor will take over body of original tensor op.
-   *  The original tensor's body will be changed to an identity read
-   *  from the corresponding cache.
+   *
+   *  This function can be used to do data layout transformation.
+   *  If there is a split/fuse/reorder on the data parallel axis of tensor
+   *  before cache_write is called. The intermediate cache stores
+   *  the data in the layout as the iteration order of leave axis.
+   *  The data will be transformed back to the original layout in the original tensor.
+   *  User can further call compute_inline to inline the original layout and keep
+   *  the data stored in the transformed layout.
+   *
    * \param tensor The tensor to be produced.
    * \param scope The scope of the storage.
    * \return The created tensor.
diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py
index ecaeb50bc..26be2de1a 100644
--- a/python/tvm/schedule.py
+++ b/python/tvm/schedule.py
@@ -248,6 +248,14 @@ class Schedule(NodeBase):
         This will mutate the body of the tensor.
         A new cache stage will created before feed into the tensor.
 
+        This function can be used to support data layout transformation.
+        If there is a split/fuse/reorder on the data parallel axis of tensor
+        before cache_write is called. The intermediate cache stores
+        the data in the layout as the iteration order of leave axis.
+        The data will be transformed back to the original layout in the original tensor.
+        User can further call compute_inline to inline the original layout and keep
+        the data stored in the transformed layout.
+
         Parameters
         ----------
         tensor : Tensor
diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc
index c7e1b54a4..89d98770b 100644
--- a/src/op/compute_op.cc
+++ b/src/op/compute_op.cc
@@ -383,8 +383,9 @@ ComputeLoopNest ComputeLoopNest::make(
   // make main loop nest
   ret.main_nest = op::MakeLoopNest(
       stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap);
-  ret.main_predicates = op::MakeBoundCheck(stage, dom_map, false,
-      std::unordered_set<IterVar>(), ret.main_vmap);
+  ret.main_predicates = schedule::MakeBoundCheck(
+      stage, dom_map, ret.main_vmap, false,
+      std::unordered_set<IterVar>());
   for (auto& e : ret.main_predicates) {
     e = likely(e);
   }
@@ -424,8 +425,8 @@ ComputeLoopNest ComputeLoopNest::make(
     ret.init_nest = op::MakeLoopNest(
         stage, dom_map, begin_loop, true,
         skip_iter, &(ret.init_vmap));
-    ret.init_predicates = op::MakeBoundCheck(
-        stage, dom_map, true, skip_iter, ret.init_vmap);
+    ret.init_predicates = schedule::MakeBoundCheck(
+        stage, dom_map, ret.init_vmap, true, skip_iter);
     for (auto& e : ret.init_predicates) {
       e = likely(e);
     }
diff --git a/src/op/cross_thread_reduction.cc b/src/op/cross_thread_reduction.cc
index e79b81c8c..6eec3bd69 100644
--- a/src/op/cross_thread_reduction.cc
+++ b/src/op/cross_thread_reduction.cc
@@ -21,9 +21,9 @@ Stmt MakeCrossThreadReduction(
   std::unordered_map<IterVar, Expr> value_map;
   auto nest = op::MakeLoopNest(
       stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map);
-  auto conds = op::MakeBoundCheck(
-      stage, dom_map, false,
-      std::unordered_set<IterVar>(), value_map);
+  auto conds = schedule::MakeBoundCheck(
+      stage, dom_map, value_map, false,
+      std::unordered_set<IterVar>());
 
   size_t size = self->body.size();
   CHECK_GT(size, 0);
diff --git a/src/op/op_util.cc b/src/op/op_util.cc
index ea64bacdc..cd0d5e436 100644
--- a/src/op/op_util.cc
+++ b/src/op/op_util.cc
@@ -147,91 +147,6 @@ MakeLoopNest(const Stage& stage,
   return nest;
 }
 
-
-/*!
- * \brief message passing to find if boundary checking on IterVar is needed.
- * \param s The stage to be used.
- * \param p_state The message passing state
- *     IterVar->flag
- */
-void PassUpBoundCheck(const Stage& s,
-                      const Map<IterVar, Range>& dom_map,
-                      std::unordered_map<IterVar, bool>* p_state) {
-  auto& state = *p_state;
-  using Halide::Internal::can_prove;
-  for (size_t i = s->relations.size(); i != 0; --i) {
-    IterVarRelation rel = s->relations[i - 1];
-    if (rel.as<SplitNode>()) {
-      const SplitNode* s = rel.as<SplitNode>();
-      bool outer = state.at(s->outer);
-      bool inner = state.at(s->inner);
-      Expr factor = dom_map.at(s->inner)->extent;
-      Expr step = dom_map.at(s->outer)->extent;
-
-      if (outer || inner) {
-        state[s->parent] = true;
-      } else {
-        if (can_prove(dom_map.at(s->parent)->extent == factor * step)) {
-          state[s->parent] = false;
-        } else {
-          state[s->parent] = true;
-        }
-      }
-    } else if (rel.as<FuseNode>()) {
-      const FuseNode* s = rel.as<FuseNode>();
-      bool fused = state.at(s->fused);
-      state[s->outer] = fused;
-      state[s->inner] = fused;
-    } else if (rel.as<RebaseNode>()) {
-      const RebaseNode* s = rel.as<RebaseNode>();
-      state[s->parent] = state.at(s->rebased);
-    } else {
-      LOG(FATAL) << "unknown relation type";
-    }
-  }
-}
-
-std::vector<Expr> MakeBoundCheck(
-    const Stage& stage,
-    const Map<IterVar, Range>& dom_map,
-    bool skip_ivar_domain,
-    const std::unordered_set<IterVar>& skip_iter,
-    const std::unordered_map<IterVar, Expr>& value_map) {
-  std::unordered_map<IterVar, bool> bound_state;
-  for (IterVar iv : stage->leaf_iter_vars) {
-    bound_state[iv] = false;
-  }
-  PassUpBoundCheck(stage, dom_map, &bound_state);
-  std::vector<Expr> preds;
-  std::unordered_map<const Variable*, IntSet> iset_dmap;
-
-  // setup domain map for set analysis
-  for (const auto& kv : dom_map) {
-    iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
-  }
-
-  for (IterVar iv : stage->op->root_iter_vars()) {
-    if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
-    Range dom = dom_map.at(iv);
-    if (bound_state.at(iv)) {
-      Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min);
-      Expr vmax = EvalSet(value, iset_dmap).max();
-      if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) {
-        preds.emplace_back(value < dom->extent);
-      }
-    }
-    CHECK(iv->dom.defined());
-    if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
-      Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min);
-      Expr vmax = EvalSet(value, iset_dmap).max();
-      if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) {
-        preds.emplace_back(value < iv->dom->extent);
-      }
-    }
-  }
-  return preds;
-}
-
 std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) {
   Stmt no_op = Evaluate::make(0);
   std::vector<Stmt> nest;
diff --git a/src/op/op_util.h b/src/op/op_util.h
index 165113863..783fbb989 100644
--- a/src/op/op_util.h
+++ b/src/op/op_util.h
@@ -13,6 +13,7 @@
 #include <vector>
 #include "../pass/ir_util.h"
 #include "../pass/arg_binder.h"
+#include "../schedule/message_passing.h"
 
 namespace tvm {
 namespace op {
@@ -36,22 +37,6 @@ MakeLoopNest(const Stage& stage,
              bool new_loop_var,
              const std::unordered_set<IterVar>& skip_iter,
              std::unordered_map<IterVar, Expr>* p_value_map);
-/*!
- * \brief Create boundary check condition for given stage.
- *
- * \param stage The stage to create a loop nest.
- * \param dom_map The range of each iter var.
- * \param skip_ivar_domain Whether we can skip check for IterVar's original domain.
- * \param skip_iter Whether skip certain iteration.
- * \param value_map The result value of each IterVar.
- * \return List of predicates that we need to check.
- */
-std::vector<Expr>
-MakeBoundCheck(const Stage& stage,
-               const Map<IterVar, Range>& dom_map,
-               bool skip_ivar_domain,
-               const std::unordered_set<IterVar>& skip_iter,
-               const std::unordered_map<IterVar, Expr>& value_map);
 
 /*!
  * \brief Create a nest of if checking the predicates.
diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc
index f03eb95f1..48565b6eb 100644
--- a/src/op/scan_op.cc
+++ b/src/op/scan_op.cc
@@ -274,7 +274,7 @@ Stmt ScanOpNode::BuildProvide(
   nest[begin_scan].push_back(init);
   nest.push_back(
       op::MakeIfNest(
-          op::MakeBoundCheck(stage, dom_map, false, empty, vmap)));
+          schedule::MakeBoundCheck(stage, dom_map, vmap, false, empty)));
   return MergeNest(nest, provide);
 }
 }  // namespace tvm
diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc
index 4ba32785d..969a18ee9 100644
--- a/src/schedule/message_passing.cc
+++ b/src/schedule/message_passing.cc
@@ -7,10 +7,12 @@
 #include <tvm/ir.h>
 #include <tvm/ir_pass.h>
 #include "./message_passing.h"
+#include "../arithmetic/compute_expr.h"
 
 namespace tvm {
 namespace schedule {
 
+using namespace ir;
 using namespace arith;
 
 // result = ceil((a / b)), both a and b are positive integer
@@ -123,8 +125,8 @@ void PassUpIndex(const Stage& stage,
       Expr factor = dom_map.at(s->inner)->extent;
       Expr outer_min = dom_map.at(s->outer)->min;
       Expr inner_min = dom_map.at(s->inner)->min;
-      state[s->outer] = value / factor;
-      state[s->inner] = value % factor;
+      state[s->outer] = ComputeExpr<Div>(value, factor);
+      state[s->inner] = ComputeExpr<Mod>(value, factor);
       // add min if they exist
       if (!is_zero(outer_min)) {
         state[s->outer] = state[s->outer] + outer_min;
@@ -151,6 +153,51 @@ void PassUpIndex(const Stage& stage,
   }
 }
 
+void PassDownIndex(const Stage& stage,
+                   const Map<IterVar, Range>& dom_map,
+                   std::unordered_map<IterVar, Expr>* p_state,
+                   bool allow_missing) {
+  auto& state = *p_state;
+  for (IterVarRelation rel : stage->relations) {
+    if (const SplitNode* s = rel.as<SplitNode>()) {
+      if (!state.count(s->parent)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      Range r = dom_map.at(s->inner);
+      CHECK(is_zero(r->min));
+      Expr parent = state.at(s->parent);
+      Expr factor = r->extent;
+      state[s->outer] = ComputeExpr<Div>(parent, factor);
+      state[s->inner] = ComputeExpr<Mod>(parent, factor);
+    } else if (const FuseNode* s = rel.as<FuseNode>()) {
+      if (!state.count(s->inner) && !state.count(s->outer)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      Expr factor = dom_map.at(s->inner)->extent;
+      Expr outer_min = dom_map.at(s->outer)->min;
+      Expr inner_min = dom_map.at(s->inner)->min;
+      Expr inner = state.at(s->inner);
+      Expr outer = state.at(s->outer);
+      CHECK(is_zero(outer_min));
+      CHECK(is_zero(inner_min));
+      state[s->fused] = outer * factor + inner;
+    } else if (const RebaseNode* s = rel.as<RebaseNode>()) {
+      if (!state.count(s->rebased)) {
+        CHECK(allow_missing);
+        continue;
+      }
+      Expr value = state.at(s->parent);
+      Expr parent_min = dom_map.at(s->parent)->min;
+      CHECK(is_zero(parent_min));
+      state[s->rebased] = value;
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
 // Domain message passing.
 void PassUpDomain(const SplitNode* s,
                   const std::unordered_map<IterVar, Range>& dom_map,
@@ -349,5 +396,94 @@ void PassDownBitMaskOr(const Stage& stage,
   }
 }
 
+
+/*!
+ * \brief message passing to find if boundary checking on IterVar is needed.
+ * \param s The stage to be used.
+ * \param p_state The message passing state
+ *     IterVar->flag
+ */
+void PassUpBoundCheck(const Stage& s,
+                      const Map<IterVar, Range>& dom_map,
+                      std::unordered_map<IterVar, bool>* p_state) {
+  auto& state = *p_state;
+  using Halide::Internal::can_prove;
+  for (size_t i = s->relations.size(); i != 0; --i) {
+    IterVarRelation rel = s->relations[i - 1];
+    if (rel.as<SplitNode>()) {
+      const SplitNode* s = rel.as<SplitNode>();
+      bool outer = state.at(s->outer);
+      bool inner = state.at(s->inner);
+
+      if (dom_map.count(s->inner) && dom_map.count(s->outer)) {
+        Expr factor = dom_map.at(s->inner)->extent;
+        Expr step = dom_map.at(s->outer)->extent;
+        if (outer || inner) {
+          state[s->parent] = true;
+        } else {
+          if (can_prove(dom_map.at(s->parent)->extent == factor * step)) {
+            state[s->parent] = false;
+          } else {
+            state[s->parent] = true;
+          }
+        }
+      } else {
+        state[s->parent] = true;
+      }
+    } else if (rel.as<FuseNode>()) {
+      const FuseNode* s = rel.as<FuseNode>();
+      bool fused = state.at(s->fused);
+      state[s->outer] = fused;
+      state[s->inner] = fused;
+    } else if (rel.as<RebaseNode>()) {
+      const RebaseNode* s = rel.as<RebaseNode>();
+      state[s->parent] = state.at(s->rebased);
+    } else {
+      LOG(FATAL) << "unknown relation type";
+    }
+  }
+}
+
+std::vector<Expr> MakeBoundCheck(
+    const Stage& stage,
+    const Map<IterVar, Range>& dom_map,
+    const std::unordered_map<IterVar, Expr>& value_map,
+    bool skip_ivar_domain,
+    const std::unordered_set<IterVar>& skip_iter) {
+  std::unordered_map<IterVar, bool> bound_state;
+  for (IterVar iv : stage->leaf_iter_vars) {
+    bound_state[iv] = false;
+  }
+  PassUpBoundCheck(stage, dom_map, &bound_state);
+
+  std::vector<Expr> preds;
+  std::unordered_map<const Variable*, IntSet> iset_dmap;
+
+  // setup domain map for set analysis
+  for (const auto& kv : dom_map) {
+    iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
+  }
+
+  for (IterVar iv : stage->op->root_iter_vars()) {
+    if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
+    Range dom = dom_map.at(iv);
+    if (bound_state.at(iv)) {
+      Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min);
+      Expr vmax = EvalSet(value, iset_dmap).max();
+      if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) {
+        preds.emplace_back(value < dom->extent);
+      }
+    }
+    CHECK(iv->dom.defined());
+    if (!skip_ivar_domain && !iv->dom.same_as(dom)) {
+      Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min);
+      Expr vmax = EvalSet(value, iset_dmap).max();
+      if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) {
+        preds.emplace_back(value < iv->dom->extent);
+      }
+    }
+  }
+  return preds;
+}
 }  // namespace schedule
 }  // namespace tvm
diff --git a/src/schedule/message_passing.h b/src/schedule/message_passing.h
index 5b7cf9d24..baf4a2415 100644
--- a/src/schedule/message_passing.h
+++ b/src/schedule/message_passing.h
@@ -45,6 +45,20 @@ void PassUpIndex(const Stage& stage,
                  std::unordered_map<IterVar, Expr>* p_state,
                  bool allow_missing = false);
 
+/*!
+ * \param Downward inference of index of each IterVar.
+ *  given index assignement of roots.
+ *
+ * \param stage The stage to operate on.
+ * \param dom_map The domain map of each iteration variable's domain.
+ * \param p_state The index state of each IterVar.
+ * \param allow_missing Whether allow missing value.
+ */
+void PassDownIndex(const Stage& stage,
+                   const Map<IterVar, Range>& dom_map,
+                   std::unordered_map<IterVar, Expr>* p_state,
+                   bool allow_missing = false);
+
 /*!
  * \param Upward inference of domain set of each IterVar.
  *  given domain assignment of the leaves,
@@ -76,6 +90,24 @@ void PassUpBitMaskOr(const Stage& stage,
 void PassDownBitMaskOr(const Stage& stage,
                        std::unordered_map<IterVar, int>* p_state,
                        bool allow_missing = false);
+
+/*!
+ * \brief Create boundary check predicates given remapped value of root
+ * \param stage The stage we operate on
+ * \param dom_map The domain map of each value.
+ * \param value_map The value map of the root iter var.
+ * \param skip_ivar_domain Whether we skip check for IterVar's original domain.
+ * \param skip_iter The set of variables to skip bound condition.
+ * \return List of predicates that we need to check.
+ */
+std::vector<Expr>
+MakeBoundCheck(
+    const Stage& stage,
+    const Map<IterVar, Range>& dom_map,
+    const std::unordered_map<IterVar, Expr>& value_map,
+    bool skip_ivar_domain,
+    const std::unordered_set<IterVar>& skip_iter);
+
 }  // namespace schedule
 }  // namespace tvm
 #endif  // TVM_SCHEDULE_MESSAGE_PASSING_H_
diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
index c5aca83d5..02ebc21e2 100644
--- a/src/schedule/schedule_dataflow_rewrite.cc
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -9,6 +9,7 @@
 #include <unordered_set>
 #include "./message_passing.h"
 #include "../pass/ir_util.h"
+#include "../arithmetic/compute_expr.h"
 
 namespace tvm {
 
@@ -38,6 +39,22 @@ class VarReplacer : public ir::IRMutator {
   const std::unordered_map<const Variable*, Expr>& vsub_;
 };
 
+Expr InjectPredicate(const Array<Expr>& predicates,
+                     Expr body) {
+  using ir::Reduce;
+  using ir::Select;
+  if (predicates.size() == 0) return body;
+  const Reduce* reduce = body.as<Reduce>();
+  if (reduce) {
+    std::shared_ptr<Reduce> n = std::make_shared<Reduce>(*reduce);
+    n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates);
+    return Expr(n);
+  }
+  return Select::make(arith::ComputeReduce<ir::And>(predicates),
+                      body,
+                      make_zero(body.type()));
+}
+
 // Replace data flow appears in all stages given the tensor change.
 // Also update vmap if subsequent dataflow need to be replaced.
 void ReplaceDataFlow(const Array<Stage>& stages,
@@ -99,52 +116,101 @@ Tensor Schedule::cache_read(const Tensor& tensor,
   return cache;
 }
 
-Tensor Schedule::cache_write(const Tensor& tensor,
-                             const std::string& scope) {
-  (*this)->InvalidateCache();
-  Stage orig_stage = operator[](tensor->op);
-  const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
-  CHECK(compute)
-      << "cache write only take ComputeOp as writers";
-  CHECK_EQ(orig_stage->relations.size(), 0U)
-      << "Create cache_write before doing split/fuse/reorder";
-  compute = orig_stage->op.as<ComputeOpNode>();
-  CHECK(compute);
-  Array<Expr> args;
+
+// Cache write and relayout the data according to loop pattern
+Tensor CacheWriteWithReLayout(Schedule sch,
+                              const Tensor& tensor,
+                              const std::string& scope) {
+  sch->InvalidateCache();
+  Stage orig_stage = sch[tensor->op];
+  const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
+
+  std::unordered_set<IterVar> red_axis;
+  for (IterVar iv : compute->reduce_axis) {
+    red_axis.insert(iv);
+  }
+  std::unordered_map<IterVar, Range> dom_map;
   Array<IterVar> new_axis;
-  std::unordered_map<const Variable*, Expr> vsub;
+
   for (IterVar iv : compute->axis) {
-    args.push_back(iv->var);
-    IterVar new_iv = IterVarNode::make(
-        iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
-    new_axis.push_back(new_iv);
-    vsub[iv->var.get()] = new_iv->var;
+    dom_map[iv] = iv->dom;
+  }
+  schedule::PassDownDomain(orig_stage, &dom_map, true);
+  std::unordered_map<const Variable*, Expr> vsub;
+  std::unordered_map<const Variable*, Expr> vsub2newvar;
+  std::vector<Expr> predicates;
+  {
+    // The source->cache
+    std::unordered_map<IterVar, Expr> value_map;
+    for (IterVar iv : orig_stage->leaf_iter_vars) {
+      if (red_axis.count(iv)) continue;
+      CHECK_EQ(iv->iter_type, kDataPar)
+          << "Can only relayout with in data parallel dimensions";
+      Range dom = dom_map.at(iv);
+      IterVar new_iv = IterVarNode::make(
+          dom, iv->var.copy_with_suffix(".c"), iv->iter_type);
+      new_axis.push_back(new_iv);
+      if (is_one(dom->min)) {
+        value_map[iv] = dom->min;
+      } else {
+        value_map[iv] = iv->var;
+        vsub2newvar[iv->var.get()] = new_iv->var;
+      }
+    }
+    // skip reduction iteration.
+    std::unordered_set<IterVar> skip_bound_check;
+    for (IterVar iv : compute->reduce_axis) {
+      skip_bound_check.insert(iv);
+    }
+    schedule::PassUpIndex(orig_stage, dom_map, &value_map, true);
+    predicates = schedule::MakeBoundCheck(
+        orig_stage, dom_map, value_map, true, skip_bound_check);
+    // The root axis
+    for (IterVar iv : compute->axis) {
+      vsub[iv->var.get()] = value_map.at(iv);
+    }
+  }
+  Expr body = VarReplacer(vsub).Mutate(compute->body[tensor->value_index]);
+  body = InjectPredicate(predicates, body);
+  body = VarReplacer(vsub2newvar).Mutate(body);
+  // The reader args
+  Array<Expr> args;
+  {
+    // cache->compute
+    std::unordered_map<IterVar, Expr> value_map;
+    for (IterVar iv : compute->axis) {
+      value_map[iv] = iv->var;
+    }
+    schedule::PassDownIndex(orig_stage, dom_map, &value_map, true);
+    for (IterVar iv : orig_stage->leaf_iter_vars) {
+      if (red_axis.count(iv)) continue;
+      args.push_back(value_map.at(iv));
+    }
   }
-  VarReplacer repl(vsub);
-  Expr body = repl.Mutate(compute->body[tensor->value_index]);
   Operation cache_op = ComputeOpNode::make(
       compute->name + "." + scope, compute->tag, new_axis, {body});
   Tensor cache_tensor = cache_op.output(0);
   Operation orig_new_op = ComputeOpNode::make(
       compute->name, compute->tag, compute->axis,
       {cache_tensor(args)});
-
+  // The replace of the dataflow
   std::unordered_map<Tensor, Tensor> vmap;
   vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
-  ReplaceDataFlow((*this)->stages, &vmap);
+  ReplaceDataFlow(sch->stages, &vmap);
   // mutate orig stage
   orig_stage->op = orig_new_op;
   orig_stage->all_iter_vars = orig_stage->op->root_iter_vars();
   orig_stage->leaf_iter_vars = orig_stage->all_iter_vars;
+  orig_stage->relations = Array<IterVarRelation>();
   // create schedule for new cached stage.
-  ArrayNode* stages = (*this)->stages.CopyOnWrite();
+  ArrayNode* stages = sch->stages.CopyOnWrite();
   size_t pos = FindNodeRef(stages, orig_stage);
   Stage cache_stage = Stage(cache_op);
   cache_stage.set_scope(scope);
   CHECK_LT(pos, stages->data.size());
   stages->data.insert(stages->data.begin() + pos,
                       cache_stage.node_);
-  (*this)->stage_map.Set(cache_op, cache_stage);
+  sch->stage_map.Set(cache_op, cache_stage);
   // Update group
   cache_stage->group = orig_stage->group;
   if (cache_stage->group.defined()) {
@@ -153,6 +219,19 @@ Tensor Schedule::cache_write(const Tensor& tensor,
   return cache_tensor;
 }
 
+Tensor Schedule::cache_write(const Tensor& tensor,
+                             const std::string& scope) {
+  (*this)->InvalidateCache();
+  Stage orig_stage = operator[](tensor->op);
+  const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
+  CHECK(compute)
+      << "cache write only take ComputeOp as writers";
+  CHECK_EQ(compute->num_outputs(), 1)
+      << "cache write only support single output ComputeOp";
+
+  return CacheWriteWithReLayout(*this, tensor, scope);
+}
+
 void RebaseNonZeroMinLoop(const Schedule& sch) {
   std::unordered_map<IterVar, IterVar> rebase_map;
   for (Stage s : sch->stages) {
@@ -295,16 +374,23 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
   touch_map[axis] = 1;
   schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true);
   schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true);
+  // skip reduction iteration.
+  std::unordered_set<IterVar> skip_bound_check;
   // Verify normal axis are not touched.
   for (IterVar iv : compute_op->axis) {
     CHECK(!touch_map.count(iv))
         << "Factor axis touches normal axis.";
+    skip_bound_check.insert(iv);
   }
   // Get the replace index
   std::unordered_map<IterVar, Range> dom_map;
   std::unordered_map<IterVar, Expr> value_map;
   for (IterVar iv : compute_op->reduce_axis) {
-    if (touch_map.count(iv)) dom_map[iv] = iv->dom;
+    if (touch_map.count(iv)) {
+      dom_map[iv] = iv->dom;
+    } else {
+      skip_bound_check.insert(iv);
+    }
   }
   schedule::PassDownDomain(reduce_stage, &dom_map, true);
   for (IterVar iv : reduce_stage->leaf_iter_vars) {
@@ -318,6 +404,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
     }
   }
   schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true);
+  std::vector<Expr> predicates = schedule::MakeBoundCheck(
+      reduce_stage, dom_map, value_map, true, skip_bound_check);
+
   // Get the factored op node.
   auto n = std::make_shared<ComputeOpNode>();
   n->name = compute_op->name + ".rf";
@@ -339,8 +428,11 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
   int idx = tensor->value_index;
   const Reduce* reduce = compute_op->body[idx].as<Reduce>();
   CHECK(reduce) << "Can only rfactor non-inline reductions";
-  Expr predicate = reduce->condition;
+  predicates.push_back(reduce->condition);
+  Expr predicate = arith::ComputeReduce<ir::And>(predicates);
+
   std::unordered_map<const Variable*, Expr> vsub;
+
   for (IterVar iv : compute_op->reduce_axis) {
     if (!touch_map.count(iv)) {
       n->reduce_axis.push_back(iv);
@@ -348,16 +440,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
       CHECK(value_map.count(iv));
       Expr index = value_map.at(iv);
       vsub[iv->var.get()] = index;
-      if (!index.same_as(iv->var)) {
-        Expr cond = (index < dom_map.at(iv)->extent);
-        if (is_one(predicate)) {
-          predicate = cond;
-        } else {
-          predicate = predicate && cond;
-        }
-      }
     }
   }
+
   // Copy touched axis.
   for (IterVar iv : reduce_stage->leaf_iter_vars) {
     if (touch_map.count(iv) && !iv.same_as(axis)) {
@@ -453,4 +538,5 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor,
   reduce_stage->relations = Array<IterVarRelation>();
   return factor_tensors;
 }
+
 }  // namespace tvm
diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py
index 158d83e78..e9c23d74e 100644
--- a/tests/python/unittest/test_schedule_schedule_ops.py
+++ b/tests/python/unittest/test_schedule_schedule_ops.py
@@ -55,7 +55,6 @@ def test_schedule_scan():
     bounds = tvm.schedule.InferBound(s)
     assert(bounds[res.op.scan_axis].min.value == 1)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
-    print(stmt)
 
 def test_auto_inline():
     m = tvm.var('m')
@@ -160,7 +159,58 @@ def test_schedule_cache():
     stmt = tvm.schedule.ScheduleOps(s, bounds)
 
 
+def test_schedule_cache_relayout1():
+    m = tvm.var('m')
+    n = tvm.var('n')
+    A = tvm.placeholder((m, n), name='A')
+    B = tvm.placeholder((m, n), name='B')
+    C = tvm.compute((m, n), lambda i, j:  A(i, j) * B(i, j), name='C')
+
+    s = tvm.create_schedule(C.op)
+    s[C].reorder(C.op.axis[1], C.op.axis[0])
+    CC = s.cache_write(C, "global")
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+
+
+def test_schedule_cache_relayout2():
+    m = tvm.var('m')
+    n = tvm.var('n')
+    A = tvm.placeholder((m*4, n), name='A')
+    B = tvm.placeholder((m*4, n), name='B')
+    C = tvm.compute(A.shape, lambda i, j:  A(i, j) * B(i, j), name='C')
+    s = tvm.create_schedule(C.op)
+    x, y = C.op.axis
+    xo, xi = s[C].split(x, factor=4)
+    s[C].reorder(xo, y, xi)
+    CC = s.cache_write(C, "global")
+    s = s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+
+
+def test_schedule_cache_relayout3():
+    m = tvm.var('m')
+    n = tvm.var('n')
+    A = tvm.placeholder((m*4, n), name='A')
+    B = tvm.placeholder((m*4, n), name='B')
+    k = tvm.reduce_axis((0, n), "k")
+    C = tvm.compute((A.shape[0],),
+                    lambda i: tvm.sum(A(i, k) * B(i, k), axis=k), name='C')
+    s = tvm.create_schedule(C.op)
+    x = C.op.axis[0]
+    xo, xi = s[C].split(x, factor=4)
+    CC = s.cache_write(C, "global")
+    s = s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+
+
 if __name__ == "__main__":
+    test_schedule_cache_relayout4()
+    test_schedule_cache_relayout3()
+    test_schedule_cache_relayout2()
+    test_schedule_cache_relayout1()
     test_schedule_const_bound()
     test_scan_inline1()
     test_scan_inline2()
-- 
GitLab