From e4a513035a87c62732d80392bb1fbaea210e0ecd Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Tue, 9 Jan 2018 17:39:16 -0800 Subject: [PATCH] [PASS] Fix storage rewrite merge rule for special tag memory (#770) --- src/pass/storage_rewrite.cc | 28 +++++++++++-------- .../unittest/test_pass_storage_rewrite.py | 23 +++++++++++++++ 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 5e7abdda2..7215c3f97 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -766,14 +766,15 @@ class StoragePlanRewriter : public IRMutator { const uint64_t match_range = 16; uint64_t const_nbits = static_cast<uint64_t>( op->constant_allocation_size() * op->type.bits() * op->type.lanes()); - if (scope.rank > 1 || op->type.is_handle()) { - return NewAlloc(op, attach_scope, scope, const_nbits); - } // disable reuse of small arrays, they will be lowered to registers in LLVM - if (const_nbits > 0 && - const_nbits <= 32 && - scope.tag.length() == 0) { - return NewAlloc(op, attach_scope, scope, const_nbits); + // This rules only apply if we are using non special memory + if (scope.tag.length() == 0) { + if (scope.rank > 1 || op->type.is_handle()) { + return NewAlloc(op, attach_scope, scope, const_nbits); + } + if (const_nbits > 0 && const_nbits <= 32) { + return NewAlloc(op, attach_scope, scope, const_nbits); + } } if (const_nbits != 0) { // constant allocation. @@ -818,10 +819,15 @@ class StoragePlanRewriter : public IRMutator { CHECK(it != alloc_map_.end()); StorageEntry* e = it->second; CHECK_NE(e->allocs.size(), 0U); - // Disable sharing of local memory. - if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return; - // disable reuse of small arrays - if (e->const_nbits > 0 && e->const_nbits <= 32) return; + + // disable reuse of small arrays, they will be lowered to registers in LLVM + // This rules only apply if we are using non special memory + if (e->scope.tag.length() == 0) { + // Disable sharing of local memory. + if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return; + // disable reuse of small arrays + if (e->const_nbits > 0 && e->const_nbits <= 32) return; + } // normal free. if (e->const_nbits != 0) { const_free_map_.insert({e->const_nbits, e}); diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index d3f6307f8..1e4dda684 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -28,6 +28,28 @@ def test_storage_share(): tvm.ir_pass.PostOrderVisit(stmt, verify) assert num_alloc[0] == 1 +def test_alloc_seq(): + ib = tvm.ir_builder.create() + n = tvm.var("n") + with ib.for_range(0, n, name="i") as i: + with ib.for_range(0, 10, name="j") as j: + A = ib.allocate("float32", 200, name="A", scope="local.L0A") + A[j] = 1.2 + with ib.for_range(0, 10, name="j") as j: + A = ib.allocate("float32", 200, name="B", scope="local.L0A") + A[j] = 1.3 + + body = ib.get() + body = tvm.ir_pass.StorageRewrite(body) + num_alloc = [0] + def verify(n): + if isinstance(n, tvm.stmt.Allocate): + num_alloc[0] += 1 + assert n.extents[0].value == 200 + tvm.ir_pass.PostOrderVisit(body, verify) + assert num_alloc[0] == 1 + + def test_inplace_rule(): m = 10 @@ -152,6 +174,7 @@ def test_parallel_alloc(): if __name__ == "__main__": + test_alloc_seq() test_inplace_rule() test_storage_share() test_parallel_alloc() -- GitLab