From 4d2fc952bbec7f635944e48b1b62534a4a410a44 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sun, 12 Nov 2017 16:30:02 -0800
Subject: [PATCH] [PASS] Fix vthread when extern access touching (#636)

---
 src/pass/inject_virtual_thread.cc             | 50 ++++++++++++++++---
 .../unittest/test_pass_inject_vthread.py      | 34 +++++++++++++
 2 files changed, 76 insertions(+), 8 deletions(-)

diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc
index 07f59cb1b..28e90ec48 100644
--- a/src/pass/inject_virtual_thread.cc
+++ b/src/pass/inject_virtual_thread.cc
@@ -15,11 +15,12 @@ namespace ir {
 // If expression is touched by var.
 class ExprTouched final : public IRVisitor {
  public:
-  explicit ExprTouched(const std::unordered_set<const Variable*> &touched)
-      : touched_var_(touched) {}
+  explicit ExprTouched(const std::unordered_set<const Variable*> &touched,
+                       bool check_write)
+      : touched_var_(touched), check_write_(check_write) {}
   void Visit(const NodeRef& n) final {
     // early stopping
-    if (expr_touched_) return;
+    if (expr_touched_ && !check_write_) return;
     IRVisitor::Visit(n);
   }
   void Visit_(const Load *op) final {
@@ -29,6 +30,24 @@ class ExprTouched final : public IRVisitor {
   void Visit_(const Variable *op) final {
     HandleUseVar(op);
   }
+  void Visit_(const Call *op) final {
+    if (op->is_intrinsic(intrinsic::tvm_access_ptr)) {
+      int rw_mask;
+      CHECK(arith::GetConstInt(op->args[4], &rw_mask));
+      const Variable* buffer_var = op->args[1].as<Variable>();
+      CHECK(buffer_var);
+      // read
+      if (rw_mask & 1) {
+        HandleUseVar(buffer_var);
+      }
+      if (rw_mask & 2) {
+        HandleWriteVar(buffer_var);
+      }
+      this->Visit(op->args[2]);
+    } else {
+      IRVisitor::Visit_(op);
+    }
+  }
   void HandleUseVar(const Variable* var) {
     auto it = touched_var_.find(var);
     if (it != touched_var_.end()) {
@@ -40,36 +59,49 @@ class ExprTouched final : public IRVisitor {
       used_vars_.push_back(var);
     }
   }
+  void HandleWriteVar(const Variable* var) {
+    write_vars_.push_back(var);
+  }
   // the fields.
   bool expr_touched_{false};
   std::vector<const Variable*> used_vars_;
+  std::vector<const Variable*> write_vars_;
   const std::unordered_set<const Variable*>& touched_var_;
+  bool check_write_;
 };
 
 // Analyze if the buffers are invariant to value of var
 class VarTouchedAnalysis : public IRVisitor {
  public:
   void Visit_(const LetStmt *op) {
-    ExprTouched tc(touched_var_);
+    ExprTouched tc(touched_var_, false);
     tc.Visit(op->value);
     Record(op->var.get(), tc);
     this->Visit(op->body);
   }
   void Visit_(const Store *op) {
-    ExprTouched tc(touched_var_);
+    ExprTouched tc(touched_var_, false);
     tc.Visit(op->value);
     tc.Visit(op->index);
     Record(op->buffer_var.get(), tc);
   }
   void Visit_(const For *op) {
-    ExprTouched tc(touched_var_);
+    ExprTouched tc(touched_var_, false);
     tc.Visit(op->min);
     tc.Visit(op->extent);
     Record(op->loop_var.get(), tc);
     this->Visit(op->body);
   }
+  // external function call
+  void Visit_(const Evaluate *op) {
+    ExprTouched tc(touched_var_, true);
+    tc.Visit(op->value);
+    for (const Variable* var : tc.write_vars_) {
+      Record(var, tc);
+    }
+  }
   void Visit_(const Allocate *op) {
-    ExprTouched tc(touched_var_);
+    ExprTouched tc(touched_var_, false);
     for (size_t i = 0; i < op->extents.size(); ++i) {
       tc.Visit(op->extents[i]);
     }
@@ -87,7 +119,9 @@ class VarTouchedAnalysis : public IRVisitor {
       touched_var_.insert(var);
     } else {
       for (const Variable* r : tc.used_vars_) {
-        affect_[r].push_back(var);
+        if (r != var) {
+          affect_[r].push_back(var);
+        }
       }
     }
   }
diff --git a/tests/python/unittest/test_pass_inject_vthread.py b/tests/python/unittest/test_pass_inject_vthread.py
index e4b3b51fb..502a55574 100644
--- a/tests/python/unittest/test_pass_inject_vthread.py
+++ b/tests/python/unittest/test_pass_inject_vthread.py
@@ -28,5 +28,39 @@ def test_vthread():
     stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("cthread"))
     assert len(stmt.body.body.extents) == 3
 
+
+def test_vthread_extern():
+    dtype = 'int64'
+    n = 100
+    m = 4
+    nthread = 2
+    def get_vthread(name):
+        tx = tvm.thread_axis(name)
+        ty = tvm.thread_axis(name)
+        ib = tvm.ir_builder.create()
+        with ib.for_range(0, n) as i:
+            ib.scope_attr(tx, "virtual_thread", nthread)
+            ib.scope_attr(ty, "virtual_thread", nthread)
+            A = ib.allocate("float32", m, name="A", scope="shared")
+            B = ib.allocate("float32", m, name="B", scope="shared")
+            C = ib.allocate("float32", m, name="C", scope="shared")
+            cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asnode())
+            abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asnode())
+            bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode())
+            A[tx] = tx + 1.0
+            B[ty] = ty + 1.0
+            ib.emit(tvm.call_extern("int32", "Run",
+                                    abuffer.access_ptr("r"),
+                                    bbuffer.access_ptr("r"),
+                                    cbuffer.access_ptr("rw")))
+        return ib.get()
+
+    stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("vthread"))
+    assert stmt.body.body.extents[0].value == 2
+    assert stmt.body.body.body.body.body.body.extents[0].value == 2
+    assert len(stmt.body.body.body.body.body.body.extents) == 3
+
+
 if __name__ == "__main__":
+    test_vthread_extern()
     test_vthread()
-- 
GitLab