diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 6284906c1a4ed7fb7166e925f885dcc4ebbbe0c5..594e8b8a7c168ff62cb597084fa8a3147f172046 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -301,6 +301,23 @@ class Schedule : public NodeRef { * User can further call compute_inline to inline the original layout and keep * the data stored in the transformed layout. * + * \param tensor The tensors to be produced. + * \param scope The scope of the storage. + * \return The created tensor. + */ + EXPORT Array<Tensor> cache_write(const Array<Tensor>& tensor, const std::string& scope); + /*! + * \brief Create a cache write tensor for producing tensor. + * The the tensor will take over body of original tensor op. + * + * This function can be used to do data layout transformation. + * If there is a split/fuse/reorder on the data parallel axis of tensor + * before cache_write is called. The intermediate cache stores + * the data in the layout as the iteration order of leave axis. + * The data will be transformed back to the original layout in the original tensor. + * User can further call compute_inline to inline the original layout and keep + * the data stored in the transformed layout. + * * \param tensor The tensor to be produced. * \param scope The scope of the storage. * \return The created tensor. diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 6305cc1d5e69a44d108ebf018c7603d50d14a245..f54ae188c1309a496b7d57d93979206d9ff13036 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -292,8 +292,8 @@ class Schedule(NodeBase): Parameters ---------- - tensor : Tensor - The tensor to be feed to. + tensor : Tensor, list or tuple + The tensors to be feed to. All the tensors must be produced by one computeOp scope : str The scope of cached diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 03d085aa8d5924d8097341cfd92943100ca6d905..ba158a7c3f795b462c5162c1d38d4bb39002a570 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -425,8 +425,13 @@ TVM_REGISTER_API("_ScheduleCacheRead") TVM_REGISTER_API("_ScheduleCacheWrite") .set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = args[0].operator Schedule() - .cache_write(args[1], args[2]); + if (args[1].IsNodeType<Tensor>()) { + *ret = args[0].operator Schedule() + .cache_write(args[1].operator Tensor(), args[2]); + } else { + *ret = args[0].operator Schedule() + .cache_write(args[1].operator Array<Tensor>(), args[2]); + } }); TVM_REGISTER_API("_ScheduleRFactor") diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 1d9dfed4e5e16fe2e1324cfd28f399d4bbe418bf..89390dcc048d992e6f1646360626312b2008f702 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -78,6 +78,13 @@ void ReplaceDataFlow(const Array<Stage>& stages, } } +inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) { + return (a->combiner.same_as(b->combiner)) && + (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && + (a->condition.same_as(b->condition)); +} + Tensor Schedule::cache_read(const Tensor& tensor, const std::string& scope, const Array<Operation>& readers) { @@ -128,15 +135,15 @@ Tensor Schedule::cache_read(const Tensor& tensor, return cache; } - // Cache write and relayout the data according to loop pattern -Tensor CacheWriteWithReLayout(Schedule sch, - const Tensor& tensor, +Array<Tensor> CacheWriteWithReLayout(Schedule sch, + const Array<Tensor>& tensor_array, const std::string& scope) { + size_t tensor_size = tensor_array.size(); sch->InvalidateCache(); + Tensor tensor = tensor_array[0]; Stage orig_stage = sch[tensor->op]; const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>(); - std::unordered_set<IterVar> red_axis; for (IterVar iv : compute->reduce_axis) { red_axis.insert(iv); @@ -182,9 +189,34 @@ Tensor CacheWriteWithReLayout(Schedule sch, vsub[iv->var.get()] = value_map.at(iv); } } - Expr body = VarReplacer(vsub).Mutate(compute->body[tensor->value_index]); - body = InjectPredicate(predicates, body); - body = VarReplacer(vsub2newvar).Mutate(body); + + Expr body; + Array<Expr> body_list; + const ir::Reduce* first_reduce = nullptr; + for (auto cbody : compute->body) { + body = VarReplacer(vsub).Mutate(cbody); + body = InjectPredicate(predicates, body); + body = VarReplacer(vsub2newvar).Mutate(body); + // Reduce nodes in ONE computeOp must be the same except value_index + // This is right only if the oringinal body ensures Reduce nodes are the same + if (body->is_type<ir::Reduce>()) { + const ir::Reduce* reduce_body = body.as<ir::Reduce>(); + if (first_reduce != nullptr) { + CHECK(ReduceEqual(reduce_body, first_reduce)); + body = ir::Reduce::make(first_reduce->combiner, + first_reduce->source, + first_reduce->axis, + first_reduce->condition, + reduce_body->value_index); + } else { + first_reduce = reduce_body; + } + } else { + CHECK(first_reduce == nullptr) + << "cannot mix reduce and other node in ONE compute bodys"; + } + body_list.push_back(body); + } // The reader args Array<Expr> args; { @@ -200,16 +232,25 @@ Tensor CacheWriteWithReLayout(Schedule sch, } } Operation cache_op = ComputeOpNode::make( - compute->name + "." + scope, compute->tag, new_axis, {body}); - Tensor cache_tensor = cache_op.output(0); + compute->name + "." + scope, compute->tag, new_axis, body_list); + Array<Tensor> cache_tensor_list; + Array<Expr> cache_expr_list; + for (size_t i = 0; i < tensor_size; i++) { + Tensor cache_tensor = cache_op.output(i); + cache_tensor_list.push_back(cache_tensor); + cache_expr_list.push_back(cache_tensor(args)); + } Operation orig_new_op = ComputeOpNode::make( - compute->name, compute->tag, compute->axis, - {cache_tensor(args)}); + compute->name, compute->tag, compute->axis, cache_expr_list); // The replace of the dataflow std::unordered_map<Tensor, Tensor> vmap; std::unordered_map<Tensor, Tensor> rvmap; vmap[orig_stage->op.output(0)] = orig_new_op.output(0); rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + for (size_t i = 0; i < tensor_size; i++) { + vmap[orig_stage->op.output(0)] = orig_new_op.output(0); + rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + } ReplaceDataFlow(sch->stages, &vmap, &rvmap); // mutate orig stage orig_stage->op = orig_new_op; @@ -230,7 +271,26 @@ Tensor CacheWriteWithReLayout(Schedule sch, if (cache_stage->group.defined()) { ++cache_stage->group->num_child_stages; } - return cache_tensor; + return cache_tensor_list; +} + +Array<Tensor> Schedule::cache_write(const Array<Tensor>& tensor_array, + const std::string& scope) { + (*this)->InvalidateCache(); + CHECK(tensor_array.size() > 0) + << "size of tensor_array must be greater than 0"; + Tensor tensor = tensor_array[0]; + Stage orig_stage = operator[](tensor->op); + const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>(); + CHECK(static_cast<size_t>(compute->num_outputs()) == tensor_array.size()) + << "size of input tensor list must be same as number of stage outputs"; + for (size_t i = 1; i < tensor_array.size(); i++) { + Stage tmp_stage = operator[](tensor_array[i]->op); + CHECK(orig_stage.same_as(tmp_stage)) + << "Input tensor list must be generated by ONE computeOp"; + } + + return CacheWriteWithReLayout(*this, tensor_array, scope); } Tensor Schedule::cache_write(const Tensor& tensor, @@ -243,7 +303,7 @@ Tensor Schedule::cache_write(const Tensor& tensor, CHECK_EQ(compute->num_outputs(), 1) << "cache write only support single output ComputeOp"; - return CacheWriteWithReLayout(*this, tensor, scope); + return (CacheWriteWithReLayout(*this, {tensor}, scope))[0]; } void RebaseNonZeroMinLoop(const Schedule& sch) { @@ -289,13 +349,6 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { } } -inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) { - return (a->combiner.same_as(b->combiner)) && - (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && - (a->condition.same_as(b->condition)); -} - void InjectInline(ScheduleNode* sch) { sch->InvalidateCache(); diff --git a/tests/python/integration/test_ewise.py b/tests/python/integration/test_ewise.py index cebacdc2b9a5454699c66e03a7ca43d8adfd872a..e9c69bb57d9f9f72758f98b4478e72e5200b09da 100644 --- a/tests/python/integration/test_ewise.py +++ b/tests/python/integration/test_ewise.py @@ -39,6 +39,49 @@ def test_exp(): check_device("vulkan") +def test_multiple_cache_write(): + # graph + n = tvm.convert(1024) + A0 = tvm.placeholder((n,), name='A0', dtype = "float32") + A1 = tvm.placeholder((n,), name='A1', dtype = "float32") + B0, B1 = tvm.compute((n,), + lambda *i: (A0(*i) + A1(*i), A0(*i) * A1(*i)), + name='B') + C = tvm.compute((n,), lambda *i: B0(*i) + B1(*i), + name='C') + s = tvm.create_schedule(C.op) + # create iter var and assign them tags. + num_thread = 8 + B0_cache, B1_cache = s.cache_write([B0, B1], "local") + bx, tx = s[C].split(C.op.axis[0], factor=num_thread) + s[B0].compute_at(s[C], bx) + s[B0_cache].compute_at(s[C], bx) + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + # one line to build the function. + def check_device(device, host="stackvm"): + if not tvm.module.enabled(host): + return + ctx = tvm.context(device, 0) + if not ctx.exist: + return + func = tvm.build(s, [A0, A1, C], + device, host, + name="multiple_cache_write") + ctx = tvm.context(device, 0) + # launch the kernel. + n = 1024 + a0 = tvm.nd.array(np.random.uniform(size=n).astype(A0.dtype), ctx) + a1 = tvm.nd.array(np.random.uniform(size=n).astype(A1.dtype), ctx) + c = tvm.nd.array(np.zeros(n, dtype=C.dtype), ctx) + func(a0, a1, c) + np.testing.assert_allclose( + c.asnumpy(), a0.asnumpy() + a1.asnumpy() + (a0.asnumpy() * a1.asnumpy()), + rtol=1e-5) + + check_device("cuda", "llvm") + check_device("vulkan") + check_device("opencl") def test_log_pow_llvm(): # graph @@ -199,6 +242,7 @@ def try_warp_memory(): if __name__ == "__main__": test_exp() try_warp_memory() + test_multiple_cache_write() test_add() test_log_pow_llvm() test_popcount() diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index ecb9c0acb7f8cc73119240500621dc50d12ba1d0..8e6f4090d4038f8400ca7aca3385d462ad3f002f 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -249,6 +249,20 @@ def test_schedule_cache_relayout3(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) +def test_schedule_cache_relayout4(): + def _compute(*indice): + return A(*indice) + 1, B(*indice) / 2 + m = tvm.var('m') + n = tvm.var('n') + A = tvm.placeholder((m*4, n), name='A') + B = tvm.placeholder((m*4, n), name='B') + C1, C2 = tvm.compute(A.shape, _compute, name='C') + s = tvm.create_schedule([C1.op, C2.op]) + C1_cache, C2_cache = s.cache_write([C1, C2], "local") + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + def test_schedule_bound_condition(): A = tvm.placeholder((64,), name='A', dtype="float32") @@ -265,6 +279,7 @@ def test_schedule_bound_condition(): if __name__ == "__main__": test_schedule_middle_cache() test_inline_multi_reduce() + test_schedule_cache_relayout4() test_schedule_cache_relayout3() test_schedule_cache_relayout2() test_schedule_cache_relayout1()