diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 5e7abdda2112dc58b67b2c5ab14d0edad6d56071..7215c3f97a439992d0b4244e57662d5113100750 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -766,14 +766,15 @@ class StoragePlanRewriter : public IRMutator { const uint64_t match_range = 16; uint64_t const_nbits = static_cast<uint64_t>( op->constant_allocation_size() * op->type.bits() * op->type.lanes()); - if (scope.rank > 1 || op->type.is_handle()) { - return NewAlloc(op, attach_scope, scope, const_nbits); - } // disable reuse of small arrays, they will be lowered to registers in LLVM - if (const_nbits > 0 && - const_nbits <= 32 && - scope.tag.length() == 0) { - return NewAlloc(op, attach_scope, scope, const_nbits); + // This rules only apply if we are using non special memory + if (scope.tag.length() == 0) { + if (scope.rank > 1 || op->type.is_handle()) { + return NewAlloc(op, attach_scope, scope, const_nbits); + } + if (const_nbits > 0 && const_nbits <= 32) { + return NewAlloc(op, attach_scope, scope, const_nbits); + } } if (const_nbits != 0) { // constant allocation. @@ -818,10 +819,15 @@ class StoragePlanRewriter : public IRMutator { CHECK(it != alloc_map_.end()); StorageEntry* e = it->second; CHECK_NE(e->allocs.size(), 0U); - // Disable sharing of local memory. - if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return; - // disable reuse of small arrays - if (e->const_nbits > 0 && e->const_nbits <= 32) return; + + // disable reuse of small arrays, they will be lowered to registers in LLVM + // This rules only apply if we are using non special memory + if (e->scope.tag.length() == 0) { + // Disable sharing of local memory. + if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return; + // disable reuse of small arrays + if (e->const_nbits > 0 && e->const_nbits <= 32) return; + } // normal free. if (e->const_nbits != 0) { const_free_map_.insert({e->const_nbits, e}); diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index d3f6307f821f7743e871e4ee7584f8c5ea51ca35..1e4dda684eb387f61714a55f7afd3365c2bf12ac 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -28,6 +28,28 @@ def test_storage_share(): tvm.ir_pass.PostOrderVisit(stmt, verify) assert num_alloc[0] == 1 +def test_alloc_seq(): + ib = tvm.ir_builder.create() + n = tvm.var("n") + with ib.for_range(0, n, name="i") as i: + with ib.for_range(0, 10, name="j") as j: + A = ib.allocate("float32", 200, name="A", scope="local.L0A") + A[j] = 1.2 + with ib.for_range(0, 10, name="j") as j: + A = ib.allocate("float32", 200, name="B", scope="local.L0A") + A[j] = 1.3 + + body = ib.get() + body = tvm.ir_pass.StorageRewrite(body) + num_alloc = [0] + def verify(n): + if isinstance(n, tvm.stmt.Allocate): + num_alloc[0] += 1 + assert n.extents[0].value == 200 + tvm.ir_pass.PostOrderVisit(body, verify) + assert num_alloc[0] == 1 + + def test_inplace_rule(): m = 10 @@ -152,6 +174,7 @@ def test_parallel_alloc(): if __name__ == "__main__": + test_alloc_seq() test_inplace_rule() test_storage_share() test_parallel_alloc()