From 38d0835728f8b11c4467ec6a55a44ad9de24de7b Mon Sep 17 00:00:00 2001
From: xqdan <danxiaoqiang@126.com>
Date: Sun, 19 Aug 2018 02:18:29 +0800
Subject: [PATCH] #1592 [PASS] Fix missing mem CHECK in storage_rewrite (#1616)

---
 src/pass/storage_rewrite.cc                   |  6 ++
 .../unittest/test_pass_storage_rewrite.py     | 63 ++++++++++++-------
 2 files changed, 48 insertions(+), 21 deletions(-)

diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc
index 0170499e1..877216ed7 100644
--- a/src/pass/storage_rewrite.cc
+++ b/src/pass/storage_rewrite.cc
@@ -584,6 +584,12 @@ class StoragePlanRewriter : public IRMutator {
           e->new_alloc = Allocate::make(
               e->alloc_var, alloc_type, {combo_size}, const_true(),
               Evaluate::make(0));
+          if (e->scope.tag.length() != 0) {
+            MemoryInfo info = GetMemoryInfo(e->scope.to_string());
+            uint64_t total_elem = e->const_nbits / e->elem_type.bits();
+            CHECK_LE(total_elem * e->elem_type.bits(), info->max_num_bits)
+                << "Allocation exceed bound of memory tag " << e->scope.to_string();
+          }
         }
       }
     }
diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py
index 2bb029989..3c07a1f26 100644
--- a/tests/python/unittest/test_pass_storage_rewrite.py
+++ b/tests/python/unittest/test_pass_storage_rewrite.py
@@ -28,15 +28,30 @@ def test_storage_share():
     tvm.ir_pass.PostOrderVisit(stmt, verify)
     assert num_alloc[0] == 1
 
+def register_mem(scope_tb, max_bits):
+    #Register mem
+    @tvm.register_func("tvm.info.mem.%s" % scope_tb)
+    def mem_info_inp_buffer():
+        return tvm.make.node("MemoryInfo",
+                        unit_bits= 16,
+                        max_simd_bits=32,
+                        max_num_bits=max_bits,
+                        head_address=None)
+
 def test_alloc_seq():
+    scope_tb = "local.L0A"
+    max_bits = 1024 * 1024 * 1024
+
+    register_mem(scope_tb, max_bits)
+
     ib = tvm.ir_builder.create()
     n = tvm.var("n")
     with ib.for_range(0, n, name="i") as i:
         with ib.for_range(0, 10, name="j") as j:
-            A = ib.allocate("float32", 200, name="A", scope="local.L0A")
+            A = ib.allocate("float32", 200, name="A", scope=scope_tb)
             A[j] = 1.2
         with ib.for_range(0, 10, name="j") as j:
-            A = ib.allocate("float32", 200, name="B", scope="local.L0A")
+            A = ib.allocate("float32", 200, name="B", scope=scope_tb)
             A[j] = 1.3
 
     body = ib.get()
@@ -233,16 +248,9 @@ def test_parallel_alloc():
 
     assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate))
 
-def test_inplace_rule2():
+def test_inplace_rule2(scope_tb = "local_TB2", max_bits = 1024 * 1024 * 1024):
     #Test Buffer
-    scope_tb = "local_TB2"
-    @tvm.register_func("tvm.info.mem.%s" % scope_tb)
-    def mem_info_inp_buffer():
-        return tvm.make.node("MemoryInfo",
-                        unit_bits= 16,
-                        max_simd_bits=32,
-                        max_num_bits=1024*1024*1024,
-                        head_address=None)
+    register_mem(scope_tb, max_bits)
     m = 10
     A = tvm.placeholder((m,), name='A')
     C = tvm.placeholder((m,), name='C')
@@ -275,16 +283,23 @@ def test_inplace_rule2():
     tvm.ir_pass.PostOrderVisit(stmt, verify)
     assert num_alloc[0] == 2
 
+def test_exceed_mem():
+    max_bits = 639
+    # The critical max_num_bits is between 639 and 640
+    loc = -1
+    try:
+        test_inplace_rule2("local_TEM", max_bits)
+    except Exception as e:
+        estr = str(e)
+        loc = estr.find('Allocation exceed bound of memory')
+        assert loc != -1
+
 def test_inplace_rule3():
     #Test Buffer
     scope_tb = "local_TB3"
-    @tvm.register_func("tvm.info.mem.%s" % scope_tb)
-    def mem_info_inp_buffer():
-        return tvm.make.node("MemoryInfo",
-                        unit_bits= 16,
-                        max_simd_bits=32,
-                        max_num_bits=1024*1024*1024,
-                        head_address=None)
+    max_bits=1024 * 1024 * 1024
+
+    register_mem(scope_tb, max_bits)
     m = 10
     B0 = tvm.placeholder((m,), name='B0')
     B1 = tvm.placeholder((m,), name='B1')
@@ -388,17 +403,22 @@ def test_alloc_seq_type():
     assert num_alloc[0] == 1
 
 def test_alloc_seq_type2():
+    scope_tb = "local.L0A2"
+    max_bits=1024 * 1024 * 1024
+
+    register_mem(scope_tb, max_bits)
+
     ib = tvm.ir_builder.create()
     n = tvm.var("n")
     with ib.for_range(0, n, name="i") as i:
         with ib.for_range(0, 10, name="j") as j:
-            A = ib.allocate("float32", 200, name="A", scope="local.L0A")
+            A = ib.allocate("float32", 200, name="A", scope=scope_tb)
             A[j] = 1.2
         with ib.for_range(0, 20, name="j") as j:
-            B = ib.allocate("int16", 400, name="B", scope="local.L0A")
+            B = ib.allocate("int16", 400, name="B", scope=scope_tb)
             B[j] = tvm.const(1, "int16")
         with ib.for_range(0, 10, name="j") as j:
-            C = ib.allocate("float32", 200, name="C", scope="local.L0A")
+            C = ib.allocate("float32", 200, name="C", scope=scope_tb)
             C[j] = 1.2
 
     body = ib.get()
@@ -465,6 +485,7 @@ if __name__ == "__main__":
     test_storage_combine()
     test_storage_share_gpu()
     test_inplace_rule2()
+    test_exceed_mem()
     test_inplace_rule3()
     test_alloc_seq_type()
     test_alloc_seq_type2()
-- 
GitLab