diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h
index 6284906c1a4ed7fb7166e925f885dcc4ebbbe0c5..594e8b8a7c168ff62cb597084fa8a3147f172046 100644
--- a/include/tvm/schedule.h
+++ b/include/tvm/schedule.h
@@ -301,6 +301,23 @@ class Schedule : public NodeRef {
    *  User can further call compute_inline to inline the original layout and keep
    *  the data stored in the transformed layout.
+   * \param tensor The tensors to be produced.
+   * \param scope The scope of the storage.
+   * \return The created tensor.
+   */
+  EXPORT Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope);
+  /*!
+   * \brief Create a cache write tensor for producing tensor.
+   *  The the tensor will take over body of original tensor op.
+   *
+   *  This function can be used to do data layout transformation.
+   *  If there is a split/fuse/reorder on the data parallel axis of tensor
+   *  before cache_write is called. The intermediate cache stores
+   *  the data in the layout as the iteration order of leave axis.
+   *  The data will be transformed back to the original layout in the original tensor.
+   *  User can further call compute_inline to inline the original layout and keep
+   *  the data stored in the transformed layout.
+   * 
    * \param tensor The tensor to be produced.
    * \param scope The scope of the storage.
    * \return The created tensor.
diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py
index 6305cc1d5e69a44d108ebf018c7603d50d14a245..f54ae188c1309a496b7d57d93979206d9ff13036 100644
--- a/python/tvm/schedule.py
+++ b/python/tvm/schedule.py
@@ -292,8 +292,8 @@ class Schedule(NodeBase):
-        tensor : Tensor
-            The tensor to be feed to.
+        tensor : Tensor, list or tuple
+            The tensors to be feed to. All the tensors must be produced by one computeOp
         scope : str
             The scope of cached
diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc
index 03d085aa8d5924d8097341cfd92943100ca6d905..ba158a7c3f795b462c5162c1d38d4bb39002a570 100644
--- a/src/api/api_lang.cc
+++ b/src/api/api_lang.cc
@@ -425,8 +425,13 @@ TVM_REGISTER_API("_ScheduleCacheRead")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
-    *ret = args[0].operator Schedule()
-        .cache_write(args[1], args[2]);
+    if (args[1].IsNodeType<Tensor>()) {
+      *ret = args[0].operator Schedule()
+          .cache_write(args[1].operator Tensor(), args[2]);
+    } else {
+      *ret = args[0].operator Schedule()
+          .cache_write(args[1].operator Array<Tensor>(), args[2]);
+    }
diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
index 1d9dfed4e5e16fe2e1324cfd28f399d4bbe418bf..89390dcc048d992e6f1646360626312b2008f702 100644
--- a/src/schedule/schedule_dataflow_rewrite.cc
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -78,6 +78,13 @@ void ReplaceDataFlow(const Array<Stage>& stages,
+inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
+  return (a->combiner.same_as(b->combiner)) &&
+         (a->source.same_as(b->source)) &&
+         (a->axis.same_as(b->axis)) &&
+         (a->condition.same_as(b->condition));
 Tensor Schedule::cache_read(const Tensor& tensor,
                             const std::string& scope,
                             const Array<Operation>& readers) {
@@ -128,15 +135,15 @@ Tensor Schedule::cache_read(const Tensor& tensor,
   return cache;
 // Cache write and relayout the data according to loop pattern
-Tensor CacheWriteWithReLayout(Schedule sch,
-                              const Tensor& tensor,
+Array<Tensor> CacheWriteWithReLayout(Schedule sch,
+                              const Array<Tensor>& tensor_array,
                               const std::string& scope) {
+  size_t tensor_size = tensor_array.size();
+  Tensor tensor = tensor_array[0];
   Stage orig_stage = sch[tensor->op];
   const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>();
   std::unordered_set<IterVar> red_axis;
   for (IterVar iv : compute->reduce_axis) {
@@ -182,9 +189,34 @@ Tensor CacheWriteWithReLayout(Schedule sch,
       vsub[iv->var.get()] = value_map.at(iv);
-  Expr body = VarReplacer(vsub).Mutate(compute->body[tensor->value_index]);
-  body = InjectPredicate(predicates, body);
-  body = VarReplacer(vsub2newvar).Mutate(body);
+  Expr body;
+  Array<Expr> body_list;
+  const ir::Reduce* first_reduce = nullptr;
+  for (auto cbody : compute->body) {
+    body = VarReplacer(vsub).Mutate(cbody);
+    body = InjectPredicate(predicates, body);
+    body = VarReplacer(vsub2newvar).Mutate(body);
+    // Reduce nodes in ONE computeOp must be the same except value_index
+    // This is right only if the oringinal body ensures Reduce nodes are the same
+    if (body->is_type<ir::Reduce>()) {
+      const ir::Reduce* reduce_body = body.as<ir::Reduce>();
+      if (first_reduce != nullptr) {
+        CHECK(ReduceEqual(reduce_body, first_reduce));
+        body = ir::Reduce::make(first_reduce->combiner,
+                                first_reduce->source,
+                                first_reduce->axis,
+                                first_reduce->condition,
+                                reduce_body->value_index);
+      } else {
+        first_reduce = reduce_body;
+      }
+    } else {
+      CHECK(first_reduce == nullptr)
+        << "cannot mix reduce and other node in ONE compute bodys";
+    }
+    body_list.push_back(body);
+  }
   // The reader args
   Array<Expr> args;
@@ -200,16 +232,25 @@ Tensor CacheWriteWithReLayout(Schedule sch,
   Operation cache_op = ComputeOpNode::make(
-      compute->name + "." + scope, compute->tag, new_axis, {body});
-  Tensor cache_tensor = cache_op.output(0);
+      compute->name + "." + scope, compute->tag, new_axis, body_list);
+  Array<Tensor> cache_tensor_list;
+  Array<Expr> cache_expr_list;
+  for (size_t i = 0; i < tensor_size; i++) {
+    Tensor cache_tensor = cache_op.output(i);
+    cache_tensor_list.push_back(cache_tensor);
+    cache_expr_list.push_back(cache_tensor(args));
+  }
   Operation orig_new_op = ComputeOpNode::make(
-      compute->name, compute->tag, compute->axis,
-      {cache_tensor(args)});
+      compute->name, compute->tag, compute->axis, cache_expr_list);
   // The replace of the dataflow
   std::unordered_map<Tensor, Tensor> vmap;
   std::unordered_map<Tensor, Tensor> rvmap;
   vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
   rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
+  for (size_t i = 0; i < tensor_size; i++) {
+    vmap[orig_stage->op.output(0)] = orig_new_op.output(0);
+    rvmap[orig_new_op.output(0)] = orig_stage->op.output(0);
+  }
   ReplaceDataFlow(sch->stages, &vmap, &rvmap);
   // mutate orig stage
   orig_stage->op = orig_new_op;
@@ -230,7 +271,26 @@ Tensor CacheWriteWithReLayout(Schedule sch,
   if (cache_stage->group.defined()) {
-  return cache_tensor;
+  return cache_tensor_list;
+Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array,
+                             const std::string& scope) {
+  (*this)->InvalidateCache();
+  CHECK(tensor_array.size() > 0)
+      << "size of tensor_array must be greater than 0";
+  Tensor tensor = tensor_array[0];
+  Stage orig_stage = operator[](tensor->op);
+  const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>();
+  CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size())
+      << "size of input tensor list must be same as number of stage outputs";
+  for (size_t i = 1; i < tensor_array.size(); i++) {
+    Stage tmp_stage = operator[](tensor_array[i]->op);
+    CHECK(orig_stage.same_as(tmp_stage))
+        << "Input tensor list must be generated by ONE computeOp";
+  }
+  return CacheWriteWithReLayout(*this, tensor_array, scope);
 Tensor Schedule::cache_write(const Tensor& tensor,
@@ -243,7 +303,7 @@ Tensor Schedule::cache_write(const Tensor& tensor,
   CHECK_EQ(compute->num_outputs(), 1)
       << "cache write only support single output ComputeOp";
-  return CacheWriteWithReLayout(*this, tensor, scope);
+  return (CacheWriteWithReLayout(*this, {tensor}, scope))[0];
 void RebaseNonZeroMinLoop(const Schedule& sch) {
@@ -289,13 +349,6 @@ void RebaseNonZeroMinLoop(const Schedule& sch) {
-inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) {
-  return (a->combiner.same_as(b->combiner)) &&
-         (a->source.same_as(b->source)) &&
-         (a->axis.same_as(b->axis)) &&
-         (a->condition.same_as(b->condition));
 void InjectInline(ScheduleNode* sch) {
diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py
index cebacdc2b9a5454699c66e03a7ca43d8adfd872a..e9c69bb57d9f9f72758f98b4478e72e5200b09da 100644
--- a/tests/python/integration/test_ewise.py
+++ b/tests/python/integration/test_ewise.py
@@ -39,6 +39,49 @@ def test_exp():
+def test_multiple_cache_write():
+    # graph
+    n = tvm.convert(1024)
+    A0 = tvm.placeholder((n,), name='A0', dtype = "float32")
+    A1 = tvm.placeholder((n,), name='A1', dtype = "float32")
+    B0, B1 = tvm.compute((n,), 
+            lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)), 
+            name='B')
+    C = tvm.compute((n,), lambda *i: B0(*i) + B1(*i), 
+            name='C')
+    s = tvm.create_schedule(C.op)
+    # create iter var and assign them tags.
+    num_thread = 8
+    B0_cache, B1_cache = s.cache_write([B0, B1], "local")
+    bx, tx = s[C].split(C.op.axis[0], factor=num_thread)
+    s[B0].compute_at(s[C], bx)
+    s[B0_cache].compute_at(s[C], bx)
+    s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
+    s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
+    # one line to build the function.
+    def check_device(device, host="stackvm"):
+        if not tvm.module.enabled(host):
+            return
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            return
+        func = tvm.build(s, [A0, A1, C],
+                         device, host,
+                         name="multiple_cache_write")
+        ctx = tvm.context(device, 0)
+        # launch the kernel.
+        n = 1024
+        a0 = tvm.nd.array(np.random.uniform(size=n).astype(A0.dtype), ctx)
+        a1 = tvm.nd.array(np.random.uniform(size=n).astype(A1.dtype), ctx)
+        c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx)
+        func(a0, a1, c)
+        np.testing.assert_allclose(
+            c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()), 
+            rtol=1e-5)
+    check_device("cuda", "llvm")
+    check_device("vulkan")
+    check_device("opencl")
 def test_log_pow_llvm():
     # graph
@@ -199,6 +242,7 @@ def try_warp_memory():
 if __name__ == "__main__":
+    test_multiple_cache_write()
diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py
index ecb9c0acb7f8cc73119240500621dc50d12ba1d0..8e6f4090d4038f8400ca7aca3385d462ad3f002f 100644
--- a/tests/python/unittest/test_schedule_schedule_ops.py
+++ b/tests/python/unittest/test_schedule_schedule_ops.py
@@ -249,6 +249,20 @@ def test_schedule_cache_relayout3():
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
+def test_schedule_cache_relayout4():
+    def _compute(*indice):
+        return A(*indice) + 1, B(*indice) / 2
+    m = tvm.var('m')
+    n = tvm.var('n')
+    A = tvm.placeholder((m*4, n), name='A')
+    B = tvm.placeholder((m*4, n), name='B')
+    C1, C2 = tvm.compute(A.shape, _compute, name='C')
+    s = tvm.create_schedule([C1.op, C2.op])
+    C1_cache, C2_cache = s.cache_write([C1, C2], "local")
+    s = s.normalize()
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
 def test_schedule_bound_condition():
    A = tvm.placeholder((64,), name='A', dtype="float32")
@@ -265,6 +279,7 @@ def test_schedule_bound_condition():
 if __name__ == "__main__":
+    test_schedule_cache_relayout4()