From 53428606c8813c442d7b05d89d564cf0dd3e3585 Mon Sep 17 00:00:00 2001
From: Gaoxiong <40658249+gaoxiong1@users.noreply.github.com>
Date: Mon, 15 Oct 2018 07:56:00 +0800
Subject: [PATCH] support double buffer to use in ir builder DSL(#1897) (#1898)

---
 src/pass/storage_flatten.cc                   |  3 +-
 .../unittest/test_pass_storage_flatten.py     | 33 +++++++++++++++++++
 2 files changed, 35 insertions(+), 1 deletion(-)

diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc
index 993f6294e..8c2105829 100644
--- a/src/pass/storage_flatten.cc
+++ b/src/pass/storage_flatten.cc
@@ -59,7 +59,8 @@ class StorageFlattener : public IRMutator {
     if (op->attr_key == attr::realize_scope) {
       storage_scope_[op->node.get()] = op->value.as<StringImm>()->value;
       return this->Mutate(op->body);
-    } else if (op->attr_key == attr::double_buffer_scope) {
+    } else if (op->attr_key == attr::double_buffer_scope &&
+               op->node.node_->derived_from<OperationNode>()) {
       Operation func(op->node.node_);
       Stmt body = Mutate(op->body);
       for (int i = 0; i < func->num_outputs(); ++i) {
diff --git a/tests/python/unittest/test_pass_storage_flatten.py b/tests/python/unittest/test_pass_storage_flatten.py
index 4e2feed23..655df1da4 100644
--- a/tests/python/unittest/test_pass_storage_flatten.py
+++ b/tests/python/unittest/test_pass_storage_flatten.py
@@ -51,8 +51,41 @@ def test_flatten_storage_align():
     stmt = tvm.ir_pass.Simplify(stmt)
     assert(stmt.body.extents[0].value == 17 * 8)
 
+def test_flatten_double_buffer():
+    dtype = 'int64'
+    n = 100
+    m = 4
+    tx = tvm.thread_axis("threadIdx.x")
+    ib = tvm.ir_builder.create()
+    A = ib.pointer("float32", name="A")
+    C = ib.pointer("float32", name="C")
+    ib.scope_attr(tx, "thread_extent", 1)
+    with ib.for_range(0, n) as i:
+        B = ib.allocate("float32", m, name="B", scope="shared")
+        with ib.new_scope():
+            ib.scope_attr(B.asnode(), "double_buffer_scope", 1)
+            with ib.for_range(0, m) as j:
+                B[j] = A[i * 4 + j]
+        with ib.for_range(0, m) as j:
+            C[j] = B[j] + 1
+
+    stmt = ib.get()
+    stmt = tvm.ir_pass.StorageFlatten(stmt, {}, 64)
+    stmt = tvm.ir_pass.InjectDoubleBuffer(stmt, 2)
+    stmt = tvm.ir_pass.Simplify(stmt)
+    assert isinstance(stmt.body.body, tvm.stmt.Allocate)
+    assert stmt.body.body.extents[0].value == 2
+    f = tvm.ir_pass.MakeAPI(stmt, "db", [A.asnode(), C.asnode()], 2, True)
+    f = tvm.ir_pass.ThreadSync(f, "shared")
+    count = [0]
+    def count_sync(op):
+        if isinstance(op, tvm.expr.Call) and op.name == "tvm_storage_sync":
+            count[0] += 1
+    tvm.ir_pass.PostOrderVisit(f.body, count_sync)
+    assert count[0] == 4
 
 if __name__ == "__main__":
     test_flatten_storage_align()
     test_flatten2()
     test_flatten_prefetch()
+    test_flatten_double_buffer()
-- 
GitLab