From 53428606c8813c442d7b05d89d564cf0dd3e3585 Mon Sep 17 00:00:00 2001 From: Gaoxiong <40658249+gaoxiong1@users.noreply.github.com> Date: Mon, 15 Oct 2018 07:56:00 +0800 Subject: [PATCH] support double buffer to use in ir builder DSL(#1897) (#1898) --- src/pass/storage_flatten.cc | 3 +- .../unittest/test_pass_storage_flatten.py | 33 +++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 993f6294e..8c2105829 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -59,7 +59,8 @@ class StorageFlattener : public IRMutator { if (op->attr_key == attr::realize_scope) { storage_scope_[op->node.get()] = op->value.as<StringImm>()->value; return this->Mutate(op->body); - } else if (op->attr_key == attr::double_buffer_scope) { + } else if (op->attr_key == attr::double_buffer_scope && + op->node.node_->derived_from<OperationNode>()) { Operation func(op->node.node_); Stmt body = Mutate(op->body); for (int i = 0; i < func->num_outputs(); ++i) { diff --git a/tests/python/unittest/test_pass_storage_flatten.py b/tests/python/unittest/test_pass_storage_flatten.py index 4e2feed23..655df1da4 100644 --- a/tests/python/unittest/test_pass_storage_flatten.py +++ b/tests/python/unittest/test_pass_storage_flatten.py @@ -51,8 +51,41 @@ def test_flatten_storage_align(): stmt = tvm.ir_pass.Simplify(stmt) assert(stmt.body.extents[0].value == 17 * 8) +def test_flatten_double_buffer(): + dtype = 'int64' + n = 100 + m = 4 + tx = tvm.thread_axis("threadIdx.x") + ib = tvm.ir_builder.create() + A = ib.pointer("float32", name="A") + C = ib.pointer("float32", name="C") + ib.scope_attr(tx, "thread_extent", 1) + with ib.for_range(0, n) as i: + B = ib.allocate("float32", m, name="B", scope="shared") + with ib.new_scope(): + ib.scope_attr(B.asnode(), "double_buffer_scope", 1) + with ib.for_range(0, m) as j: + B[j] = A[i * 4 + j] + with ib.for_range(0, m) as j: + C[j] = B[j] + 1 + + stmt = ib.get() + stmt = tvm.ir_pass.StorageFlatten(stmt, {}, 64) + stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, 2) + stmt = tvm.ir_pass.Simplify(stmt) + assert isinstance(stmt.body.body, tvm.stmt.Allocate) + assert stmt.body.body.extents[0].value == 2 + f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True) + f = tvm.ir_pass.ThreadSync(f, "shared") + count = [0] + def count_sync(op): + if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync": + count[0] += 1 + tvm.ir_pass.PostOrderVisit(f.body, count_sync) + assert count[0] == 4 if __name__ == "__main__": test_flatten_storage_align() test_flatten2() test_flatten_prefetch() + test_flatten_double_buffer() -- GitLab