From 400c1c483e7aa9aaebd93f3552d1e8e31697e497 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Thu, 7 Sep 2017 15:50:25 -0700 Subject: [PATCH] [SCHEDULE] Enhance cache_write to enable layout change. (#432) * [SCHEDULE] Enahance cache_write to enable layout change. * more tests --- include/tvm/schedule.h | 11 +- python/tvm/schedule.py | 8 + src/op/compute_op.cc | 9 +- src/op/cross_thread_reduction.cc | 6 +- src/op/op_util.cc | 85 ---------- src/op/op_util.h | 17 +- src/op/scan_op.cc | 2 +- src/schedule/message_passing.cc | 140 +++++++++++++++- src/schedule/message_passing.h | 32 ++++ src/schedule/schedule_dataflow_rewrite.cc | 154 ++++++++++++++---- .../unittest/test_schedule_schedule_ops.py | 52 +++++- 11 files changed, 368 insertions(+), 148 deletions(-) diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index aeb5ffa66..957b425a9 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -284,8 +284,15 @@ class Schedule : public NodeRef { /*! * \brief Create a cache write tensor for producing tensor. * The the tensor will take over body of original tensor op. - * The original tensor's body will be changed to an identity read - * from the corresponding cache. + * + * This function can be used to do data layout transformation. + * If there is a split/fuse/reorder on the data parallel axis of tensor + * before cache_write is called. The intermediate cache stores + * the data in the layout as the iteration order of leave axis. + * The data will be transformed back to the original layout in the original tensor. + * User can further call compute_inline to inline the original layout and keep + * the data stored in the transformed layout. + * * \param tensor The tensor to be produced. * \param scope The scope of the storage. * \return The created tensor. diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index ecaeb50bc..26be2de1a 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -248,6 +248,14 @@ class Schedule(NodeBase): This will mutate the body of the tensor. A new cache stage will created before feed into the tensor. + This function can be used to support data layout transformation. + If there is a split/fuse/reorder on the data parallel axis of tensor + before cache_write is called. The intermediate cache stores + the data in the layout as the iteration order of leave axis. + The data will be transformed back to the original layout in the original tensor. + User can further call compute_inline to inline the original layout and keep + the data stored in the transformed layout. + Parameters ---------- tensor : Tensor diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index c7e1b54a4..89d98770b 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -383,8 +383,9 @@ ComputeLoopNest ComputeLoopNest::make( // make main loop nest ret.main_nest = op::MakeLoopNest( stage, dom_map, 0, false, std::unordered_set<IterVar>(), &ret.main_vmap); - ret.main_predicates = op::MakeBoundCheck(stage, dom_map, false, - std::unordered_set<IterVar>(), ret.main_vmap); + ret.main_predicates = schedule::MakeBoundCheck( + stage, dom_map, ret.main_vmap, false, + std::unordered_set<IterVar>()); for (auto& e : ret.main_predicates) { e = likely(e); } @@ -424,8 +425,8 @@ ComputeLoopNest ComputeLoopNest::make( ret.init_nest = op::MakeLoopNest( stage, dom_map, begin_loop, true, skip_iter, &(ret.init_vmap)); - ret.init_predicates = op::MakeBoundCheck( - stage, dom_map, true, skip_iter, ret.init_vmap); + ret.init_predicates = schedule::MakeBoundCheck( + stage, dom_map, ret.init_vmap, true, skip_iter); for (auto& e : ret.init_predicates) { e = likely(e); } diff --git a/src/op/cross_thread_reduction.cc b/src/op/cross_thread_reduction.cc index e79b81c8c..6eec3bd69 100644 --- a/src/op/cross_thread_reduction.cc +++ b/src/op/cross_thread_reduction.cc @@ -21,9 +21,9 @@ Stmt MakeCrossThreadReduction( std::unordered_map<IterVar, Expr> value_map; auto nest = op::MakeLoopNest( stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map); - auto conds = op::MakeBoundCheck( - stage, dom_map, false, - std::unordered_set<IterVar>(), value_map); + auto conds = schedule::MakeBoundCheck( + stage, dom_map, value_map, false, + std::unordered_set<IterVar>()); size_t size = self->body.size(); CHECK_GT(size, 0); diff --git a/src/op/op_util.cc b/src/op/op_util.cc index ea64bacdc..cd0d5e436 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -147,91 +147,6 @@ MakeLoopNest(const Stage& stage, return nest; } - -/*! - * \brief message passing to find if boundary checking on IterVar is needed. - * \param s The stage to be used. - * \param p_state The message passing state - * IterVar->flag - */ -void PassUpBoundCheck(const Stage& s, - const Map<IterVar, Range>& dom_map, - std::unordered_map<IterVar, bool>* p_state) { - auto& state = *p_state; - using Halide::Internal::can_prove; - for (size_t i = s->relations.size(); i != 0; --i) { - IterVarRelation rel = s->relations[i - 1]; - if (rel.as<SplitNode>()) { - const SplitNode* s = rel.as<SplitNode>(); - bool outer = state.at(s->outer); - bool inner = state.at(s->inner); - Expr factor = dom_map.at(s->inner)->extent; - Expr step = dom_map.at(s->outer)->extent; - - if (outer || inner) { - state[s->parent] = true; - } else { - if (can_prove(dom_map.at(s->parent)->extent == factor * step)) { - state[s->parent] = false; - } else { - state[s->parent] = true; - } - } - } else if (rel.as<FuseNode>()) { - const FuseNode* s = rel.as<FuseNode>(); - bool fused = state.at(s->fused); - state[s->outer] = fused; - state[s->inner] = fused; - } else if (rel.as<RebaseNode>()) { - const RebaseNode* s = rel.as<RebaseNode>(); - state[s->parent] = state.at(s->rebased); - } else { - LOG(FATAL) << "unknown relation type"; - } - } -} - -std::vector<Expr> MakeBoundCheck( - const Stage& stage, - const Map<IterVar, Range>& dom_map, - bool skip_ivar_domain, - const std::unordered_set<IterVar>& skip_iter, - const std::unordered_map<IterVar, Expr>& value_map) { - std::unordered_map<IterVar, bool> bound_state; - for (IterVar iv : stage->leaf_iter_vars) { - bound_state[iv] = false; - } - PassUpBoundCheck(stage, dom_map, &bound_state); - std::vector<Expr> preds; - std::unordered_map<const Variable*, IntSet> iset_dmap; - - // setup domain map for set analysis - for (const auto& kv : dom_map) { - iset_dmap[kv.first->var.get()] = IntSet::range(kv.second); - } - - for (IterVar iv : stage->op->root_iter_vars()) { - if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue; - Range dom = dom_map.at(iv); - if (bound_state.at(iv)) { - Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min); - Expr vmax = EvalSet(value, iset_dmap).max(); - if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) { - preds.emplace_back(value < dom->extent); - } - } - CHECK(iv->dom.defined()); - if (!skip_ivar_domain && !iv->dom.same_as(dom)) { - Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min); - Expr vmax = EvalSet(value, iset_dmap).max(); - if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) { - preds.emplace_back(value < iv->dom->extent); - } - } - } - return preds; -} - std::vector<Stmt> MakeIfNest(const std::vector<Expr>& predicates) { Stmt no_op = Evaluate::make(0); std::vector<Stmt> nest; diff --git a/src/op/op_util.h b/src/op/op_util.h index 165113863..783fbb989 100644 --- a/src/op/op_util.h +++ b/src/op/op_util.h @@ -13,6 +13,7 @@ #include <vector> #include "../pass/ir_util.h" #include "../pass/arg_binder.h" +#include "../schedule/message_passing.h" namespace tvm { namespace op { @@ -36,22 +37,6 @@ MakeLoopNest(const Stage& stage, bool new_loop_var, const std::unordered_set<IterVar>& skip_iter, std::unordered_map<IterVar, Expr>* p_value_map); -/*! - * \brief Create boundary check condition for given stage. - * - * \param stage The stage to create a loop nest. - * \param dom_map The range of each iter var. - * \param skip_ivar_domain Whether we can skip check for IterVar's original domain. - * \param skip_iter Whether skip certain iteration. - * \param value_map The result value of each IterVar. - * \return List of predicates that we need to check. - */ -std::vector<Expr> -MakeBoundCheck(const Stage& stage, - const Map<IterVar, Range>& dom_map, - bool skip_ivar_domain, - const std::unordered_set<IterVar>& skip_iter, - const std::unordered_map<IterVar, Expr>& value_map); /*! * \brief Create a nest of if checking the predicates. diff --git a/src/op/scan_op.cc b/src/op/scan_op.cc index f03eb95f1..48565b6eb 100644 --- a/src/op/scan_op.cc +++ b/src/op/scan_op.cc @@ -274,7 +274,7 @@ Stmt ScanOpNode::BuildProvide( nest[begin_scan].push_back(init); nest.push_back( op::MakeIfNest( - op::MakeBoundCheck(stage, dom_map, false, empty, vmap))); + schedule::MakeBoundCheck(stage, dom_map, vmap, false, empty))); return MergeNest(nest, provide); } } // namespace tvm diff --git a/src/schedule/message_passing.cc b/src/schedule/message_passing.cc index 4ba32785d..969a18ee9 100644 --- a/src/schedule/message_passing.cc +++ b/src/schedule/message_passing.cc @@ -7,10 +7,12 @@ #include <tvm/ir.h> #include <tvm/ir_pass.h> #include "./message_passing.h" +#include "../arithmetic/compute_expr.h" namespace tvm { namespace schedule { +using namespace ir; using namespace arith; // result = ceil((a / b)), both a and b are positive integer @@ -123,8 +125,8 @@ void PassUpIndex(const Stage& stage, Expr factor = dom_map.at(s->inner)->extent; Expr outer_min = dom_map.at(s->outer)->min; Expr inner_min = dom_map.at(s->inner)->min; - state[s->outer] = value / factor; - state[s->inner] = value % factor; + state[s->outer] = ComputeExpr<Div>(value, factor); + state[s->inner] = ComputeExpr<Mod>(value, factor); // add min if they exist if (!is_zero(outer_min)) { state[s->outer] = state[s->outer] + outer_min; @@ -151,6 +153,51 @@ void PassUpIndex(const Stage& stage, } } +void PassDownIndex(const Stage& stage, + const Map<IterVar, Range>& dom_map, + std::unordered_map<IterVar, Expr>* p_state, + bool allow_missing) { + auto& state = *p_state; + for (IterVarRelation rel : stage->relations) { + if (const SplitNode* s = rel.as<SplitNode>()) { + if (!state.count(s->parent)) { + CHECK(allow_missing); + continue; + } + Range r = dom_map.at(s->inner); + CHECK(is_zero(r->min)); + Expr parent = state.at(s->parent); + Expr factor = r->extent; + state[s->outer] = ComputeExpr<Div>(parent, factor); + state[s->inner] = ComputeExpr<Mod>(parent, factor); + } else if (const FuseNode* s = rel.as<FuseNode>()) { + if (!state.count(s->inner) && !state.count(s->outer)) { + CHECK(allow_missing); + continue; + } + Expr factor = dom_map.at(s->inner)->extent; + Expr outer_min = dom_map.at(s->outer)->min; + Expr inner_min = dom_map.at(s->inner)->min; + Expr inner = state.at(s->inner); + Expr outer = state.at(s->outer); + CHECK(is_zero(outer_min)); + CHECK(is_zero(inner_min)); + state[s->fused] = outer * factor + inner; + } else if (const RebaseNode* s = rel.as<RebaseNode>()) { + if (!state.count(s->rebased)) { + CHECK(allow_missing); + continue; + } + Expr value = state.at(s->parent); + Expr parent_min = dom_map.at(s->parent)->min; + CHECK(is_zero(parent_min)); + state[s->rebased] = value; + } else { + LOG(FATAL) << "unknown relation type"; + } + } +} + // Domain message passing. void PassUpDomain(const SplitNode* s, const std::unordered_map<IterVar, Range>& dom_map, @@ -349,5 +396,94 @@ void PassDownBitMaskOr(const Stage& stage, } } + +/*! + * \brief message passing to find if boundary checking on IterVar is needed. + * \param s The stage to be used. + * \param p_state The message passing state + * IterVar->flag + */ +void PassUpBoundCheck(const Stage& s, + const Map<IterVar, Range>& dom_map, + std::unordered_map<IterVar, bool>* p_state) { + auto& state = *p_state; + using Halide::Internal::can_prove; + for (size_t i = s->relations.size(); i != 0; --i) { + IterVarRelation rel = s->relations[i - 1]; + if (rel.as<SplitNode>()) { + const SplitNode* s = rel.as<SplitNode>(); + bool outer = state.at(s->outer); + bool inner = state.at(s->inner); + + if (dom_map.count(s->inner) && dom_map.count(s->outer)) { + Expr factor = dom_map.at(s->inner)->extent; + Expr step = dom_map.at(s->outer)->extent; + if (outer || inner) { + state[s->parent] = true; + } else { + if (can_prove(dom_map.at(s->parent)->extent == factor * step)) { + state[s->parent] = false; + } else { + state[s->parent] = true; + } + } + } else { + state[s->parent] = true; + } + } else if (rel.as<FuseNode>()) { + const FuseNode* s = rel.as<FuseNode>(); + bool fused = state.at(s->fused); + state[s->outer] = fused; + state[s->inner] = fused; + } else if (rel.as<RebaseNode>()) { + const RebaseNode* s = rel.as<RebaseNode>(); + state[s->parent] = state.at(s->rebased); + } else { + LOG(FATAL) << "unknown relation type"; + } + } +} + +std::vector<Expr> MakeBoundCheck( + const Stage& stage, + const Map<IterVar, Range>& dom_map, + const std::unordered_map<IterVar, Expr>& value_map, + bool skip_ivar_domain, + const std::unordered_set<IterVar>& skip_iter) { + std::unordered_map<IterVar, bool> bound_state; + for (IterVar iv : stage->leaf_iter_vars) { + bound_state[iv] = false; + } + PassUpBoundCheck(stage, dom_map, &bound_state); + + std::vector<Expr> preds; + std::unordered_map<const Variable*, IntSet> iset_dmap; + + // setup domain map for set analysis + for (const auto& kv : dom_map) { + iset_dmap[kv.first->var.get()] = IntSet::range(kv.second); + } + + for (IterVar iv : stage->op->root_iter_vars()) { + if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue; + Range dom = dom_map.at(iv); + if (bound_state.at(iv)) { + Expr value = ComputeExpr<Sub>(value_map.at(iv), dom->min); + Expr vmax = EvalSet(value, iset_dmap).max(); + if (vmax.type() != value.type() || !can_prove(vmax < dom->extent)) { + preds.emplace_back(value < dom->extent); + } + } + CHECK(iv->dom.defined()); + if (!skip_ivar_domain && !iv->dom.same_as(dom)) { + Expr value = ComputeExpr<Sub>(value_map.at(iv), iv->dom->min); + Expr vmax = EvalSet(value, iset_dmap).max(); + if (vmax.type() != value.type() || !can_prove(vmax < iv->dom->extent)) { + preds.emplace_back(value < iv->dom->extent); + } + } + } + return preds; +} } // namespace schedule } // namespace tvm diff --git a/src/schedule/message_passing.h b/src/schedule/message_passing.h index 5b7cf9d24..baf4a2415 100644 --- a/src/schedule/message_passing.h +++ b/src/schedule/message_passing.h @@ -45,6 +45,20 @@ void PassUpIndex(const Stage& stage, std::unordered_map<IterVar, Expr>* p_state, bool allow_missing = false); +/*! + * \param Downward inference of index of each IterVar. + * given index assignement of roots. + * + * \param stage The stage to operate on. + * \param dom_map The domain map of each iteration variable's domain. + * \param p_state The index state of each IterVar. + * \param allow_missing Whether allow missing value. + */ +void PassDownIndex(const Stage& stage, + const Map<IterVar, Range>& dom_map, + std::unordered_map<IterVar, Expr>* p_state, + bool allow_missing = false); + /*! * \param Upward inference of domain set of each IterVar. * given domain assignment of the leaves, @@ -76,6 +90,24 @@ void PassUpBitMaskOr(const Stage& stage, void PassDownBitMaskOr(const Stage& stage, std::unordered_map<IterVar, int>* p_state, bool allow_missing = false); + +/*! + * \brief Create boundary check predicates given remapped value of root + * \param stage The stage we operate on + * \param dom_map The domain map of each value. + * \param value_map The value map of the root iter var. + * \param skip_ivar_domain Whether we skip check for IterVar's original domain. + * \param skip_iter The set of variables to skip bound condition. + * \return List of predicates that we need to check. + */ +std::vector<Expr> +MakeBoundCheck( + const Stage& stage, + const Map<IterVar, Range>& dom_map, + const std::unordered_map<IterVar, Expr>& value_map, + bool skip_ivar_domain, + const std::unordered_set<IterVar>& skip_iter); + } // namespace schedule } // namespace tvm #endif // TVM_SCHEDULE_MESSAGE_PASSING_H_ diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index c5aca83d5..02ebc21e2 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -9,6 +9,7 @@ #include <unordered_set> #include "./message_passing.h" #include "../pass/ir_util.h" +#include "../arithmetic/compute_expr.h" namespace tvm { @@ -38,6 +39,22 @@ class VarReplacer : public ir::IRMutator { const std::unordered_map<const Variable*, Expr>& vsub_; }; +Expr InjectPredicate(const Array<Expr>& predicates, + Expr body) { + using ir::Reduce; + using ir::Select; + if (predicates.size() == 0) return body; + const Reduce* reduce = body.as<Reduce>(); + if (reduce) { + std::shared_ptr<Reduce> n = std::make_shared<Reduce>(*reduce); + n->condition = n->condition && arith::ComputeReduce<ir::And>(predicates); + return Expr(n); + } + return Select::make(arith::ComputeReduce<ir::And>(predicates), + body, + make_zero(body.type())); +} + // Replace data flow appears in all stages given the tensor change. // Also update vmap if subsequent dataflow need to be replaced. void ReplaceDataFlow(const Array<Stage>& stages, @@ -99,52 +116,101 @@ Tensor Schedule::cache_read(const Tensor& tensor, return cache; } -Tensor Schedule::cache_write(const Tensor& tensor, - const std::string& scope) { - (*this)->InvalidateCache(); - Stage orig_stage = operator[](tensor->op); - const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>(); - CHECK(compute) - << "cache write only take ComputeOp as writers"; - CHECK_EQ(orig_stage->relations.size(), 0U) - << "Create cache_write before doing split/fuse/reorder"; - compute = orig_stage->op.as<ComputeOpNode>(); - CHECK(compute); - Array<Expr> args; + +// Cache write and relayout the data according to loop pattern +Tensor CacheWriteWithReLayout(Schedule sch, + const Tensor& tensor, + const std::string& scope) { + sch->InvalidateCache(); + Stage orig_stage = sch[tensor->op]; + const ComputeOpNode* compute = orig_stage->op.as<ComputeOpNode>(); + + std::unordered_set<IterVar> red_axis; + for (IterVar iv : compute->reduce_axis) { + red_axis.insert(iv); + } + std::unordered_map<IterVar, Range> dom_map; Array<IterVar> new_axis; - std::unordered_map<const Variable*, Expr> vsub; + for (IterVar iv : compute->axis) { - args.push_back(iv->var); - IterVar new_iv = IterVarNode::make( - iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); - new_axis.push_back(new_iv); - vsub[iv->var.get()] = new_iv->var; + dom_map[iv] = iv->dom; + } + schedule::PassDownDomain(orig_stage, &dom_map, true); + std::unordered_map<const Variable*, Expr> vsub; + std::unordered_map<const Variable*, Expr> vsub2newvar; + std::vector<Expr> predicates; + { + // The source->cache + std::unordered_map<IterVar, Expr> value_map; + for (IterVar iv : orig_stage->leaf_iter_vars) { + if (red_axis.count(iv)) continue; + CHECK_EQ(iv->iter_type, kDataPar) + << "Can only relayout with in data parallel dimensions"; + Range dom = dom_map.at(iv); + IterVar new_iv = IterVarNode::make( + dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + new_axis.push_back(new_iv); + if (is_one(dom->min)) { + value_map[iv] = dom->min; + } else { + value_map[iv] = iv->var; + vsub2newvar[iv->var.get()] = new_iv->var; + } + } + // skip reduction iteration. + std::unordered_set<IterVar> skip_bound_check; + for (IterVar iv : compute->reduce_axis) { + skip_bound_check.insert(iv); + } + schedule::PassUpIndex(orig_stage, dom_map, &value_map, true); + predicates = schedule::MakeBoundCheck( + orig_stage, dom_map, value_map, true, skip_bound_check); + // The root axis + for (IterVar iv : compute->axis) { + vsub[iv->var.get()] = value_map.at(iv); + } + } + Expr body = VarReplacer(vsub).Mutate(compute->body[tensor->value_index]); + body = InjectPredicate(predicates, body); + body = VarReplacer(vsub2newvar).Mutate(body); + // The reader args + Array<Expr> args; + { + // cache->compute + std::unordered_map<IterVar, Expr> value_map; + for (IterVar iv : compute->axis) { + value_map[iv] = iv->var; + } + schedule::PassDownIndex(orig_stage, dom_map, &value_map, true); + for (IterVar iv : orig_stage->leaf_iter_vars) { + if (red_axis.count(iv)) continue; + args.push_back(value_map.at(iv)); + } } - VarReplacer repl(vsub); - Expr body = repl.Mutate(compute->body[tensor->value_index]); Operation cache_op = ComputeOpNode::make( compute->name + "." + scope, compute->tag, new_axis, {body}); Tensor cache_tensor = cache_op.output(0); Operation orig_new_op = ComputeOpNode::make( compute->name, compute->tag, compute->axis, {cache_tensor(args)}); - + // The replace of the dataflow std::unordered_map<Tensor, Tensor> vmap; vmap[orig_stage->op.output(0)] = orig_new_op.output(0); - ReplaceDataFlow((*this)->stages, &vmap); + ReplaceDataFlow(sch->stages, &vmap); // mutate orig stage orig_stage->op = orig_new_op; orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; + orig_stage->relations = Array<IterVarRelation>(); // create schedule for new cached stage. - ArrayNode* stages = (*this)->stages.CopyOnWrite(); + ArrayNode* stages = sch->stages.CopyOnWrite(); size_t pos = FindNodeRef(stages, orig_stage); Stage cache_stage = Stage(cache_op); cache_stage.set_scope(scope); CHECK_LT(pos, stages->data.size()); stages->data.insert(stages->data.begin() + pos, cache_stage.node_); - (*this)->stage_map.Set(cache_op, cache_stage); + sch->stage_map.Set(cache_op, cache_stage); // Update group cache_stage->group = orig_stage->group; if (cache_stage->group.defined()) { @@ -153,6 +219,19 @@ Tensor Schedule::cache_write(const Tensor& tensor, return cache_tensor; } +Tensor Schedule::cache_write(const Tensor& tensor, + const std::string& scope) { + (*this)->InvalidateCache(); + Stage orig_stage = operator[](tensor->op); + const ComputeOpNode* compute = tensor->op.as<ComputeOpNode>(); + CHECK(compute) + << "cache write only take ComputeOp as writers"; + CHECK_EQ(compute->num_outputs(), 1) + << "cache write only support single output ComputeOp"; + + return CacheWriteWithReLayout(*this, tensor, scope); +} + void RebaseNonZeroMinLoop(const Schedule& sch) { std::unordered_map<IterVar, IterVar> rebase_map; for (Stage s : sch->stages) { @@ -295,16 +374,23 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, touch_map[axis] = 1; schedule::PassUpBitMaskOr(reduce_stage, &touch_map, true); schedule::PassDownBitMaskOr(reduce_stage, &touch_map, true); + // skip reduction iteration. + std::unordered_set<IterVar> skip_bound_check; // Verify normal axis are not touched. for (IterVar iv : compute_op->axis) { CHECK(!touch_map.count(iv)) << "Factor axis touches normal axis."; + skip_bound_check.insert(iv); } // Get the replace index std::unordered_map<IterVar, Range> dom_map; std::unordered_map<IterVar, Expr> value_map; for (IterVar iv : compute_op->reduce_axis) { - if (touch_map.count(iv)) dom_map[iv] = iv->dom; + if (touch_map.count(iv)) { + dom_map[iv] = iv->dom; + } else { + skip_bound_check.insert(iv); + } } schedule::PassDownDomain(reduce_stage, &dom_map, true); for (IterVar iv : reduce_stage->leaf_iter_vars) { @@ -318,6 +404,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, } } schedule::PassUpIndex(reduce_stage, dom_map, &value_map, true); + std::vector<Expr> predicates = schedule::MakeBoundCheck( + reduce_stage, dom_map, value_map, true, skip_bound_check); + // Get the factored op node. auto n = std::make_shared<ComputeOpNode>(); n->name = compute_op->name + ".rf"; @@ -339,8 +428,11 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, int idx = tensor->value_index; const Reduce* reduce = compute_op->body[idx].as<Reduce>(); CHECK(reduce) << "Can only rfactor non-inline reductions"; - Expr predicate = reduce->condition; + predicates.push_back(reduce->condition); + Expr predicate = arith::ComputeReduce<ir::And>(predicates); + std::unordered_map<const Variable*, Expr> vsub; + for (IterVar iv : compute_op->reduce_axis) { if (!touch_map.count(iv)) { n->reduce_axis.push_back(iv); @@ -348,16 +440,9 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, CHECK(value_map.count(iv)); Expr index = value_map.at(iv); vsub[iv->var.get()] = index; - if (!index.same_as(iv->var)) { - Expr cond = (index < dom_map.at(iv)->extent); - if (is_one(predicate)) { - predicate = cond; - } else { - predicate = predicate && cond; - } - } } } + // Copy touched axis. for (IterVar iv : reduce_stage->leaf_iter_vars) { if (touch_map.count(iv) && !iv.same_as(axis)) { @@ -453,4 +538,5 @@ Array<Tensor> Schedule::rfactor(const Tensor& tensor, reduce_stage->relations = Array<IterVarRelation>(); return factor_tensors; } + } // namespace tvm diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 158d83e78..e9c23d74e 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -55,7 +55,6 @@ def test_schedule_scan(): bounds = tvm.schedule.InferBound(s) assert(bounds[res.op.scan_axis].min.value == 1) stmt = tvm.schedule.ScheduleOps(s, bounds) - print(stmt) def test_auto_inline(): m = tvm.var('m') @@ -160,7 +159,58 @@ def test_schedule_cache(): stmt = tvm.schedule.ScheduleOps(s, bounds) +def test_schedule_cache_relayout1(): + m = tvm.var('m') + n = tvm.var('n') + A = tvm.placeholder((m, n), name='A') + B = tvm.placeholder((m, n), name='B') + C = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='C') + + s = tvm.create_schedule(C.op) + s[C].reorder(C.op.axis[1], C.op.axis[0]) + CC = s.cache_write(C, "global") + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + +def test_schedule_cache_relayout2(): + m = tvm.var('m') + n = tvm.var('n') + A = tvm.placeholder((m*4, n), name='A') + B = tvm.placeholder((m*4, n), name='B') + C = tvm.compute(A.shape, lambda i, j: A(i, j) * B(i, j), name='C') + s = tvm.create_schedule(C.op) + x, y = C.op.axis + xo, xi = s[C].split(x, factor=4) + s[C].reorder(xo, y, xi) + CC = s.cache_write(C, "global") + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + +def test_schedule_cache_relayout3(): + m = tvm.var('m') + n = tvm.var('n') + A = tvm.placeholder((m*4, n), name='A') + B = tvm.placeholder((m*4, n), name='B') + k = tvm.reduce_axis((0, n), "k") + C = tvm.compute((A.shape[0],), + lambda i: tvm.sum(A(i, k) * B(i, k), axis=k), name='C') + s = tvm.create_schedule(C.op) + x = C.op.axis[0] + xo, xi = s[C].split(x, factor=4) + CC = s.cache_write(C, "global") + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + if __name__ == "__main__": + test_schedule_cache_relayout4() + test_schedule_cache_relayout3() + test_schedule_cache_relayout2() + test_schedule_cache_relayout1() test_schedule_const_bound() test_scan_inline1() test_scan_inline2() -- GitLab