From 4d2fc952bbec7f635944e48b1b62534a4a410a44 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Sun, 12 Nov 2017 16:30:02 -0800 Subject: [PATCH] [PASS] Fix vthread when extern access touching (#636) --- src/pass/inject_virtual_thread.cc | 50 ++++++++++++++++--- .../unittest/test_pass_inject_vthread.py | 34 +++++++++++++ 2 files changed, 76 insertions(+), 8 deletions(-) diff --git a/src/pass/inject_virtual_thread.cc b/src/pass/inject_virtual_thread.cc index 07f59cb1b..28e90ec48 100644 --- a/src/pass/inject_virtual_thread.cc +++ b/src/pass/inject_virtual_thread.cc @@ -15,11 +15,12 @@ namespace ir { // If expression is touched by var. class ExprTouched final : public IRVisitor { public: - explicit ExprTouched(const std::unordered_set<const Variable*> &touched) - : touched_var_(touched) {} + explicit ExprTouched(const std::unordered_set<const Variable*> &touched, + bool check_write) + : touched_var_(touched), check_write_(check_write) {} void Visit(const NodeRef& n) final { // early stopping - if (expr_touched_) return; + if (expr_touched_ && !check_write_) return; IRVisitor::Visit(n); } void Visit_(const Load *op) final { @@ -29,6 +30,24 @@ class ExprTouched final : public IRVisitor { void Visit_(const Variable *op) final { HandleUseVar(op); } + void Visit_(const Call *op) final { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + int rw_mask; + CHECK(arith::GetConstInt(op->args[4], &rw_mask)); + const Variable* buffer_var = op->args[1].as<Variable>(); + CHECK(buffer_var); + // read + if (rw_mask & 1) { + HandleUseVar(buffer_var); + } + if (rw_mask & 2) { + HandleWriteVar(buffer_var); + } + this->Visit(op->args[2]); + } else { + IRVisitor::Visit_(op); + } + } void HandleUseVar(const Variable* var) { auto it = touched_var_.find(var); if (it != touched_var_.end()) { @@ -40,36 +59,49 @@ class ExprTouched final : public IRVisitor { used_vars_.push_back(var); } } + void HandleWriteVar(const Variable* var) { + write_vars_.push_back(var); + } // the fields. bool expr_touched_{false}; std::vector<const Variable*> used_vars_; + std::vector<const Variable*> write_vars_; const std::unordered_set<const Variable*>& touched_var_; + bool check_write_; }; // Analyze if the buffers are invariant to value of var class VarTouchedAnalysis : public IRVisitor { public: void Visit_(const LetStmt *op) { - ExprTouched tc(touched_var_); + ExprTouched tc(touched_var_, false); tc.Visit(op->value); Record(op->var.get(), tc); this->Visit(op->body); } void Visit_(const Store *op) { - ExprTouched tc(touched_var_); + ExprTouched tc(touched_var_, false); tc.Visit(op->value); tc.Visit(op->index); Record(op->buffer_var.get(), tc); } void Visit_(const For *op) { - ExprTouched tc(touched_var_); + ExprTouched tc(touched_var_, false); tc.Visit(op->min); tc.Visit(op->extent); Record(op->loop_var.get(), tc); this->Visit(op->body); } + // external function call + void Visit_(const Evaluate *op) { + ExprTouched tc(touched_var_, true); + tc.Visit(op->value); + for (const Variable* var : tc.write_vars_) { + Record(var, tc); + } + } void Visit_(const Allocate *op) { - ExprTouched tc(touched_var_); + ExprTouched tc(touched_var_, false); for (size_t i = 0; i < op->extents.size(); ++i) { tc.Visit(op->extents[i]); } @@ -87,7 +119,9 @@ class VarTouchedAnalysis : public IRVisitor { touched_var_.insert(var); } else { for (const Variable* r : tc.used_vars_) { - affect_[r].push_back(var); + if (r != var) { + affect_[r].push_back(var); + } } } } diff --git a/tests/python/unittest/test_pass_inject_vthread.py b/tests/python/unittest/test_pass_inject_vthread.py index e4b3b51fb..502a55574 100644 --- a/tests/python/unittest/test_pass_inject_vthread.py +++ b/tests/python/unittest/test_pass_inject_vthread.py @@ -28,5 +28,39 @@ def test_vthread(): stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("cthread")) assert len(stmt.body.body.extents) == 3 + +def test_vthread_extern(): + dtype = 'int64' + n = 100 + m = 4 + nthread = 2 + def get_vthread(name): + tx = tvm.thread_axis(name) + ty = tvm.thread_axis(name) + ib = tvm.ir_builder.create() + with ib.for_range(0, n) as i: + ib.scope_attr(tx, "virtual_thread", nthread) + ib.scope_attr(ty, "virtual_thread", nthread) + A = ib.allocate("float32", m, name="A", scope="shared") + B = ib.allocate("float32", m, name="B", scope="shared") + C = ib.allocate("float32", m, name="C", scope="shared") + cbuffer = tvm.decl_buffer((m,), dtype=C.dtype, data=C.asnode()) + abuffer = tvm.decl_buffer((m,), dtype=A.dtype, data=A.asnode()) + bbuffer = tvm.decl_buffer((m,), dtype=B.dtype, data=B.asnode()) + A[tx] = tx + 1.0 + B[ty] = ty + 1.0 + ib.emit(tvm.call_extern("int32", "Run", + abuffer.access_ptr("r"), + bbuffer.access_ptr("r"), + cbuffer.access_ptr("rw"))) + return ib.get() + + stmt = tvm.ir_pass.InjectVirtualThread(get_vthread("vthread")) + assert stmt.body.body.extents[0].value == 2 + assert stmt.body.body.body.body.body.body.extents[0].value == 2 + assert len(stmt.body.body.body.body.body.body.extents) == 3 + + if __name__ == "__main__": + test_vthread_extern() test_vthread() -- GitLab