diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index 3c17e47030cfd52f3c0ac113bf8a01ae660b131d..11731361148d6b528e4f5a3c10ddc39ad93e5c74 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -24,6 +24,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(ComputeOpNode); +inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) { + return (a->combiner.same_as(b->combiner)) && + (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && + (a->condition.same_as(b->condition)); +} + int ComputeOpNode::num_outputs() const { return body.size(); } @@ -98,13 +105,6 @@ Array<Tensor> compute(Array<Expr> shape, return outputs; } -inline bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) { - return (a->combiner.same_as(b->combiner)) && - (a->source.same_as(b->source)) && - (a->axis.same_as(b->axis)) && - (a->condition.same_as(b->condition)); -} - Operation ComputeOpNode::make(std::string name, std::string tag, Array<IterVar> axis, @@ -151,9 +151,35 @@ Operation ComputeOpNode::ReplaceInputs( const Operation& self, const std::unordered_map<Tensor, Tensor>& rmap) const { CHECK_EQ(self.operator->(), this); - Array<Expr> arr = UpdateArray(this->body, [&rmap] (const Expr& e) { - return op::ReplaceTensor(e, rmap); - }); + Array<Expr> arr; + if (this->body[0]->is_type<ir::Reduce>()) { + // Specially handle reduce so the replaced op + // still share all the components + const ir::Reduce* reduce = this->body[0].as<ir::Reduce>(); + for (size_t i = 1; i < this->body.size(); ++i) { + const ir::Reduce* reduce_ = this->body[i].as<ir::Reduce>(); + CHECK(reduce_); + CHECK(ReduceEqual(reduce_, reduce)) + << "The Reduce inputs of ComputeOp should " + << "have the same attribute except value_index"; + }\ + Expr new_reduce = op::ReplaceTensor(this->body[0], rmap); + if (!new_reduce.same_as(this->body[0])) { + const ir::Reduce* r = new_reduce.as<ir::Reduce>(); + for (size_t k = 0; k < this->body.size(); ++k) { + std::shared_ptr<ir::Reduce> n = std::make_shared<ir::Reduce>(*r); + n->value_index = static_cast<int>(k); + n->type = r->source[k].type(); + arr.push_back(Expr(n)); + } + } else { + arr = this->body; + } + } else { + arr = UpdateArray(this->body, [&rmap] (const Expr& e) { + return op::ReplaceTensor(e, rmap); + }); + } if (!arr.same_as(this->body)) { return ComputeOpNode::make(name, tag, axis, arr); } else { diff --git a/src/op/op_util.cc b/src/op/op_util.cc index cd0d5e43625e90f290f31cfb797d23a8d3ec56b4..7cf6711d2270c5209ecd7d0729cb53b7fc04b49a 100644 --- a/src/op/op_util.cc +++ b/src/op/op_util.cc @@ -162,6 +162,7 @@ class TensorReplacer : public ir::IRMutator { public: explicit TensorReplacer(const std::unordered_map<Tensor, Tensor>& vmap) : vmap_(vmap) {} + Expr Mutate_(const ir::Call* op, const Expr& e) { if (op->call_type == ir::Call::Halide) { Tensor t = Operation(op->func.node_).output(op->value_index); diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 252cea21f7cec58ae0f92ddf3e790ecbd1bdfeea..a85db2a23e86cb1dd39598b84b282aedcd32ea19 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -68,16 +68,18 @@ def test_inline_multi_reduce(): m = tvm.var('m') n = tvm.var('n') val = tvm.placeholder((m, n), name='val', dtype='float32') - val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val[i, j]), name='val2') + val1 = tvm.compute((m, n), lambda i, j: val[i, j]+1, name='val1') + val2 = tvm.compute((m, n), lambda i, j: tvm.exp(val1[i, j]), name='val2') k = tvm.reduce_axis((0, n), 'k') T_idx, T_val = tvm.compute((m, ), lambda i: argmax((k.var, val2[i, k]), axis=k), name='T') s = tvm.create_schedule(T_idx.op) - s[val2].compute_inline() + s[val1].compute_inline() s = s.normalize() bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) + def test_auto_inline(): m = tvm.var('m') n = tvm.var('n')