diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
index b58df9d0481fa189ac434651814e033230b40c45..59d425287be0fac08a19fb66f107403394817c4c 100644
--- a/src/schedule/schedule_dataflow_rewrite.cc
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -82,12 +82,12 @@ Tensor Schedule::cache_read(const Tensor& tensor,
   }
   os << "." << scope;
 
-  Tensor cache = compute(tensor->shape, [&tensor](const Array<Var>& i) {
-      return tensor(Array<Expr>(i.begin(), i.end()));
-    }, os.str());
   std::unordered_map<Tensor, Tensor> vsub;
   Stage s = operator[](tensor->op);
   Tensor sugar_tensor = s->op.output(tensor->value_index);
+  Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array<Var>& i) {
+      return sugar_tensor(Array<Expr>(i.begin(), i.end()));
+    }, os.str());
   vsub[sugar_tensor] = cache;
 
   std::unordered_map<Tensor, Tensor> vmap;
diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py
index 1e4dda684eb387f61714a55f7afd3365c2bf12ac..d044db12686f8d752fc7e22bf8346e9be4342d7f 100644
--- a/tests/python/unittest/test_pass_storage_rewrite.py
+++ b/tests/python/unittest/test_pass_storage_rewrite.py
@@ -171,7 +171,47 @@ def test_parallel_alloc():
 
     assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate))
 
-
+def test_inplace_rule2():
+    #Test Buffer
+    scope_tb = "local_TB"
+    @tvm.register_func("tvm.info.mem.%s" % scope_tb)
+    def mem_info_inp_buffer():
+        return tvm.make.node("MemoryInfo",
+                        unit_bits= 16,
+                        max_simd_bits=32,
+                        max_num_bits=1024*1024*1024,
+                        head_address=None)
+    m = 10
+    A = tvm.placeholder((m,), name='A')
+    C = tvm.placeholder((m,), name='C')
+    D = tvm.placeholder((m,), name='D')
+    A0 = tvm.compute((m,), lambda i: A[i] + C[i], name='A0')
+    A1 = tvm.compute((m,), lambda i: D[i] * D[i], name='A1')
+    A2 = tvm.compute((m,), lambda i: A0[i] + A1[i], name='A2')
+    B = tvm.compute((m,), lambda i: A2[i], name='B')
+    s = tvm.create_schedule(B.op)
+    A0L = s.cache_read(A0, scope_tb, [A2])
+    A1L = s.cache_read(A1, scope_tb, [A2])
+    A2L = s.cache_read(A2, scope_tb, [B])
+    bounds = tvm.schedule.InferBound(s)
+    assert isinstance(bounds, tvm.container.Map)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+    Ab = tvm.decl_buffer(A.shape, A.dtype, name='A')
+    Bb = tvm.decl_buffer(B.shape, B.dtype, name='B')
+    Cc = tvm.decl_buffer(C.shape, B.dtype, name='C')
+    Dd = tvm.decl_buffer(D.shape, B.dtype, name='D')
+    stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd}, 64)
+    stmt = tvm.ir_pass.CanonicalSimplify(stmt)
+    stmt = tvm.ir_pass.Simplify(stmt)
+    stmt = tvm.ir_pass.StorageRewrite(stmt)
+    # verify only have one allocations.
+    # verify inplace folding works
+    num_alloc = [0]
+    def verify(n):
+        if isinstance(n, tvm.stmt.Allocate):
+            num_alloc[0] += 1
+    tvm.ir_pass.PostOrderVisit(stmt, verify)
+    assert num_alloc[0] == 2
 
 if __name__ == "__main__":
     test_alloc_seq()
@@ -180,3 +220,4 @@ if __name__ == "__main__":
     test_parallel_alloc()
     test_storage_combine()
     test_storage_share_gpu()
+    test_inplace_rule2()