From e4a513035a87c62732d80392bb1fbaea210e0ecd Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Tue, 9 Jan 2018 17:39:16 -0800
Subject: [PATCH] [PASS] Fix storage rewrite merge rule for special tag memory
 (#770)

---
 src/pass/storage_rewrite.cc                   | 28 +++++++++++--------
 .../unittest/test_pass_storage_rewrite.py     | 23 +++++++++++++++
 2 files changed, 40 insertions(+), 11 deletions(-)

diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc
index 5e7abdda2..7215c3f97 100644
--- a/src/pass/storage_rewrite.cc
+++ b/src/pass/storage_rewrite.cc
@@ -766,14 +766,15 @@ class StoragePlanRewriter : public IRMutator {
     const uint64_t match_range = 16;
     uint64_t const_nbits = static_cast<uint64_t>(
         op->constant_allocation_size() * op->type.bits() * op->type.lanes());
-    if (scope.rank > 1 || op->type.is_handle()) {
-      return NewAlloc(op, attach_scope, scope, const_nbits);
-    }
     // disable reuse of small arrays, they will be lowered to registers in LLVM
-    if (const_nbits > 0  &&
-        const_nbits <= 32 &&
-        scope.tag.length() == 0) {
-      return NewAlloc(op, attach_scope, scope, const_nbits);
+    // This rules only apply if we are using non special memory
+    if (scope.tag.length() == 0) {
+      if (scope.rank > 1 || op->type.is_handle()) {
+        return NewAlloc(op, attach_scope, scope, const_nbits);
+      }
+      if (const_nbits > 0  &&  const_nbits <= 32) {
+        return NewAlloc(op, attach_scope, scope, const_nbits);
+      }
     }
     if (const_nbits != 0) {
       // constant allocation.
@@ -818,10 +819,15 @@ class StoragePlanRewriter : public IRMutator {
     CHECK(it != alloc_map_.end());
     StorageEntry* e = it->second;
     CHECK_NE(e->allocs.size(), 0U);
-    // Disable sharing of local memory.
-    if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return;
-    // disable reuse of small arrays
-    if (e->const_nbits > 0 && e->const_nbits <= 32) return;
+
+    // disable reuse of small arrays, they will be lowered to registers in LLVM
+    // This rules only apply if we are using non special memory
+    if (e->scope.tag.length() == 0) {
+      // Disable sharing of local memory.
+      if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return;
+      // disable reuse of small arrays
+      if (e->const_nbits > 0 && e->const_nbits <= 32) return;
+    }
     // normal free.
     if (e->const_nbits != 0) {
       const_free_map_.insert({e->const_nbits, e});
diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py
index d3f6307f8..1e4dda684 100644
--- a/tests/python/unittest/test_pass_storage_rewrite.py
+++ b/tests/python/unittest/test_pass_storage_rewrite.py
@@ -28,6 +28,28 @@ def test_storage_share():
     tvm.ir_pass.PostOrderVisit(stmt, verify)
     assert num_alloc[0] == 1
 
+def test_alloc_seq():
+    ib = tvm.ir_builder.create()
+    n = tvm.var("n")
+    with ib.for_range(0, n, name="i") as i:
+        with ib.for_range(0, 10, name="j") as j:
+            A = ib.allocate("float32", 200, name="A", scope="local.L0A")
+            A[j] = 1.2
+        with ib.for_range(0, 10, name="j") as j:
+            A = ib.allocate("float32", 200, name="B", scope="local.L0A")
+            A[j] = 1.3
+
+    body = ib.get()
+    body = tvm.ir_pass.StorageRewrite(body)
+    num_alloc = [0]
+    def verify(n):
+        if isinstance(n, tvm.stmt.Allocate):
+            num_alloc[0] += 1
+            assert n.extents[0].value == 200
+    tvm.ir_pass.PostOrderVisit(body, verify)
+    assert num_alloc[0] == 1
+
+
 
 def test_inplace_rule():
     m = 10
@@ -152,6 +174,7 @@ def test_parallel_alloc():
 
 
 if __name__ == "__main__":
+    test_alloc_seq()
     test_inplace_rule()
     test_storage_share()
     test_parallel_alloc()
-- 
GitLab