diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index b58df9d0481fa189ac434651814e033230b40c45..59d425287be0fac08a19fb66f107403394817c4c 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -82,12 +82,12 @@ Tensor Schedule::cache_read(const Tensor& tensor, } os << "." << scope; - Tensor cache = compute(tensor->shape, [&tensor](const Array<Var>& i) { - return tensor(Array<Expr>(i.begin(), i.end())); - }, os.str()); std::unordered_map<Tensor, Tensor> vsub; Stage s = operator[](tensor->op); Tensor sugar_tensor = s->op.output(tensor->value_index); + Tensor cache = compute(sugar_tensor->shape, [&sugar_tensor](const Array<Var>& i) { + return sugar_tensor(Array<Expr>(i.begin(), i.end())); + }, os.str()); vsub[sugar_tensor] = cache; std::unordered_map<Tensor, Tensor> vmap; diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index 1e4dda684eb387f61714a55f7afd3365c2bf12ac..d044db12686f8d752fc7e22bf8346e9be4342d7f 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -171,7 +171,47 @@ def test_parallel_alloc(): assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate)) - +def test_inplace_rule2(): + #Test Buffer + scope_tb = "local_TB" + @tvm.register_func("tvm.info.mem.%s" % scope_tb) + def mem_info_inp_buffer(): + return tvm.make.node("MemoryInfo", + unit_bits= 16, + max_simd_bits=32, + max_num_bits=1024*1024*1024, + head_address=None) + m = 10 + A = tvm.placeholder((m,), name='A') + C = tvm.placeholder((m,), name='C') + D = tvm.placeholder((m,), name='D') + A0 = tvm.compute((m,), lambda i: A[i] + C[i], name='A0') + A1 = tvm.compute((m,), lambda i: D[i] * D[i], name='A1') + A2 = tvm.compute((m,), lambda i: A0[i] + A1[i], name='A2') + B = tvm.compute((m,), lambda i: A2[i], name='B') + s = tvm.create_schedule(B.op) + A0L = s.cache_read(A0, scope_tb, [A2]) + A1L = s.cache_read(A1, scope_tb, [A2]) + A2L = s.cache_read(A2, scope_tb, [B]) + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + stmt = tvm.schedule.ScheduleOps(s, bounds) + Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') + Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') + Cc = tvm.decl_buffer(C.shape, B.dtype, name='C') + Dd = tvm.decl_buffer(D.shape, B.dtype, name='D') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb, C: Cc, D:Dd}, 64) + stmt = tvm.ir_pass.CanonicalSimplify(stmt) + stmt = tvm.ir_pass.Simplify(stmt) + stmt = tvm.ir_pass.StorageRewrite(stmt) + # verify only have one allocations. + # verify inplace folding works + num_alloc = [0] + def verify(n): + if isinstance(n, tvm.stmt.Allocate): + num_alloc[0] += 1 + tvm.ir_pass.PostOrderVisit(stmt, verify) + assert num_alloc[0] == 2 if __name__ == "__main__": test_alloc_seq() @@ -180,3 +220,4 @@ if __name__ == "__main__": test_parallel_alloc() test_storage_combine() test_storage_share_gpu() + test_inplace_rule2()