diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 89d98770b3107f06a0c2603c86d3677b3e012817..3c17e47030cfd52f3c0ac113bf8a01ae660b131d 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -98,7 +98,7 @@ Array<Tensor> compute(Array<Expr> shape, return outputs; } -bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) { +inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) { return (a->combiner.same_as(b->combiner)) && (a->source.same_as(b->source)) && (a->axis.same_as(b->axis)) && diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 02ebc21e235a41fc08227b3a70f124a32561bfa4..a8dc4edf57f12eb238d210a569c9df28b18f980a 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -275,10 +275,17 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { } } +inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) { + return (a->combiner.same_as(b->combiner)) && + (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && + (a->condition.same_as(b->condition)); +} + void InjectInline(ScheduleNode* sch) { sch->InvalidateCache(); - std::vector<Array<Expr>> new_body(sch->stages.size()); + std::vector<Array<Expr> > new_body(sch->stages.size()); std::vector<bool> changed(sch->stages.size(), false); // inline all the ops for (size_t i = sch->stages.size(); i != 0; --i) { @@ -286,7 +293,7 @@ void InjectInline(ScheduleNode* sch) { if (stage->attach_type == kInline) { stage->attach_type = kInlinedAlready; Array<Var> args; - Array<Expr> body; + Expr body; { // setup args const ComputeOpNode* compute = stage->op.as<ComputeOpNode>(); @@ -295,7 +302,9 @@ void InjectInline(ScheduleNode* sch) { for (auto iv : compute->axis) { args.push_back(iv->var); } - body = compute->body; + CHECK_EQ(compute->body.size(), 1U) + << "can only inline compute op with 1 output"; + body = compute->body[0]; } for (size_t j = i; j < sch->stages.size(); ++j) { Stage s = sch->stages[j]; @@ -304,10 +313,39 @@ void InjectInline(ScheduleNode* sch) { if (!new_body[j].size()) { new_body[j] = s->op.as<ComputeOpNode>()->body; } - for (size_t k = 0; k < body.size(); ++k) { - changed[j] = true; - new_body[j].Set(k, ir::Inline(ir::Evaluate::make(new_body[j][k]), - stage->op, args, body[k]).as<ir::Evaluate>()->value); + if (new_body[j][0]->is_type<ir::Reduce>()) { + // specially handle reduction inline for multiplre reductions. + const ir::Reduce* reduce = new_body[j][0].as<ir::Reduce>(); + for (size_t k = 1; k < new_body[j].size(); ++k) { + const ir::Reduce* reduce_ = new_body[j][k].as<ir::Reduce>(); + CHECK(reduce_); + CHECK(ReduceEqual(reduce_, reduce)) + << "The Reduce inputs of ComputeOp should " + << "have the same attribute except value_index"; + } + Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][0]), + stage->op, args, body).as<ir::Evaluate>()->value; + if (!new_value.same_as(new_body[j][0])) { + changed[j] = true; + const ir::Reduce* r = new_value.as<ir::Reduce>(); + CHECK_EQ(new_body[j].size(), r->source.size()); + CHECK(r != nullptr); + for (size_t k = 0; k < new_body[j].size(); ++k) { + std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r); + n->value_index = static_cast<int>(k); + n->type = r->source[k].type(); + new_body[j].Set(k, Expr(n)); + } + } + } else { + for (size_t k = 0; k < new_body[j].size(); ++k) { + Expr new_value = ir::Inline(ir::Evaluate::make(new_body[j][k]), + stage->op, args, body).as<ir::Evaluate>()->value; + if (!new_value.same_as(new_body[j][k])) { + new_body[j].Set(k, new_value); + changed[j] = true; + } + } } } } diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 68b11d5c2ad5597524e5b6652483de58f4ce055e..252cea21f7cec58ae0f92ddf3e790ecbd1bdfeea 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -56,6 +56,28 @@ def test_schedule_scan(): assert(bounds[res.op.scan_axis].min.value == 1) stmt = tvm.schedule.ScheduleOps(s, bounds) +def test_inline_multi_reduce(): + def argmax_comp(x, y): + idx = tvm.select((x[1] >= y[1]), x[0], y[0]) + val = tvm.select((x[1] >= y[1]), x[1], y[1]) + return idx, val + def argmax_init(idx_typ, val_typ): + return tvm.const(-1, idx_typ), tvm.min_value(val_typ) + + argmax = tvm.comm_reducer(argmax_comp, argmax_init, name='argmax') + m = tvm.var('m') + n = tvm.var('n') + val = tvm.placeholder((m, n), name='val', dtype='float32') + val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val[i, j]), name='val2') + k = tvm.reduce_axis((0, n), 'k') + T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T') + s = tvm.create_schedule(T_idx.op) + s[val2].compute_inline() + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + def test_auto_inline(): m = tvm.var('m') n = tvm.var('n') @@ -207,6 +229,7 @@ def test_schedule_cache_relayout3(): if __name__ == "__main__": + test_inline_multi_reduce() test_schedule_cache_relayout3() test_schedule_cache_relayout2() test_schedule_cache_relayout1()