From c492737826261a84b37138a215ac4fe7170a8c72 Mon Sep 17 00:00:00 2001
From: libing4752 <libing475211023@sjtu.edu.cn>
Date: Thu, 4 Jan 2018 05:37:56 +0800
Subject: [PATCH] modified schedule_dataflow_rewrite.cc to fix Stale Tensor
 during Dataflow Rewrite #738 (#747)

* modified schedule_dataflow_rewrite.cc to fix losing tensor problem

* modified schedule_dataflow_rewrite.cc for lint scan

* modified schedule_dataflow_rewrite.cc for lint scan

* using tensor's value_index to index output of stage op
---
 src/schedule/schedule_dataflow_rewrite.cc     |  4 +++-
 .../unittest/test_schedule_schedule_ops.py    | 20 +++++++++++++++++++
 2 files changed, 23 insertions(+), 1 deletion(-)

diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc
index d1a69ecf0..b58df9d04 100644
--- a/src/schedule/schedule_dataflow_rewrite.cc
+++ b/src/schedule/schedule_dataflow_rewrite.cc
@@ -86,7 +86,9 @@ Tensor Schedule::cache_read(const Tensor& tensor,
       return tensor(Array<Expr>(i.begin(), i.end()));
     }, os.str());
   std::unordered_map<Tensor, Tensor> vsub;
-  vsub[tensor] = cache;
+  Stage s = operator[](tensor->op);
+  Tensor sugar_tensor = s->op.output(tensor->value_index);
+  vsub[sugar_tensor] = cache;
 
   std::unordered_map<Tensor, Tensor> vmap;
   for (Operation op : readers) {
diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py
index a85db2a23..03b8dbf48 100644
--- a/tests/python/unittest/test_schedule_schedule_ops.py
+++ b/tests/python/unittest/test_schedule_schedule_ops.py
@@ -182,6 +182,25 @@ def test_schedule_cache():
     bounds = tvm.schedule.InferBound(s)
     stmt = tvm.schedule.ScheduleOps(s, bounds)
 
+def test_schedule_middle_cache():
+    m = tvm.var('m')
+    n = tvm.var('n')
+    A = tvm.placeholder((m, n), name='A')
+    B = tvm.placeholder((m, n), name='B')
+
+    C = tvm.compute((m, n), lambda i, j:  A(i, j) * B(i, j), name='C')
+    D = tvm.compute((m, n), lambda i, j:  C(i , j) , name='D')
+
+    s = tvm.create_schedule(D.op)
+    AA = s.cache_read(A, "local", readers=[C])
+    BB = s.cache_read(B, "local", readers=[C])
+    CC = s.cache_read(C, "local", readers=[D])
+    DD = s.cache_write(D, "local")
+    #s[AA].compute_at(s[CC], CC.op.axis[0])
+    bounds = tvm.schedule.InferBound(s)
+    stmt = tvm.schedule.ScheduleOps(s, bounds)
+
+
 
 def test_schedule_cache_relayout1():
     m = tvm.var('m')
@@ -231,6 +250,7 @@ def test_schedule_cache_relayout3():
 
 
 if __name__ == "__main__":
+    test_schedule_middle_cache()
     test_inline_multi_reduce()
     test_schedule_cache_relayout3()
     test_schedule_cache_relayout2()
-- 
GitLab