From c55865b46ae5fd37699d05caf373144cdcf8b21c Mon Sep 17 00:00:00 2001
From: libing4752 <libing475211023@sjtu.edu.cn>
Date: Thu, 14 Jun 2018 01:53:04 +0800
Subject: [PATCH] fix copro_sync.cc errors of ctx (#1274)

---
 src/pass/coproc_sync.cc                       |  2 +-
 .../python/unittest/test_pass_storage_sync.py | 37 +++++++++++++++++++
 2 files changed, 38 insertions(+), 1 deletion(-)

diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc
index 28be8aba2..b3e64a989 100644
--- a/src/pass/coproc_sync.cc
+++ b/src/pass/coproc_sync.cc
@@ -385,7 +385,7 @@ class CoProcInstDepDetector : public IRVisitor {
                  &(curr_state_.exit_push),
                  &(curr_state_.enter_pop));
       curr_state_.enter_ctx = first_state_.enter_ctx;
-      curr_state_.exit_ctx = last_state_.enter_ctx;
+      curr_state_.exit_ctx = last_state_.exit_ctx;
     }
     std::swap(first_state_, temp_first);
     std::swap(last_state_, temp_last);
diff --git a/tests/python/unittest/test_pass_storage_sync.py b/tests/python/unittest/test_pass_storage_sync.py
index ce9e2f9a4..2286dd53e 100644
--- a/tests/python/unittest/test_pass_storage_sync.py
+++ b/tests/python/unittest/test_pass_storage_sync.py
@@ -78,7 +78,44 @@ def test_coproc_sync2():
     stmt = ib.get()
     stmt = tvm.ir_pass.CoProcSync(stmt)
 
+def test_coproc_sync3():
+    def __check_list(tvm_array, py_list):
+        for ti, li in zip(tvm_array, py_list):
+            if ti.value != li:
+                return False
+        return True
+
+    ib = tvm.ir_builder.create()
+    n = tvm.var("n")
+    cp = tvm.thread_axis((0, 1), "cop")
+    A = ib.allocate("float32", 128, name="A", scope="global.cache")
+    with ib.for_range(0, n, name="i") as i:
+        with ib.for_range(0, n, name="i") as j:
+            with ib.new_scope():
+                ib.scope_attr(cp, "coproc_scope", 1)
+                A[i] = 1.0
+            with ib.new_scope():
+                ib.scope_attr(cp, "coproc_scope", 2)
+                A[i] = 1.0
+    with ib.new_scope():
+        ib.scope_attr(cp, "coproc_scope", 3)
+        A[0] = 0.0
+   
+    stmt = ib.get()
+    stmt = tvm.ir_pass.CoProcSync(stmt)
+    slist = tvm.make.stmt_list(stmt.first.body.body)
+    push_st = slist[2]
+    slist = tvm.make.stmt_list(slist[-1])
+    pop_st = slist[0].body.first
+
+    assert(push_st.value.name == "cop.coproc_dep_push")
+    assert(__check_list(push_st.value.args, [2,3]))
+    assert(pop_st.value.name == "cop.coproc_dep_pop")
+    assert(__check_list(pop_st.value.args, [2,3]))
+    
+
 if __name__ == "__main__":
     test_coproc_sync()
     test_storage_sync()
     test_coproc_sync2()
+    test_coproc_sync3()
-- 
GitLab