diff --git a/include/tvm/ir.h b/include/tvm/ir.h index f2f47b8cb4b0a68167a8616cce2fe1564c9fcda1..834b1baf364faa8b080b89defd3a315191ad8b28 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -47,23 +47,27 @@ struct CommReducer : public NodeRef { * binary operator with identity element */ struct CommReducerNode : public Node { - /*! \brief The arguments of reducer */ - Array<Var> args; + /*! \brief The left argument of reducer */ + Array<Var> lhs; + /*! \brief The right argument of reducer */ + Array<Var> rhs; /*! \brief The result of reducer */ - Expr result; + Array<Expr> result; /*! * \brief The identity element of reducer, which leaves other * elements unchanged when combined with it, with respect to * the binary operation of this reducer uses. */ - Expr identity_element; + Array<Expr> identity_element; /*! \brief Function call operator to combine a and b */ - Expr operator()(Expr a, Expr b) const; + Array<Expr> operator()(Array<Expr> a, Array<Expr> b) const; /*! \brief construct CommReducer from args, result and identity_element */ - static CommReducer make(Array<Var> args, Expr result, Expr identity_element); + static CommReducer make(Array<Var> lhs, Array<Var> rhs, + Array<Expr> result, Array<Expr> identity_element); void VisitAttrs(AttrVisitor* v) final { - v->Visit("args", &args); + v->Visit("lhs", &lhs); + v->Visit("rhs", &rhs); v->Visit("result", &result); v->Visit("identity_element", &identity_element); } @@ -84,7 +88,7 @@ struct Reduce : public ExprNode<Reduce> { /*! \brief The commutative combiner */ CommReducer combiner; /*! \brief The source operand */ - Expr source; + Array<Expr> source; /*! \brief The reduction axis */ Array<IterVar> axis; /*! @@ -92,18 +96,22 @@ struct Reduce : public ExprNode<Reduce> { * Only add the body to reduction if condition is true. */ Expr condition; + /*! \brief the index of this reduce node */ + int value_index; /*! \brief construct expr from op and rdom */ static Expr make(CommReducer combiner, - Expr src, + Array<Expr> src, Array<IterVar> rdom, - Expr condition = const_true()); + Expr condition, + int value_index); void VisitAttrs(AttrVisitor* v) final { v->Visit("dtype", &type); v->Visit("source", &source); v->Visit("axis", &axis); v->Visit("condition", &condition); + v->Visit("value_index", &value_index); } static const IRNodeType _type_info = IRNodeType::ExtensionExpr; static constexpr const char* _type_key = "Reduce"; @@ -292,11 +300,12 @@ constexpr const char* tvm_global_barrier_kinit = "tvm_global_barrier_kinit"; /*! * \brief See pesudo code * - * Expr tvm_thread_allreduce(CommReducer combiner, Expr value, Expr cond, - * Var thread_idx1, thread_idx2...) { + * void tvm_thread_allreduce(UIntImm size, Expr source0, ..., Expr cond, + * Var reduce_temp0, .., Var thread_idx1, ...) { * // constraint by the other thread_idx remain the same. - * return reduce(combiner, value, cond, - * over [thread_idx1, thread_idx2] passed by any caller) + * // reduce_temp is used to save intermediate result. + * reduce_temp0, ... = reduce(combiner, source0, ..., cond + * over [thread_idx1, thread_idx2] passed by any caller) * } */ constexpr const char* tvm_thread_allreduce = "tvm_thread_allreduce"; diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 7e323870413768ef2df909cae90413beff828812..fc0bc1f1abd28930ccaf8e531f1bbc0ea59d214d 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -96,10 +96,10 @@ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map); /*! * \brief inline all calls of f in stmt. * + * \param stmt The statement to apply inline optimization. * \param f The function reference to be inlined * \param args The arguments variable of the function. - * \param body The defintion body of the function. - * \param stmt The statement to apply inline optimization. + * \param body The definition body of the function. * \return The result stmt * * \note All the passes in this file uses SSA form and outputs SSA form. diff --git a/include/tvm/operation.h b/include/tvm/operation.h index eb0ee37569f1c28238461234d90644c0234d4677..0533bdcea6fba8c25eb81d9bfa43435cc64486cc 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -182,7 +182,7 @@ class ComputeOpNode : public OperationNode { /*! \brief IterVar on each reduction axis, if the body is a Reduce */ Array<IterVar> reduce_axis; /*! \brief the compute expression */ - Expr body; + Array<Expr> body; /*! \brief constructor */ ComputeOpNode() {} // override functions @@ -218,7 +218,7 @@ class ComputeOpNode : public OperationNode { } static Operation make(std::string name, Array<IterVar> axis, - Expr body); + Array<Expr> body); static constexpr const char* _type_key = "ComputeOp"; TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode); @@ -358,6 +358,9 @@ class ExternOpNode : public OperationNode { /*! \brief The compute function to specify the input source of a Tensor */ using FCompute = std::function<Expr (const Array<Var>& i)>; +/*! \brief The compute function to specify the inputs source of Tensors */ +using FBatchCompute = std::function<Array<Expr> (const Array<Var>& i)>; + /*! * \brief create a place holder tensor. * \param shape The shape of the tensor. @@ -377,6 +380,15 @@ Tensor placeholder(Array<Expr> shape, */ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name = "tensor"); +/*! + * \brief Construct a new tensor by computing over shape, + * using the computation rule: result_tensor[axis] = fcompute(axis) + * \param shape Shape of the tensor. + * \param fcompute The compute function to create the tensors. + * \param name The optional name of the tensor. + */ +Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name = "tensor"); + /*! * \brief Construct new tensors by scan. * diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 4479c4fbee80d0f3e85685d949f3e45055802735..9f8a4bd51f2f99124644d8e7a85e685611076c9a 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -252,15 +252,15 @@ class Schedule : public NodeRef { /*! * \brief Factor a reduction axis in tensor's schedule to be an explicit axis. * This will create a new stage that generated the new tensor with axis - * as the first dimension. The tensor's body wil be rewriten as a reduction + * as the first dimension. The tensor's body will be rewritten as a reduction * over the factored tensor. * * \param tensor The tensor to be factored. * \param axis The reduction axis in tensor's schedule to be factored. - * \return The created factored tensor. + * \return The created factored tensors. */ - Tensor rfactor(const Tensor& tensor, - const IterVar& axis); + Array<Tensor> rfactor(const Tensor& tensor, + const IterVar& axis); /*! * \brief Normalize the schedule. * This is needed before bound inference. diff --git a/python/tvm/api.py b/python/tvm/api.py index 2ef18d210342a2429f193b44f675d594de9d6390..7ea6a8e81e6ba036ecb73ffcaba5cd1d543a966f 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -174,10 +174,14 @@ def compute(shape, fcompute, name="compute"): dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)] body = fcompute(*[v.var for v in dim_var]) + if not isinstance(body, (list, tuple)): + body = [body] body = convert(body) op_node = _api_internal._ComputeOp( name, dim_var, body) - return op_node.output(0) + num = op_node.num_outputs + outputs = tuple(op_node.output(i) for i in range(num)) + return outputs[0] if num == 1 else outputs def scan(init, update, state_placeholder, inputs=None, name="scan"): @@ -525,18 +529,45 @@ def comm_reducer(fcombine, fidentity, name="reduce"): return res def _make_reduce(expr, axis, where=None): - expr = convert(expr) - dtype = expr.dtype code = fcombine.__code__ assert fcombine.__code__.co_argcount == 2 - arg_vars = [var(name, dtype) for name in code.co_varnames] - result = fcombine(*[v for v in arg_vars]) + expr = convert(expr) + if isinstance(expr, _collections.Array): + size = len(expr) + larr = [] + rarr = [] + dtypes = [] + for i in range(size): + dtype = expr[i].dtype + dtypes.append(dtype) + lname = code.co_varnames[0] + '_' + str(i) + larr.append(var(lname, dtype)) + rname = code.co_varnames[1] + '_' + str(i) + rarr.append(var(rname, dtype)) + lhs = convert(larr) + rhs = convert(rarr) + result = fcombine(lhs, rhs) + id_elem = fidentity(*dtypes) + else: + assert isinstance(expr, _expr.Expr) + size = 1 + dtype = expr.dtype + lvar = var(code.co_varnames[0], dtype) + rvar = var(code.co_varnames[1], dtype) + result = [fcombine(lvar, rvar)] + id_elem = [fidentity(dtype)] + lhs = convert([lvar]) + rhs = convert([rvar]) + expr = convert([expr]) result = convert(result) - id_elem = fidentity(dtype) - assert isinstance(id_elem, _expr.Expr) - combiner = _make.CommReducer(arg_vars, result, id_elem) - axis = axis if isinstance(axis, list) else [axis] - return _make.Reduce(combiner, expr, axis, where) + id_elem = convert(id_elem) + combiner = _make.CommReducer(lhs, rhs, result, id_elem) + axis = convert(axis if isinstance(axis, list) else [axis]) + if where is None: + where = convert(True) + outputs = tuple(_make.Reduce(combiner, expr, axis, where, i) + for i in range(size)) + return outputs[0] if size == 1 else outputs def reducer(expr, axis, where=None, *args): if isinstance(axis, (_schedule.IterVar, list)): diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 8b2e63f9c884079423a9e318bcd45a0eec871269..e9c8a179c95f4252ecd9838d2a5b582b39df9dbc 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -181,7 +181,7 @@ class Schedule(NodeBase): """ Factor a reduction axis in tensor's schedule to be an explicit axis. This will create a new stage that generated the new tensor with axis - as the first dimension. The tensor's body wil be rewriten as a reduction + as the first dimension. The tensor's body will be rewritten as a reduction over the factored tensor. Parameters @@ -193,10 +193,11 @@ class Schedule(NodeBase): Returns ------- - tfactor : Tensor + tfactor : Tensor or Array of Tensor The created factored tensor. """ - return _api_internal._ScheduleRFactor(self, tensor, axis) + factored = _api_internal._ScheduleRFactor(self, tensor, axis) + return factored[0] if len(factored) == 1 else factored @register_node diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 00ab79b19167fe1b2fd92b567fb7e17b9b2b8370..f66652b99157d5a29eaade5f7e8caef46f576a27 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -68,11 +68,13 @@ TVM_REGISTER_API("make.Call") }); TVM_REGISTER_API("make.CommReducer") -.set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = CommReducerNode::make(args[0], args[1], args[2]); +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = CommReducerNode::make(args[0], + args[1], + args[2], + args[3]); }); - // make from two arguments #define REGISTER_MAKE1(Node) \ TVM_REGISTER_API("make."#Node) \ @@ -112,7 +114,7 @@ TVM_REGISTER_API("make.CommReducer") *ret = Node::make(a, b); \ }) -REGISTER_MAKE4(Reduce); +REGISTER_MAKE5(Reduce); REGISTER_MAKE4(AttrStmt); REGISTER_MAKE2(IntImm); diff --git a/src/lang/expr.cc b/src/lang/expr.cc index 795f34fe673fcd61a7b0e7b61cce79ce59b2561d..9e0feb44479f90482b6070b55171b38ce44a1577 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -50,24 +50,27 @@ Expr sum(Expr source, Array<IterVar> rdom) { Var x("x"), y("y"); Expr result = ir::Add::make(x, y); Expr identity_element = make_zero(source.type()); - ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element); - return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true)); + ir::CommReducer combiner = + ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } Expr max(Expr source, Array<IterVar> rdom) { Var x("x"), y("y"); Expr result = ir::Max::make(x, y); Expr identity_element = source.type().min(); - ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element); - return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true)); + ir::CommReducer combiner = + ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } Expr min(Expr source, Array<IterVar> rdom) { Var x("x"), y("y"); Expr result = ir::Min::make(x, y); Expr identity_element = source.type().max(); - ir::CommReducer combiner = ir::CommReducerNode::make({x, y}, result, identity_element); - return ir::Reduce::make(combiner, source, rdom, make_const(Bool(1), true)); + ir::CommReducer combiner = + ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); + return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); } std::ostream& operator<<(std::ostream& os, const NodeRef& n) { // NOLINT(*) diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 52ba225253dae7f08f76d2852b060ac323c2c24a..e7903333562f9fd3f640f74a1298b775ec465be1 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -9,6 +9,7 @@ #include <ir/IR.h> #include <ir/IRPrinter.h> #include <memory> +#include "../pass/ir_util.h" namespace Halide { namespace Internal { @@ -25,23 +26,20 @@ void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const { TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) { p->stream << "reduce(combiner=" - << op->combiner - << ", "; - p->print(op->source); + << op->combiner; + p->stream << ", source=" << op->source; p->stream << ", axis=" << op->axis; - if (!is_const(op->condition, 1)) { - p->stream << ", where=" << op->condition; - } + p->stream << ", where=" << op->condition; + p->stream << ", value_index=" << op->value_index; p->stream << ")"; }); TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) .set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) { - p->stream << "comm_reducer(result=" - << op->result - << ", args=" << op->args - << ", identity_element=" - << op->identity_element + p->stream << "comm_reducer(result=" << op->result + << ", lhs=" << op->lhs + << ", rhs=" << op->rhs + << ", identity_element=" << op->identity_element << ")"; }); } // namespace Internal @@ -50,23 +48,34 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) namespace tvm { namespace ir { -CommReducer CommReducerNode::make(Array<Var> args, Expr result, Expr identity_element) { +CommReducer CommReducerNode::make(Array<Var> lhs, + Array<Var> rhs, + Array<Expr> result, + Array<Expr> identity_element) { auto node = std::make_shared<CommReducerNode>(); - node->args = args; + node->lhs = lhs; + node->rhs = rhs; node->result = result; node->identity_element = identity_element; return CommReducer(node); } -Expr CommReducerNode::operator()(Expr a, Expr b) const { +Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const { + CHECK_EQ(a.size(), b.size()); + CHECK_EQ(lhs.size(), a.size()); + CHECK_EQ(rhs.size(), b.size()); Map<Var, Expr> value_map; - value_map.Set(args[0], a); - value_map.Set(args[1], b); - return Substitute(result, value_map); + for (size_t i = 0; i < a.size(); ++i) { + value_map.Set(lhs[i], a[i]); + value_map.Set(rhs[i], b[i]); + } + return UpdateArray(result, [&value_map] (const Expr& e) { + return Substitute(e, value_map); + }); } -Expr Reduce::make(CommReducer combiner, Expr source, - Array<IterVar> axis, Expr condition) { +Expr Reduce::make(CommReducer combiner, Array<Expr> source, + Array<IterVar> axis, Expr condition, int value_index) { for (size_t i = 0; i < axis.size(); ++i) { CHECK_EQ(axis[i]->iter_type, kCommReduce) << "Can only take axis created by reduce_axis"; @@ -79,11 +88,12 @@ Expr Reduce::make(CommReducer combiner, Expr source, for (size_t i = 0; i < axis.size(); ++i) { CHECK(axis[i].defined()); } - n->type = source.type(); - n->combiner = combiner; - n->source = source; - n->axis = axis; + n->type = source[value_index].type(); + n->combiner = std::move(combiner); + n->source = std::move(source); + n->axis = std::move(axis); n->condition = condition; + n->value_index = value_index; return Expr(n); } diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index a2d3b25e25e080423fb006cdf3bdcdb3a03c0886..be594a6b6e4a2d2433d4de14d779025e77483128 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -24,7 +24,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(ComputeOpNode); int ComputeOpNode::num_outputs() const { - return 1; + return body.size(); } Array<IterVar> ComputeOpNode::root_iter_vars() const { @@ -36,13 +36,14 @@ Array<IterVar> ComputeOpNode::root_iter_vars() const { return ret; } -Type ComputeOpNode::output_dtype(size_t i) const { - CHECK_EQ(i, 0U); - return body.type(); +Type ComputeOpNode::output_dtype(size_t idx) const { + CHECK_LT(idx, num_outputs()); + return body[idx].type(); } -Array<Expr> ComputeOpNode::output_shape(size_t i) const { - CHECK_EQ(i, 0U); +Array<Expr> ComputeOpNode::output_shape(size_t idx) const { + CHECK_LT(idx, num_outputs()); + // for now, all outputs of ComputeOp have the same shape std::vector<Expr> shape; for (size_t i = 0; i < axis.size(); ++i) { const Range& r = axis[i]->dom; @@ -65,18 +66,55 @@ Tensor compute(Array<Expr> shape, FCompute fcompute, std::string name) { args.push_back(axis.back()->var); } - return ComputeOpNode::make(name, axis, fcompute(args)).output(0); + return ComputeOpNode::make(name, axis, {fcompute(args)}).output(0); +} + +Array<Tensor> compute(Array<Expr> shape, FBatchCompute fcompute, std::string name) { + auto op_node = std::make_shared<ComputeOpNode>(); + // compute dimension. + size_t ndim = shape.size(); + std::vector<IterVar> axis; + std::vector<Var> args; + for (size_t i = 0; i < ndim; ++i) { + std::ostringstream os; + os << "ax" << i; + axis.emplace_back(IterVarNode::make( + Range(0, shape[i]), Var(os.str(), shape[i].type()), kDataPar)); + args.push_back(axis.back()->var); + } + + Operation op = ComputeOpNode::make(name, axis, fcompute(args)); + Array<Tensor> outputs; + for (int idx = 0; idx < op->num_outputs(); ++idx) { + outputs.push_back(op.output(idx)); + } + return outputs; +} + +bool ReduceEqual(const ir::Reduce* a, const ir::Reduce* b) { + return (a->combiner.same_as(b->combiner)) && + (a->source.same_as(b->source)) && + (a->axis.same_as(b->axis)) && + (a->condition.same_as(b->condition)); } Operation ComputeOpNode::make(std::string name, Array<IterVar> axis, - Expr body) { + Array<Expr> body) { auto n = std::make_shared<ComputeOpNode>(); n->name = name; n->axis = axis; n->body = body; - if (n->body->is_type<ir::Reduce>()) { - n->reduce_axis = n->body.as<ir::Reduce>()->axis; + if (n->body[0]->is_type<ir::Reduce>()) { + const ir::Reduce* reduce = n->body[0].as<ir::Reduce>(); + for (size_t i = 1; i < n->body.size(); ++i) { + const ir::Reduce* reduce_ = n->body[i].as<ir::Reduce>(); + CHECK(reduce_); + CHECK(ReduceEqual(reduce_, reduce)) + << "The Reduce inputs of ComputeOp should " + << "have the same attribute except value_index"; + } + n->reduce_axis = reduce->axis; } return Operation(n); } @@ -85,16 +123,18 @@ Operation ComputeOpNode::make(std::string name, Array<Tensor> ComputeOpNode::InputTensors() const { Array<Tensor> ret; std::unordered_set<Tensor> visited; - ir::PostOrderVisit(body, [&ret, &visited](const NodeRef& n) { - const ir::Call *call = n.as<ir::Call>(); - if (call != nullptr && call->func.defined()) { - Tensor t = Operation(call->func.node_).output(call->value_index); - if (!visited.count(t)) { - ret.push_back(t); - visited.insert(t); + for (auto& e : body) { + ir::PostOrderVisit(e, [&ret, &visited](const NodeRef& n) { + const ir::Call *call = n.as<ir::Call>(); + if (call != nullptr && call->func.defined()) { + Tensor t = Operation(call->func.node_).output(call->value_index); + if (!visited.count(t)) { + ret.push_back(t); + visited.insert(t); + } } - } - }); + }); + } return ret; } @@ -102,9 +142,11 @@ Operation ComputeOpNode::ReplaceInputs( const Operation& self, const std::unordered_map<Tensor, Tensor>& rmap) const { CHECK_EQ(self.operator->(), this); - Expr new_body = op::ReplaceTensor(this->body, rmap); - if (!new_body.same_as(this->body)) { - return ComputeOpNode::make(name, axis, new_body); + Array<Expr> arr = UpdateArray(this->body, [&rmap] (const Expr& e) { + return op::ReplaceTensor(e, rmap); + }); + if (!arr.same_as(this->body)) { + return ComputeOpNode::make(name, axis, arr); } else { return self; } @@ -127,7 +169,7 @@ void ComputeOpNode::PropBoundToInputs( } } }; - ir::PostOrderVisit(body, fvisit); + for (auto& e : body) ir::PostOrderVisit(e, fvisit); } void ComputeOpNode::GatherBound( @@ -151,34 +193,50 @@ Stmt ComputeOpNode::BuildRealize( const std::unordered_map<IterVar, Range>& realize_map, const Stmt& realize_body) const { CHECK_EQ(self.operator->(), this); - Tensor t = self.output(0); Halide::Internal::Region bounds; for (IterVar iv : this->axis) { bounds.push_back(realize_map.at(iv)); } - return ir::Realize::make(t->op, t->value_index, t->dtype, - bounds, const_true(), realize_body); + Stmt realize = realize_body; + for (int i = self->num_outputs(); i > 0; --i) { + Tensor t = self.output(i-1); + realize = ir::Realize::make(t->op, t->value_index, + t->dtype, bounds, const_true(), realize); + } + return realize; } // Build a reduction body. void MakeReduction(const ComputeOpNode* op, - const Tensor& t, + const Array<Tensor>& tensors, Stmt* init, Stmt* provide) { - Stmt no_op = Evaluate::make(0); - std::vector<Stmt> nest; Array<Expr> args; for (IterVar iv : op->axis) { args.push_back(iv->var); } - const Reduce* reduce = op->body.as<Reduce>(); + std::vector<Stmt> inits, provides; + + size_t size = op->body.size(); + const Reduce* reduce = op->body[0].as<Reduce>(); CHECK(reduce); const CommReducerNode* combiner = reduce->combiner.as<CommReducerNode>(); CHECK(combiner); - Expr init_value = combiner->identity_element; - Expr update_value = (*combiner)(t(args), reduce->source); - *init = Provide::make(t->op, t->value_index, init_value, args); - *provide = Provide::make(t->op, t->value_index, update_value, args); + Array<Expr> lhs; + for (size_t i = 0; i < size; ++i) { + lhs.push_back(tensors[i](args)); + } + Array<Expr> init_value = combiner->identity_element; + Array<Expr> update_value = (*combiner)(lhs, reduce->source); + for (size_t i = 0; i < size; ++i) { + Tensor t = tensors[i]; + inits.emplace_back(Provide::make( + t->op, t->value_index, init_value[i], args)); + provides.emplace_back(Provide::make( + t->op, t->value_index, update_value[i], args)); + } + *init = Block::make(inits); + *provide = Block::make(provides); if (!is_one(reduce->condition)) { *provide = IfThenElse::make(reduce->condition, *provide); } @@ -225,22 +283,36 @@ Stmt MakeCrossThreadReduction( for (IterVar iv : self->axis) { args.push_back(iv->var); } - const Reduce* reduce = self->body.as<Reduce>(); - CHECK(reduce); std::unordered_map<IterVar, Expr> value_map; auto nest = op::MakeLoopNest( stage, dom_map, 0, false, std::unordered_set<IterVar>(), &value_map); auto conds = op::MakeBoundCheck( stage, dom_map, false, std::unordered_set<IterVar>(), value_map); - Expr cond = reduce->condition; + + size_t size = self->body.size(); + CHECK_GT(size, 0); + std::vector<const Reduce*> reduces(size); + for (size_t i = 0; i < size; ++i) { + const Reduce* reduce = self->body[i].as<Reduce>(); + CHECK(reduce); + reduces[i] = reduce; + } + Expr cond = reduces[0]->condition; for (Expr v : conds) { cond = cond && v; } - Var res_handle("reduce_temp", Handle()); Array<Expr> freduce_args; - freduce_args.push_back(reduce->source); + freduce_args.push_back(make_const(UInt(32), size)); + for (size_t i = 0; i < size; ++i) { + freduce_args.push_back(reduces[0]->source[i]); + } freduce_args.push_back(cond); + std::vector<Var> res_handles(size); + for (size_t idx = 0; idx < size; ++idx) { + res_handles[idx] = Var("reduce_temp" + std::to_string(idx), Handle()); + freduce_args.push_back(res_handles[idx]); + } for (IterVar iv : stage->leaf_iter_vars) { if (iv->iter_type == kCommReduce) { @@ -257,28 +329,33 @@ Stmt MakeCrossThreadReduction( if (stage->store_predicate.defined()) { thread_head_check.emplace_back(stage->store_predicate); } - Type t = reduce->type; - Expr pred = const_true(t.lanes()); - Stmt reduce_body = Store::make(res_handle, - Call::make( - reduce->type, + + Stmt reduce_body = Evaluate::make(Call::make( + Handle(), ir::intrinsic::tvm_thread_allreduce, - freduce_args, Call::Intrinsic), - 0, pred); + freduce_args, Call::Intrinsic)); reduce_body = AttrStmt::make( - reduce->combiner, + reduces[0]->combiner, attr::reduce_scope, - make_zero(reduce->type), + make_zero(Handle()), reduce_body); - Stmt assign_body = Provide::make( - stage->op, 0, Load::make(reduce->type, res_handle, 0, pred), args); + std::vector<Stmt> assigns(size); + for (size_t idx = 0; idx < size; ++idx) { + Type t = reduces[idx]->type; + assigns[idx] = Provide::make( + stage->op, idx, + Load::make(t, res_handles[idx], 0, const_true(t.lanes())), args); + } + Stmt assign_body = Block::make(assigns); assign_body = MergeNest(op::MakeIfNest(thread_head_check), assign_body); assign_body = MergeNest(op::MakeIfNest(conds), assign_body); - Stmt body = Allocate::make( - res_handle, reduce->type, {1}, const_true(), - Block::make(reduce_body, assign_body)); - body = AttrStmt::make( - res_handle, attr::storage_scope, StringImm::make("local"), body); + Stmt body = Block::make(reduce_body, assign_body); + for (size_t idx = size; idx != 0; --idx) { + body = Allocate::make( + res_handles[idx - 1], reduces[idx - 1]->type, {1}, const_true(), body); + body = AttrStmt::make( + res_handles[idx - 1], attr::storage_scope, StringImm::make("local"), body); + } body = Substitute(body, value_map); return MergeNest(nest, body); } @@ -289,7 +366,7 @@ Stmt MakeProvide(const ComputeOpNode* op, for (IterVar iv : op->axis) { args.push_back(iv->var); } - return Provide::make(t->op, t->value_index, op->body, args); + return Provide::make(t->op, t->value_index, op->body[t->value_index], args); } Stmt ComputeOpNode::BuildProvide( @@ -301,12 +378,24 @@ Stmt ComputeOpNode::BuildProvide( // specially handle cross thread reduction. return MakeCrossThreadReduction(this, stage, dom_map); } - Stmt init, provide; + + size_t size = this->body.size(); + Stmt init; + Stmt provide; if (this->reduce_axis.size() == 0) { - provide = MakeProvide(this, stage->op.output(0)); + std::vector<Stmt> provides; + for (size_t i = 0; i < size; ++i) { + provides.emplace_back(MakeProvide(this, stage->op.output(i))); + } + provide = Block::make(provides); } else { - MakeReduction(this, stage->op.output(0), &init, &provide); + Array<Tensor> source; + for (size_t i = 0; i < size; ++i) { + source.push_back(stage->op.output(i)); + } + MakeReduction(this, source, &init, &provide); } + // make loop nest std::unordered_map<IterVar, Expr> value_map; auto nest = op::MakeLoopNest( @@ -357,7 +446,7 @@ Stmt ComputeOpNode::BuildProvide( for (auto& e : preds) e = likely(e); init_nest.push_back(op::MakeIfNest(preds)); init = Substitute(init, init_value_map); - init = MergeNest(init_nest, init); + init = MergeNest(init_nest, init); // common nest std::vector<std::vector<Stmt> > common(nest.begin(), nest.begin() + begin_loop + 1); std::vector<std::vector<Stmt> > reduce(nest.begin() + begin_loop + 1, nest.end()); diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index ce179b79188affcda881018a99c61633ab770c21..b12f6648dffcdcb5842cafffda7d26c8eebe6d7e 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -4,6 +4,7 @@ */ #include <tvm/ir.h> #include <tvm/ir_mutator.h> +#include "./ir_util.h" namespace tvm { namespace ir { @@ -17,19 +18,7 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() { // NOLINT(*) } inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) { - std::vector<Expr> new_arr(arr.size()); - bool changed = false; - for (size_t i = 0; i < arr.size(); i++) { - Expr old_elem = arr[i]; - Expr new_elem = m->Mutate(old_elem); - if (!new_elem.same_as(old_elem)) changed = true; - new_arr[i] = new_elem; - } - if (!changed) { - return arr; - } else { - return Array<Expr>(new_arr); - } + return UpdateArray(arr, [&m] (const Expr& e) { return m->Mutate(e); }); } inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) { @@ -323,14 +312,15 @@ DEFINE_BIOP_EXPR_MUTATE_(Or) Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) { Array<IterVar> new_axis = MutateIterVarArr(op->axis, this); - Expr new_source = this->Mutate(op->source); + Array<Expr> new_source = MutateArray(op->source, this); Expr new_cond = this->Mutate(op->condition); if (op->axis.same_as(new_axis) && op->source.same_as(new_source) && op->condition.same_as(new_cond)) { return e; } else { - return Reduce::make(op->combiner, new_source, new_axis, new_cond); + return Reduce::make( + op->combiner, new_source, new_axis, new_cond, op->value_index); } } diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index 1982f977365f2bce182434f559e06ba8cbb5be88..472b408e32d52015d0321d707a49770beae65161 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -12,6 +12,32 @@ namespace tvm { namespace ir { +/*! + * \brief update array with an unary function + * \param arr array + * \param fupdate an unary function + * \tparam T type of array element + * \tparam F type of the unary function + * \return if update happens, return the new array, else return the + * original array + */ +template<typename T, typename F> +inline Array<T> UpdateArray(Array<T> arr, F fupdate) { + std::vector<T> new_arr(arr.size()); + bool changed = false; + for (size_t i = 0; i < arr.size(); ++i) { + T old_elem = arr[i]; + T new_elem = fupdate(old_elem); + if (!new_elem.same_as(old_elem)) changed = true; + new_arr[i] = new_elem; + } + if (!changed) { + return arr; + } else { + return Array<T>(new_arr); + } +} + /*! * \brief combine the nest stmt, whose body is not defined. * \param nest A list of For and LetStmt, whose body is not defined. diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index bb1b3678e0d836c2f96888fec3a2359d6a72276a..bae93f9d00b69bada1986046d3a1386c43530cb9 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -133,7 +133,7 @@ DEFINE_BINOP_VISIT_(Or) void IRVisitor::Visit_(const Reduce* op) { VisitRDom(op->axis, this); - this->Visit(op->source); + VisitArray(op->source, this); } void IRVisitor::Visit_(const Cast* op) { diff --git a/src/pass/lower_thread_allreduce.cc b/src/pass/lower_thread_allreduce.cc index f9f99e8fcd265dbf42041a25ddc1f708d4b3d374..1e59723d59d5a184adebee0dafeacb02b1a89958 100644 --- a/src/pass/lower_thread_allreduce.cc +++ b/src/pass/lower_thread_allreduce.cc @@ -45,12 +45,12 @@ class ThreadAllreduceBuilder : public IRMutator { return IRMutator::Mutate_(op, s); } } - Stmt Mutate_(const Store* op, const Stmt& s) final { + Stmt Mutate_(const Evaluate* op, const Stmt& s) final { Stmt stmt = IRMutator::Mutate_(op, s); - op = stmt.as<Store>(); + op = stmt.as<Evaluate>(); const Call* call = op->value.as<Call>(); if (call && call->is_intrinsic(intrinsic::tvm_thread_allreduce)) { - return MakeAllreduce(op, call); + return MakeAllreduce(call); } else { return stmt; } @@ -97,18 +97,34 @@ class ThreadAllreduceBuilder : public IRMutator { } }; // make allreduce. - Stmt MakeAllreduce(const Store* op, const Call* call) { + Stmt MakeAllreduce(const Call* call) { CHECK(!reduce_combiner_.empty()); const CommReducerNode *combiner = reduce_combiner_.back(); - Expr init = combiner->identity_element; - Expr value = call->args[0]; - Expr cond = call->args[1]; - if (!is_one(cond)) { - value = Select::make(cond, value, init); + size_t size = combiner->result.size(); + + const UIntImm *size_of_args = call->args[0].as<UIntImm>(); + CHECK(size_of_args) << call->args[0]->type_key(); + CHECK_EQ(size, size_of_args->value); + Array<Expr> inits = combiner->identity_element; + std::vector<Expr> values(size); + std::vector<Type> types(size); + Expr cond = call->args[size+1]; + for (size_t idx = 0; idx < size; ++idx) { + values[idx] = call->args[1+idx]; + if (!is_one(cond)) { + values[idx] = Select::make(cond, values[idx], inits[idx]); + } + types[idx] = values[idx].type(); + } + std::vector<const Variable*> buffers(size); + for (size_t idx = 0; idx < size; ++idx) { + const Variable* buffer = call->args[2+size+idx].as<Variable>(); + CHECK(buffer); + buffers[idx] = buffer; } std::unordered_set<const Variable*> reduce_set; - for (size_t i = 2; i < call->args.size(); ++i) { + for (size_t i = 2 + 2 * size; i < call->args.size(); ++i) { const Variable* v = call->args[i].as<Variable>(); CHECK(v); reduce_set.insert(v); @@ -143,40 +159,50 @@ class ThreadAllreduceBuilder : public IRMutator { int threadx_extent = 1; Expr reduce_index = FlattenThread(vred, &reduce_extent); Expr group_index = FlattenThread(vpar, &group_extent); - Expr pred = const_true(value.type().lanes()); if (reduce_extent == 1) { // special case, no reduction is needed. - return Store::make(op->buffer_var, value, 0, pred); + std::vector<Stmt> stores(size); + for (size_t i = 0; i < size; ++i) { + Expr pred = const_true(types[i].lanes()); + Var buffer_var(call->args[2+size+i].node_); + stores[i] = Store::make(buffer_var, values[i], 0, pred); + } + return Block::make(stores); } // Whether the threadIdx.x is involved in reduction. if (vred[0].scope.dim_index == 0) { threadx_extent = vred[0].extent; } - Var shared_buf("red_buf", Handle()); std::vector<Stmt> seq; - seq.emplace_back(Store::make( - shared_buf, value, - BufIndex(reduce_index, group_index, reduce_extent), pred)); + std::vector<Var> shared_bufs(size); + for (size_t idx = 0; idx < size; ++idx) { + shared_bufs[idx] = Var("red_buf"+std::to_string(idx), Handle()); + Expr pred = const_true(types[idx].lanes()); + seq.emplace_back(Store::make( + shared_bufs[idx], values[idx], + BufIndex(reduce_index, group_index, reduce_extent), pred)); + } seq.emplace_back(SyncThread("shared")); seq.emplace_back(MakeBufAllreduce( - combiner, value.type(), shared_buf, + combiner, types, shared_bufs, reduce_index, group_index, reduce_extent, threadx_extent)); - CHECK(!load_remap_.count(op->buffer_var.get())); - load_remap_[op->buffer_var.get()] = - Load::make( - value.type(), shared_buf, - BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), - pred); - alloc_remap_[op->buffer_var.get()] = - Allocate::make(shared_buf, value.type(), - {Expr(group_extent), Expr(reduce_extent)}, - pred, Evaluate::make(0)); + for (size_t idx = 0; idx < size; ++idx) { + CHECK(!load_remap_.count(buffers[idx])); + Expr pred = const_true(types[idx].lanes()); + load_remap_[buffers[idx]] = Load::make( + types[idx], shared_bufs[idx], + BufIndex(make_zero(reduce_index.type()), group_index, reduce_extent), pred); + alloc_remap_[buffers[idx]] = Allocate::make( + shared_bufs[idx], types[idx], + {Expr(group_extent), Expr(reduce_extent)}, + pred, Evaluate::make(0)); + } return MergeSeq(seq); } // make allreduce. Stmt MakeBufAllreduce(const CommReducerNode *combiner, - Type type, - Var shared_buf, + const std::vector<Type>& types, + const Array<Var>& shared_bufs, Expr reduce_index, Expr group_index, int reduce_extent, @@ -189,14 +215,23 @@ class ThreadAllreduceBuilder : public IRMutator { CHECK_GT(reduce_align, 1); std::vector<Stmt> seq; + size_t size = shared_bufs.size(); Expr buf_index = BufIndex(reduce_index, group_index, reduce_extent); // make reduction auto freduce = [&](int offset) { - Expr b = Load::make( - type, shared_buf, - BufIndex(reduce_index + offset, group_index, reduce_extent), const_true()); - Expr a = Load::make(type, shared_buf, buf_index, const_true()); - return Store::make(shared_buf, (*combiner)(a, b), buf_index, const_true()); + Array<Expr> a, b; + for (size_t i = 0; i < size; ++i) { + b.push_back(Load::make(types[i], shared_bufs[i], + BufIndex(reduce_index + offset, group_index, reduce_extent), + const_true())); + a.push_back(Load::make(types[i], shared_bufs[i], buf_index, const_true())); + } + Array<Expr> ret = (*combiner)(a, b); + std::vector<Stmt> stores(size); + for (size_t i = 0; i < size; ++i) { + stores[i] = Store::make(shared_bufs[i], ret[i], buf_index, const_true()); + } + return Block::make(stores); }; // Step one, check for if (reduce_align > reduce_extent) { diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index ff4c5912dae0ecf74e2105a0af157705dc4a0757..f807b92dceaa7ccec6d6535f4dce75936e86cc86 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -157,7 +157,9 @@ class StorageFlattener : public IRMutator { CHECK_EQ(extern_buf_remap_.size(), 0U); for (size_t i = 0; i < ext_op->output_placeholders.size(); ++i) { TensorKey key{func, static_cast<int>(i)}; - CHECK(buf_map_.count(key)); + CHECK(buf_map_.count(key)) + << "Cannot find allocated buffer for " << key.f + << "(" << key.value_index << ")"; extern_buf_remap_[ext_op->output_placeholders[i]->data.get()] = buf_map_.at(key).buffer->data; } diff --git a/src/schedule/auto_inline_elem_wise.cc b/src/schedule/auto_inline_elem_wise.cc index 312be19a08747f66d16cb55d7b8603c6ae417797..9fd073c0ac7afca44907e05d2bd822c6dd9c1a74 100644 --- a/src/schedule/auto_inline_elem_wise.cc +++ b/src/schedule/auto_inline_elem_wise.cc @@ -46,7 +46,7 @@ class ElemWiseDetector : public ir::IRVisitor { bool IsElemWise(const Operation& op) { if (const ComputeOpNode* compute = op.as<ComputeOpNode>()) { ElemWiseDetector v = ElemWiseDetector(compute->axis); - v.Visit(compute->body); + for (auto& e : compute->body) v.Visit(e); return v.is_elem_wise_; } return false; diff --git a/src/schedule/graph.cc b/src/schedule/graph.cc index 0fcf21def2fca9e1583cc9e502d0ac10bedde004..da0aeb0eccaa3f21fef70fb56297473459757457 100644 --- a/src/schedule/graph.cc +++ b/src/schedule/graph.cc @@ -260,7 +260,9 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) { } } }; - ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit); + for (auto& e : op.as<ComputeOpNode>()->body) { + ir::PostOrderVisit(e, fvisit); + } } } return reach; @@ -321,11 +323,14 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) { } } } else if (op.as<ComputeOpNode>()) { - std::unordered_map<const Node*, TensorDimKey> vmap; + std::unordered_map<const Node*, std::vector<TensorDimKey> > vmap; const auto& axis = op.as<ComputeOpNode>()->axis; - Tensor t = op.output(0); for (size_t i = 0; i < axis.size(); ++i) { - vmap[axis[i]->var.get()] = TensorDimKey(t, i); + std::vector<TensorDimKey> keys; + for (int j = 0; j < op->num_outputs(); ++j) { + keys.emplace_back(op.output(j), i); + } + vmap[axis[i]->var.get()] = std::move(keys); } auto fvisit = [&vmap, &f_merge_key, &exact_reach, &fail_set]( const NodeRef& n) { @@ -335,7 +340,10 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) { auto it = vmap.find(call->args[i].get()); TensorDimKey src(call, static_cast<int>(i)); if (it != vmap.end()) { - f_merge_key(it->second, src); + const std::vector<TensorDimKey>& keys = it->second; + for (const auto& key : keys) { + f_merge_key(key, src); + } } else { if (exact_reach.count(src)) { fail_set.insert(exact_reach.at(src)); @@ -344,7 +352,9 @@ Map<IterVar, Expr> ScanFixPointAnalysis(const Operation& scan_op) { } } }; - ir::PostOrderVisit(op.as<ComputeOpNode>()->body, fvisit); + for (auto& e : op.as<ComputeOpNode>()->body) { + ir::PostOrderVisit(e, fvisit); + } } } ReachGraph reach; diff --git a/src/schedule/graph.h b/src/schedule/graph.h index 7908dc9e1de64a4eccae1fc97031415490d94cf7..50d35355cc6482faf2317a1da2f8e26265fa3569 100644 --- a/src/schedule/graph.h +++ b/src/schedule/graph.h @@ -27,7 +27,7 @@ using ReadGraph = Map<Operation, Array<Tensor> >; using AttachPath = Map<Operation, Array<IterVar> >; /*! - * \brief The map beteen tensor and operation it feeds to. + * \brief The map between tensor and operation it feeds to. */ using FeedGraph = std::unordered_map<Tensor, std::vector<Operation> >; @@ -46,7 +46,7 @@ ReadGraph CreateReadGraph(const Array<Operation>& roots); * The operations contains node which input-reachable from any inputs * output reachable to any outputs. * - * The inputs won't be included in the subgraph, the outputs will be inclued. + * The inputs won't be included in the subgraph, the outputs will be included. * * \param outputs The outputs of the subgraph * \param inputs The inputs to the subgraph. diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 25319dc24eff7d755b326094185d7b5c0a6ad695..d24ba17bf6e3e97c276d5dbb8e7f1b7a77fac553 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -8,6 +8,7 @@ #include <tvm/ir_pass.h> #include <unordered_set> #include "./message_passing.h" +#include "../pass/ir_util.h" namespace tvm { @@ -120,13 +121,13 @@ Tensor Schedule::cache_write(const Tensor& tensor, vsub[iv->var.get()] = new_iv->var; } VarReplacer repl(vsub); - Expr body = repl.Mutate(compute->body); + Expr body = repl.Mutate(compute->body[tensor->value_index]); Operation cache_op = ComputeOpNode::make( - compute->name + "." + scope, new_axis, body); + compute->name + "." + scope, new_axis, {body}); Tensor cache_tensor = cache_op.output(0); Operation orig_new_op = ComputeOpNode::make( compute->name, compute->axis, - cache_tensor(args)); + {cache_tensor(args)}); std::unordered_map<Tensor, Tensor> vmap; vmap[orig_stage->op.output(0)] = orig_new_op.output(0); @@ -198,14 +199,15 @@ void RebaseNonZeroMinLoop(const Schedule& sch) { void InjectInline(ScheduleNode* sch) { sch->InvalidateCache(); - std::vector<Expr> new_body(sch->stages.size()); + std::vector<Array<Expr>> new_body(sch->stages.size()); + std::vector<bool> changed(sch->stages.size(), false); // inline all the ops for (size_t i = sch->stages.size(); i != 0; --i) { Stage stage = sch->stages[i - 1]; if (stage->attach_type == kInline) { stage->attach_type = kInlinedAlready; Array<Var> args; - Expr body; + Array<Expr> body; { // setup args const ComputeOpNode* compute = stage->op.as<ComputeOpNode>(); @@ -220,11 +222,14 @@ void InjectInline(ScheduleNode* sch) { Stage s = sch->stages[j]; const ComputeOpNode* compute = s->op.as<ComputeOpNode>(); if (compute) { - if (!new_body[j].defined()) { + if (!new_body[j].size()) { new_body[j] = s->op.as<ComputeOpNode>()->body; } - new_body[j] = ir::Inline(ir::Evaluate::make(new_body[j]), - stage->op, args, body).as<ir::Evaluate>()->value; + for (size_t k = 0; k < body.size(); ++k) { + changed[j] = true; + new_body[j].Set(k, ir::Inline(ir::Evaluate::make(new_body[j][k]), + stage->op, args, body[k]).as<ir::Evaluate>()->value); + } } } } @@ -234,19 +239,21 @@ void InjectInline(ScheduleNode* sch) { for (size_t i = 0; i < sch->stages.size(); ++i) { Stage s = sch->stages[i]; if (s->attach_type == kInlinedAlready) continue; - if (new_body[i].defined()) { + if (new_body[i].size()) { // Logics from ReplaceDataFlow const ComputeOpNode* compute = sch->stages[i]->op.as<ComputeOpNode>(); CHECK(compute); Operation op = s->op; - if (!new_body[i].same_as(compute->body)) { + if (changed[i]) { op = ComputeOpNode::make( compute->name, compute->axis, new_body[i]); } op = op->ReplaceInputs(op, repl); if (!op.same_as(s->op)) { - repl[s->op.output(0)] = op.output(0); - s->op = op; + for (int idx = 0; idx < s->op->num_outputs(); ++idx) { + repl[s->op.output(idx)] = op.output(idx); + s->op = op; + } } } else { Operation op = s->op->ReplaceInputs(s->op, repl); @@ -268,15 +275,15 @@ Schedule Schedule::normalize() { } // Handle reduction factor. -Tensor Schedule::rfactor(const Tensor& tensor, - const IterVar& axis) { +Array<Tensor> Schedule::rfactor(const Tensor& tensor, + const IterVar& axis) { (*this)->InvalidateCache(); using ir::Reduce; CHECK_EQ(axis->iter_type, kCommReduce) << "Can only factor reduction axis"; Stage reduce_stage = operator[](tensor->op); const ComputeOpNode* compute_op = reduce_stage->op.as<ComputeOpNode>(); - CHECK(compute_op) << "Can only factor ComputeOp"; + CHECK(compute_op) << "Can only factor ComputeOp"; ArrayNode* leaf_vars = reduce_stage->leaf_iter_vars.CopyOnWrite(); { size_t axis_pos = FindNodeRef(leaf_vars, axis); @@ -329,7 +336,8 @@ Tensor Schedule::rfactor(const Tensor& tensor, } } // predicate generation, copy not touched axis. - const Reduce* reduce = compute_op->body.as<Reduce>(); + int idx = tensor->value_index; + const Reduce* reduce = compute_op->body[idx].as<Reduce>(); CHECK(reduce) << "Can only rfactor non-inline reductions"; Expr predicate = reduce->condition; std::unordered_map<const Variable*, Expr> vsub; @@ -359,10 +367,18 @@ Tensor Schedule::rfactor(const Tensor& tensor, n->reduce_axis.push_back(IterVar(ncpy)); } } - n->body = Reduce::make(reduce->combiner, - VarReplacer(vsub).Mutate(reduce->source), - n->reduce_axis, - predicate); + VarReplacer replacer(vsub); + Array<Expr> new_source = ir::UpdateArray(reduce->source, + [&replacer] (const Expr& e) { return replacer.Mutate(e); }); + std::vector<Expr> body; + for (size_t idx = 0; idx < reduce->source.size(); ++idx) { + body.emplace_back(Reduce::make(reduce->combiner, + new_source, + n->reduce_axis, + predicate, + idx)); + } + n->body = Array<Expr>(body); // refresh relations, keep the un-touched relations. Array<IterVarRelation> rels; for (IterVarRelation rel : reduce_stage->relations) { @@ -397,26 +413,44 @@ Tensor Schedule::rfactor(const Tensor& tensor, // Replace the old reduction. IterVar repl_red_axis = reduce_axis( dom_map.at(axis), axis->var->name_hint + ".v"); - Tensor factor_tensor = factor_op.output(0); - Tensor old_tensor = reduce_stage->op.output(0); - Tensor repl_tensor = compute(old_tensor->shape, [&](const Array<Var>& i) { + Array<Tensor> factor_tensors; + Array<Tensor> old_tensors; + int size = factor_op->num_outputs(); + for (int idx = 0; idx < size; ++idx) { + factor_tensors.push_back(factor_op.output(idx)); + old_tensors.push_back(reduce_stage->op.output(idx)); + } + Array<Tensor> repl_tensors = compute(old_tensors[0]->shape, + [&](const Array<Var>& i) { Array<Expr> indices; indices.push_back(repl_red_axis->var); for (Var v : i) { indices.push_back(v); } - return Reduce::make(reduce->combiner, - factor_tensor(indices), {repl_red_axis}, const_true()); - }, old_tensor->op->name + ".repl"); + Array<Expr> factor_exprs; + for (int idx = 0; idx < size; ++idx) { + factor_exprs.push_back(factor_tensors[idx](indices)); + } + Array<Expr> reductions; + Array<IterVar> axis = {repl_red_axis}; + Expr cond = const_true(); + for (int idx = 0; idx < size; ++idx) { + reductions.push_back(Reduce::make(reduce->combiner, + factor_exprs, axis, cond, idx)); + } + return reductions; + }, reduce_stage->op->name + ".repl"); std::unordered_map<Tensor, Tensor> vmap; - vmap[old_tensor] = repl_tensor; + for (int idx = 0; idx < size; ++idx) { + vmap[old_tensors[idx]] = repl_tensors[idx]; + } ReplaceDataFlow((*this)->stages, &vmap); // revamp the reduction stage. - reduce_stage->op = repl_tensor->op; - reduce_stage->all_iter_vars = repl_tensor->op->root_iter_vars(); + reduce_stage->op = repl_tensors[0]->op; + reduce_stage->all_iter_vars = repl_tensors[0]->op->root_iter_vars(); reduce_stage->leaf_iter_vars = reduce_stage->all_iter_vars; reduce_stage->relations = Array<IterVarRelation>(); - return factor_tensor; + return factor_tensors; } } // namespace tvm diff --git a/src/schedule/schedule_ops.cc b/src/schedule/schedule_ops.cc index edbc2878a5a963d5de2d6482de592fd931a8b507..347cf69884b6934eaf13dfa8250e7126e1c996d7 100644 --- a/src/schedule/schedule_ops.cc +++ b/src/schedule/schedule_ops.cc @@ -253,7 +253,7 @@ class SchedulePostProc : public IRMutator { // This must be checked for all ops, including scan. if (!s->op.same_as(s->origin_op)) { for (int i = 0; i < s->op->num_outputs(); ++i) { - Tensor target = s->origin_op.output(0); + Tensor target = s->origin_op.output(i); AddReplace(s->op.output(i), target, target, s->origin_op); } diff --git a/tests/python/integration/test_reduce.py b/tests/python/integration/test_reduce.py index bf3be27fd3bf2040c69237c0d4249c300c18d5ee..ffc4b79b58d26415a544ebb6caa6fe50dddbe857 100644 --- a/tests/python/integration/test_reduce.py +++ b/tests/python/integration/test_reduce.py @@ -49,7 +49,6 @@ def test_reduce_prims(): test_prim(tvm.max, np.amax) - def test_rfactor(): n = tvm.convert(1027) A = tvm.placeholder((n,), name='A') @@ -128,7 +127,115 @@ def test_rfactor_threads(): check_target("metal") check_target("opencl") +def test_argmax(): + def fcombine(x, y): + lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) + return lhs, rhs + + def fidentity(t0, t1): + return tvm.const(-1, t0), tvm.min_value(t1) + + argmax = tvm.comm_reducer(fcombine, + fidentity, + name='argmax') + m = tvm.var('m') + n = tvm.var('n') + idx = tvm.placeholder((m, n), name='idx', dtype='int32') + val = tvm.placeholder((m, n), name='val', dtype='float32') + k = tvm.reduce_axis((0, n), 'k') + T0, T1 = tvm.compute((m,), lambda i: argmax((idx[i,k], val[i,k]), axis=k), name='T') + s = tvm.create_schedule(T0.op) + + def check_target(): + device = 'cpu' + if not tvm.module.enabled(device): + print("skip because %s is not enabled.." % device) + return + ctx = tvm.context(device, 0) + fapi = tvm.lower(s, args=[idx, val, T0, T1]) + fargmax = tvm.build(fapi, + target='llvm', + name="argmax") + + mm = 12 + nn = 16 + np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0) + np_val = np.random.uniform(size=(mm, nn)).astype('float32') + np_res = np.argmax(np_val, axis=1) + + nd_idx = tvm.nd.array(np_idx, ctx) + nd_val = tvm.nd.array(np_val, ctx) + nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx) + nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx) + fargmax(nd_idx, nd_val, nd_res0, nd_res1) + np.testing.assert_allclose(np_res, nd_res0.asnumpy()) + + check_target() + + +def test_rfactor_argmax(): + def fcombine(x, y): + lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) + return lhs, rhs + + def fidentity(t0, t1): + return tvm.const(-1, t0), tvm.min_value(t1) + + argmax = tvm.comm_reducer(fcombine, + fidentity, + name='argmax') + + nn = 1027 + mm = 10 + n = tvm.convert(nn) + m = tvm.convert(mm) + A0 = tvm.placeholder((m, n), name='A0', dtype='int32') + A1 = tvm.placeholder((m, n), name='A1', dtype='float32') + k = tvm.reduce_axis((0, n)) + B0, B1 = tvm.compute((m,), lambda i: argmax((A0[i, k], A1[i, k]), axis=k), name='B') + + # schedule + s = tvm.create_schedule(B0.op) + nthread = 16 + ko, kf = s[B0].split(k, factor=nthread) + BF0, BF1 = s.rfactor(B0, kf) + bx, ty = s[B0].split(s[B0].op.axis[0], factor=nthread) + s[B0].bind(bx, tvm.thread_axis("blockIdx.x")) + s[B0].bind(ty, tvm.thread_axis("threadIdx.y")) + tx = s[B0].op.reduce_axis[0] + thread_x = tvm.thread_axis("threadIdx.x") + s[B0].bind(tx, thread_x) + s[BF0.op].compute_at(s[B0], tx) + s[B0].set_store_predicate(thread_x.var.equal(0)) + + def check_target(device): + if not tvm.module.enabled(device): + print("skip because %s is not enabled.." % device) + return + ctx = tvm.context(device, 0) + fapi = tvm.lower(s, args=[A0, A1, B0, B1]) + fargmax = tvm.build(fapi, + target=device, + name="argmax") + + np_idx = np.repeat(np.arange(nn, dtype='int32').reshape(1, nn), mm, axis=0) + np_val = np.random.uniform(size=(mm, nn)).astype('float32') + np_res = np.argmax(np_val, axis=1) + + nd_idx = tvm.nd.array(np_idx, ctx) + nd_val = tvm.nd.array(np_val, ctx) + nd_res0 = tvm.nd.array(np.zeros(mm, dtype='int32'), ctx) + nd_res1 = tvm.nd.array(np.zeros(mm, dtype='float32'), ctx) + fargmax(nd_idx, nd_val, nd_res0, nd_res1) + np.testing.assert_allclose(np_res, nd_res0.asnumpy()) + + check_target("cuda") + if __name__ == "__main__": test_rfactor_threads() test_rfactor() test_reduce_prims() + test_argmax() + test_rfactor_argmax() diff --git a/tests/python/unittest/test_lang_schedule.py b/tests/python/unittest/test_lang_schedule.py index 5fed5a23f750bc6d0ddbf6eab2bb233eb2c2c89a..1b0eac15fe0782ba1201facba7e35b0e48f2dd4e 100644 --- a/tests/python/unittest/test_lang_schedule.py +++ b/tests/python/unittest/test_lang_schedule.py @@ -101,8 +101,8 @@ def test_rfactor(): s = tvm.create_schedule(B.op) BF = s.rfactor(B, k1) assert(tuple(BF.shape) == (n, n)) - assert(set(BF.op.body.axis) == set([k2])) - assert(s[B].op.body.axis[0].dom.extent == n) + assert(set(BF.op.body[0].axis) == set([k2])) + assert(s[B].op.body[0].axis[0].dom.extent == n) assert(len(s[B].all_iter_vars) == 2) # schedule with splot s = tvm.create_schedule(B.op) @@ -111,9 +111,9 @@ def test_rfactor(): BF = s.rfactor(B, ki) assert(BF.shape[0].value == 4) assert(BF.shape[1] == n) - assert(BF.op.body.axis[0] == k2) - assert(BF.op.body.axis[1].var == ko.var) - assert(s[B].op.body.axis[0].dom.extent.value == 4) + assert(BF.op.body[0].axis[0] == k2) + assert(BF.op.body[0].axis[1].var == ko.var) + assert(s[B].op.body[0].axis[0].dom.extent.value == 4) if __name__ == "__main__": diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index e81b15ffa653c34df4604bdc83180e26866f3de0..9160baec3789379965c293bdbc094674cf53c0ad 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -118,6 +118,43 @@ def test_extern_multi_out(): assert(len(res) == 2) assert(res[1].value_index == 1) +def test_tuple_inputs(): + m = tvm.var('m') + n = tvm.var('n') + A0 = tvm.placeholder((m, n), name='A0') + A1 = tvm.placeholder((m, n), name='A1') + T0, T1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='T') + s = tvm.create_schedule(T0.op) + + for i in range(len(T0.shape)): + assert(T0.shape[i] == T1.shape[i]) + assert(T0.op == T1.op) + assert(T0.value_index == 0) + assert(T1.value_index == 1) + +def test_tuple_with_different_deps(): + m = tvm.var('m') + n = tvm.var('n') + A0 = tvm.placeholder((m, n), name='A1') + A1 = tvm.placeholder((m, n), name='A2') + B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] * 2, A1[i, j] * 3), name='B') + C = tvm.compute((m, n), lambda i, j: B0[i, j] + 4, name='C') + + s = tvm.create_schedule(C.op) + xo, xi = s[C].split(C.op.axis[0], factor=10) + s[B0.op].compute_at(s[C], xo) + sch = s.normalize() + bounds = tvm.schedule.InferBound(sch) + stmt = tvm.schedule.ScheduleOps(sch, bounds) + + def get_B1_realize(x): + if isinstance(x, tvm.stmt.Realize) and \ + x.func == B1.op and x.value_index == 1: + ret.append(x) + ret = [] + tvm.ir_pass.PostOrderVisit(stmt, get_B1_realize) + + assert stmt.node == C.op and len(ret) == 1 if __name__ == "__main__": test_conv1d() @@ -128,3 +165,5 @@ if __name__ == "__main__": test_scan_multi_out() test_extern() test_extern_multi_out() + test_tuple_inputs() + test_tuple_with_different_deps() diff --git a/tests/python/unittest/test_pass_inline.py b/tests/python/unittest/test_pass_inline.py index 1988d54083c767feb4758f48908fb761ff394beb..398c0d34d58d6336c1bc45629be65375eb50ecac 100644 --- a/tests/python/unittest/test_pass_inline.py +++ b/tests/python/unittest/test_pass_inline.py @@ -6,7 +6,7 @@ def test_inline(): T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') stmt = tvm.make.Evaluate(T[10] + 11 * T[100]) stmt = tvm.ir_pass.Inline( - stmt, T.op, [x.var for x in T.op.axis], T.op.body) + stmt, T.op, [x.var for x in T.op.axis], T.op.body[0]) print(stmt) assert(tvm.ir_pass.VerifySSA(stmt)) @@ -25,7 +25,7 @@ def test_inline2(): T = tvm.compute((m,), lambda i,: A[i] + 10, name='T') stmt = tvm.make.Evaluate(tvm.exp(T[10]) + 11 * T[100]) stmt = tvm.ir_pass.Inline( - stmt, T.op, [x.var for x in T.op.axis], T.op.body) + stmt, T.op, [x.var for x in T.op.axis], T.op.body[0]) def check(op): if isinstance(op, tvm.expr.Call): assert op.func != T.op diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index eceba04f5b0b1b1dcaec3b892e4626eb6db46195..297800e1632d18dc3e9a8797c97cbd801203cce0 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -89,7 +89,7 @@ def test_inline_mixed(): def check(x): if isinstance(x, tvm.expr.Call): assert x.func != A2 - tvm.ir_pass.PostOrderVisit(s[C].op.body, check) + tvm.ir_pass.PostOrderVisit(s[C].op.body[0], check) def test_scan_inline1(): diff --git a/tutorials/python/reduction.py b/tutorials/python/reduction.py index 1bdc1f9b8e756f93d93edf4caad6f39f17f8225a..e7295cb927a32a0d41f48dfab7e921420fa5fb92 100644 --- a/tutorials/python/reduction.py +++ b/tutorials/python/reduction.py @@ -125,6 +125,8 @@ np.testing.assert_allclose( b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4) ###################################################################### +# .. _general-reduction: +# # Define General Commutative Reduction Operation # ---------------------------------------------- # Besides the built-in reduction operations like :any:`tvm.sum`, @@ -140,6 +142,12 @@ A = tvm.placeholder((n, m), name='A') k = tvm.reduce_axis((0, m), name='k') B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B') +###################################################################### +# .. note:: +# +# Sometimes we would like to perform reduction that involves multiple +# values like :code:`argmax`, which can be done by tuple inputs. +# See :ref:`reduction-with-tuple-inputs` for more detail. ###################################################################### # Summary diff --git a/tutorials/python/tuple_inputs_operation.py b/tutorials/python/tuple_inputs_operation.py new file mode 100644 index 0000000000000000000000000000000000000000..8c101a59e86e2f04c97b43a68c00fa5f258e719e --- /dev/null +++ b/tutorials/python/tuple_inputs_operation.py @@ -0,0 +1,103 @@ +""" +Compute and Reduction with Tuple Inputs +======================================= +**Author**: `Ziheng Jiang <https://github.com/ZihengJiang>`_ + +Often we want to compute multiple outputs with the same shape within +a single loop or perform reduction that involves multiple values like +:code:`argmax`. These problems can be addressed by tuple inputs. + +In this tutorial, we will introduce the usage of tuple inputs in TVM. +""" +from __future__ import absolute_import, print_function + +import tvm +import numpy as np + +###################################################################### +# Describe Batchwise Computation +# ------------------------------ +# For operators which have the same shape, we can put them together as +# the inputs of :any:`tvm.compute`, if we wish they can be scheduled +# together in the next schedule procedure. +# +n = tvm.var("n") +m = tvm.var("m") +A0 = tvm.placeholder((m, n), name='A0') +A1 = tvm.placeholder((m, n), name='A1') +B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A1[i, j] * 3), name='B') + +# The generated IR code would be: +s = tvm.create_schedule(B0.op) +print(tvm.lower(s, [A0, A1, B0, B1], simple_mode=True)) + +###################################################################### +# .. _reduction-with-tuple-inputs: +# +# Describe Reduction with Collaborative Inputs +# -------------------------------------------- +# Sometimes, we requires multiple inputs to express some reduction +# operators, and the inputs will collaborate together, e.g. :code:`argmax`. +# In the reduction procedure, :code:`argmax` need to compare the value of +# operands, also need to keep the index of operand. It can be expressed +# with :any:`comm_reducer` as below: + +# x and y are the operands of reduction, both of them is a tuple of index +# and value. +def fcombine(x, y): + lhs = tvm.make.Select((x[1] >= y[1]), x[0], y[0]) + rhs = tvm.make.Select((x[1] >= y[1]), x[1], y[1]) + return lhs, rhs + +# our identity element also need to be a tuple, so `fidentity` accepts +# two types as inputs. +def fidentity(t0, t1): + return tvm.const(-1, t0), tvm.min_value(t1) + +argmax = tvm.comm_reducer(fcombine, fidentity, name='argmax') + +# describe the reduction computation +m = tvm.var('m') +n = tvm.var('n') +idx = tvm.placeholder((m, n), name='idx', dtype='int32') +val = tvm.placeholder((m, n), name='val', dtype='int32') +k = tvm.reduce_axis((0, n), 'k') +T0, T1 = tvm.compute((m, ), lambda i: argmax((idx[i, k], val[i, k]), axis=k), name='T') + +# the generated IR code would be: +s = tvm.create_schedule(T0.op) +print(tvm.lower(s, [idx, val, T0, T1], simple_mode=True)) + +###################################################################### +# .. note:: +# +# For ones who are not familiar with reduction, please refer to +# :ref:`general-reduction`. + +###################################################################### +# Schedule Operation with Tuple Inputs +# ------------------------------------ +# It is worth mentioning that although you will get multiple outputs +# with one batch operation, but they can only be scheduled together +# in terms of operation. + +n = tvm.var("n") +m = tvm.var("m") +A0 = tvm.placeholder((m, n), name='A0') +B0, B1 = tvm.compute((m, n), lambda i, j: (A0[i, j] + 2, A0[i, j] * 3), name='B') +A1 = tvm.placeholder((m, n), name='A1') +C = tvm.compute((m, n), lambda i, j: A1[i, j] + B0[i, j], name='C') + +s = tvm.create_schedule(C.op) +s[B0].compute_at(s[C], C.op.axis[0]) +# as you can see in the below generated IR code: +print(tvm.lower(s, [A0, A1, C], simple_mode=True)) + +###################################################################### +# Summary +# ------- +# This tutorial introduces the usage of tuple inputs operation. +# +# - Describe normal batchwise computation. +# - Describe reduction operation with tuple inputs. +# - Notice that you can only schedule computation in terms of operation instead of tensor.