From 0c72ca970e5e3862bd2914bc9802335618ce15f0 Mon Sep 17 00:00:00 2001 From: tqchen <tianqi.tchen@gmail.com> Date: Mon, 5 Dec 2016 11:21:32 +0100 Subject: [PATCH] Finish schedule operation --- HalideIR | 2 +- include/tvm/schedule.h | 8 ++- python/tvm/function.py | 8 +-- python/tvm/schedule.py | 99 +++++++++++++++++++++++++++++++++-- python/tvm/tensor.py | 2 + src/c_api/c_api_lang.cc | 49 +++++++++++++++++ src/c_api/c_api_registry.h | 2 +- src/lang/expr.cc | 4 +- src/lang/schedule.cc | 94 ++++++++++++++++++++++++++++++--- tests/python/test_schedule.py | 36 ++++++++----- 10 files changed, 272 insertions(+), 32 deletions(-) diff --git a/HalideIR b/HalideIR index 29fd3defa..ea1a81be8 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 29fd3defa3dbf810e52dbc2ecd3933604989dcc8 +Subproject commit ea1a81be8baa43665f6ebd4d75d51c081283ebc8 diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index ad23f5135..259ef2d17 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -50,16 +50,19 @@ class Schedule : public NodeRef { * \brief specify the schedule to be computed at the parent schedule's scope. * \param parent The parent schedule. * \param scope The iteration point to carry the schedule. + * \return reference to self. */ Schedule& compute_at(Schedule parent, IterVar scope); // NOLINT(*) /*! * \brief Compute the function inline, attach it at parent. * \param parent The parent schedule to be attached to. + * \return reference to self. */ Schedule& compute_inline(Schedule parent); // NOLINT(*) /*! * \brief Compute the function at root, attach it to its parent. * \param parent The parent schedule to be attached to. + * \return reference to self. */ Schedule& compute_root(Schedule parent); // NOLINT(*) /*! @@ -68,7 +71,7 @@ class Schedule : public NodeRef { * \param p_outer The result outer domain * \param p_inner The result inner domain. * \param factor The split factor of the loop. - * \param outer The generated + * \return reference to self. */ Schedule& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*) /*! @@ -80,6 +83,7 @@ class Schedule : public NodeRef { * \param p_inner The result inner domain. * \param factor Optional, the factor of the split, * factor must be provided such that factor * outer.extent >= parent.extent. + * \return reference to self. */ Schedule& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*) /*! @@ -87,11 +91,13 @@ class Schedule : public NodeRef { * \param inner The inner domain to be fused * \param outer The outer domain to be fused. * \param p_target The result target domain. + * \return reference to self. */ Schedule& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*) /*! * \brief Reorder the iteration * \param order The order of iteration variable. + * \return reference to self. */ Schedule& reorder(const Array<IterVar>& order); // NOLINT(*) }; diff --git a/python/tvm/function.py b/python/tvm/function.py index b1f91bd44..7088d051a 100644 --- a/python/tvm/function.py +++ b/python/tvm/function.py @@ -79,6 +79,9 @@ def compute(shape, fcompute, name="TensorCompute"): tensor: tensor.Tensor The created tensor """ + if isinstance(shape, _expr.Expr): + shape = (shape, ) + ndim = len(shape) arg_names = fcompute.__code__.co_varnames if ndim != len(arg_names): @@ -86,6 +89,7 @@ def compute(shape, fcompute, name="TensorCompute"): dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)] body = fcompute(*[v.var for v in dim_var]) + body = convert(body) op_node = _function_internal._ComputeOp( name, dim_var, body) return _function_internal._Tensor( @@ -174,8 +178,4 @@ def Schedule(tensor, scope="global"): return _function_internal._Schedule(tensor, scope) -def Split(dim, factor, over_rdom=False): - return _function_internal._DimSplit(dim, factor, over_rdom) - - _init_function_module("tvm") diff --git a/python/tvm/schedule.py b/python/tvm/schedule.py index 5e368b8b1..dee3f3309 100644 --- a/python/tvm/schedule.py +++ b/python/tvm/schedule.py @@ -4,13 +4,106 @@ from ._ctypes._api import NodeBase, register_node from . import _function_internal @register_node -class DimSplit(NodeBase): +class Split(NodeBase): pass @register_node -class AttachSpec(NodeBase): +class Fuse(NodeBase): pass @register_node class Schedule(NodeBase): - pass + def split(self, parent, factor=None, outer=None): + """Split the schedule either by factor providing outer scope, or both + + Parameters + ---------- + parent : IterVar + The parent iter var. + + factor : Expr, optional + The splitting factor + + outer : IterVar, optional + The outer split variable + + Returns + ------- + outer : IterVar + The outer variable of iteration. + + inner : IterVar + The inner variable of iteration. + """ + if outer is not None: + if outer.thread_tag == '': + raise ValueError("split by outer must have special thread_tag") + if outer.dom is None: + raise ValueError("split by outer must have specified domain") + inner = _function_internal._ScheduleSplitByOuter(self, parent, outer, factor) + else: + if factor is None: + raise ValueError("either outer or factor need to be provided") + outer, inner = _function_internal._ScheduleSplitByFactor(self, parent, factor) + return outer, inner + + def fuse(self, inner, outer): + """Fuse inner and outer to a single iteration variable. + + Parameters + ---------- + outer : IterVar + The outer variable of iteration. + + inner : IterVar + The inner variable of iteration. + + Returns + ------- + inner : IterVar + The fused variable of iteration. + """ + return _function_internal._ScheduleFuse(self, inner, outer) + + def compute_at(self, parent, scope): + """Attach the schedule at parent's scope + + Parameters + ---------- + parent : Schedule + The parent schedule + + scope : IterVar + The loop scope t be attached to. + """ + _function_internal._ScheduleComputeAt(self, parent, scope) + + def compute_inline(self, parent): + """Attach the schedule at parent, and mark it as inline + + Parameters + ---------- + parent : Schedule + The parent schedule + """ + _function_internal._ScheduleComputeInline(self, parent) + + def compute_root(self, parent): + """Attach the schedule at parent, and mark it as root + + Parameters + ---------- + parent : Schedule + The parent schedule + """ + _function_internal._ScheduleComputeInline(self, parent) + + def reorder(self, *args): + """reorder the arguments in the specified order. + + Parameters + ---------- + args : list of IterVar + The order to be ordered + """ + _function_internal._ScheduleReorder(self, args) diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index ab1bdbe90..957080006 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -7,6 +7,8 @@ from . import expr as _expr class TensorSlice(SliceBase, _expr.ExprOp): """Auxiliary data structure for enable slicing syntax from tensor.""" def __init__(self, tensor, indices): + if not isinstance(indices, tuple): + indices = (indices,) self.tensor = tensor self.indices = indices diff --git a/src/c_api/c_api_lang.cc b/src/c_api/c_api_lang.cc index af6ce62e6..6ee137a7b 100644 --- a/src/c_api/c_api_lang.cc +++ b/src/c_api/c_api_lang.cc @@ -103,4 +103,53 @@ TVM_REGISTER_API(_Schedule) *ret = Schedule(args.at(0), args.at(1)); }); +TVM_REGISTER_API(_ScheduleSplitByFactor) +.set_body([](const ArgStack& args, RetValue *ret) { + IterVar outer, inner; + args.at(0).operator Schedule() + .split(args.at(1), &outer, &inner, args.at(2)); + *ret = Array<IterVar>({outer, inner}); + }); + +TVM_REGISTER_API(_ScheduleSplitByOuter) +.set_body([](const ArgStack& args, RetValue *ret) { + IterVar inner; + args.at(0).operator Schedule() + .split(args.at(1), args.at(2), &inner, args.at(3)); + *ret = inner; + }); + +TVM_REGISTER_API(_ScheduleFuse) +.set_body([](const ArgStack& args, RetValue *ret) { + IterVar fused; + args.at(0).operator Schedule() + .split(args.at(1), args.at(2), &fused); + *ret = fused; + }); + +TVM_REGISTER_API(_ScheduleComputeAt) +.set_body([](const ArgStack& args, RetValue *ret) { + args.at(0).operator Schedule() + .compute_at(args.at(1), args.at(2)); + }); + +TVM_REGISTER_API(_ScheduleComputeInline) +.set_body([](const ArgStack& args, RetValue *ret) { + args.at(0).operator Schedule() + .compute_inline(args.at(1)); + }); + +TVM_REGISTER_API(_ScheduleComputeRoot) +.set_body([](const ArgStack& args, RetValue *ret) { + args.at(0).operator Schedule() + .compute_root(args.at(1)); + }); + +TVM_REGISTER_API(_ScheduleReorder) +.set_body([](const ArgStack& args, RetValue *ret) { + args.at(0).operator Schedule() + .reorder(args.at(1)); + }); + + } // namespace tvm diff --git a/src/c_api/c_api_registry.h b/src/c_api/c_api_registry.h index 3e57d43c2..9edce8ea6 100644 --- a/src/c_api/c_api_registry.h +++ b/src/c_api/c_api_registry.h @@ -115,7 +115,7 @@ class APIVariantValue { CHECK_EQ(type_id, kNodeHandle); // use dynamic RTTI for safety CHECK(dynamic_cast<typename T::ContainerType*>(sptr.get())) - << "wrong type specified"; + << "wrong type specified, expected " << typeid(typename T::ContainerType).name(); return T(sptr); } inline operator Expr() const { diff --git a/src/lang/expr.cc b/src/lang/expr.cc index df2350148..3082bd579 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -57,7 +57,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) if (op->var->name_hint.length() != 0) { p->stream << op->var->name_hint << ", "; } - p->stream << op->dom; + if (op->dom.defined()) { + p->stream << op->dom; + } if (op->thread_tag.length() != 0) { p->stream << ", " << op->thread_tag; } diff --git a/src/lang/schedule.cc b/src/lang/schedule.cc index beb5ff1f1..47f5ee744 100644 --- a/src/lang/schedule.cc +++ b/src/lang/schedule.cc @@ -17,12 +17,38 @@ size_t FindIterVar(ArrayNode* array_node, const IterVar& v) { return array_node->data.size(); } -size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* const IterVar& v) { - size_t pos = Find(leaf_iter_vars, parent); +size_t FindLeafVar(ArrayNode* all_vars, ArrayNode* leaf_vars, const IterVar& v) { + size_t pos = FindIterVar(leaf_vars, v); + if (pos < leaf_vars->data.size()) return pos; + + if (FindIterVar(all_vars, v) < all_vars->data.size()) { + LOG(FATAL) << "Operate on iter var " << v + << "that has already been splitted"; + } else { + LOG(FATAL) << "Operate on iter var " << v + << "that is not part of the schedule"; + } + return 0; } +void Split(ScheduleNode* self, IterVar parent, + IterVar outer, IterVar inner, Expr factor) { + ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); + ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); + size_t pos = FindLeafVar(all_vars, leaf_vars, parent); + + self->relations.push_back(SplitNode::make(parent, outer, inner, factor)); + // add vars to all vars + all_vars->data.push_back(outer.node_); + all_vars->data.push_back(inner.node_); + // replace the position. + leaf_vars->data.erase(leaf_vars->data.begin() + pos); + leaf_vars->data.insert(leaf_vars->data.begin() + pos, inner.node_); + leaf_vars->data.insert(leaf_vars->data.begin() + pos, outer.node_); } +} // namespace + Schedule::Schedule(Operation op, std::string scope) { auto n = std::make_shared<ScheduleNode>(); n->op = op; @@ -36,6 +62,14 @@ Schedule& Schedule::compute_at(Schedule parent, IterVar scope) { // NOLINT(*) CHECK_EQ((*this)->attach_type, kNone); (*this)->attach_type = kScope; (*this)->attach_parent = scope; + bool found = false; + for (size_t i = 0; i < parent->leaf_iter_vars.size(); ++i) { + if (scope == parent->leaf_iter_vars[i]) { + found = true; break; + } + } + CHECK(found) + << "Cannot compute at a iteration variable that is not part of parent leaf vars"; parent->children.push_back(*this); return *this; } @@ -56,17 +90,63 @@ Schedule& Schedule::compute_root(Schedule parent) { // NOLINT(*) Schedule& Schedule::split( IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor) { // NOLINT(*) - ScheduleNode* self = operator->(); - ArrayNode* leaf_iter_vars = self->leaf_iter_vars.CopyOnWrite(); + // place holder for the splitted results. + IterVar outer(Range(), parent->var->name_hint + ".outer"); + IterVar inner(Range(), parent->var->name_hint + ".inner"); + *p_outer = outer; *p_inner = inner; - CHECK(pos != leaf_iter_vars->data.size()) - << "Cannot find IterVar " << parent << " in the active leaf vars" - << " this means " + Split(operator->(), parent, outer, inner, factor); + return *this; +} + +Schedule& Schedule::split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor) { // NOLINT(*) + // place holder for the splitted results. + IterVar inner(Range(), parent->var->name_hint + ".inner"); + *p_inner = inner; + Split(operator->(), parent, outer, inner, factor); return *this; } +Schedule& Schedule::fuse(IterVar inner, IterVar outer, IterVar* p_target) { // NOLINT(*) + IterVar fused(Range(), outer->var->name_hint + "." + inner->var->name_hint + ".fused"); + ScheduleNode* self = operator->(); + ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); + ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); + + self->relations.push_back(FuseNode::make(inner, outer, fused)); + all_vars->data.push_back(fused.node_); + + size_t pos_inner = FindLeafVar(all_vars, leaf_vars, inner); + size_t pos_outer = FindLeafVar(all_vars, leaf_vars, outer); + CHECK_EQ(pos_inner, pos_outer + 1) + << "Can only fuse iterations that are consecutive between each other"; + leaf_vars->data.erase(leaf_vars->data.begin() + pos_outer, + leaf_vars->data.begin() + pos_inner); + leaf_vars->data.insert(leaf_vars->data.begin() + pos_outer, + fused.node_); + return *this; +} + +Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*) + ScheduleNode* self = operator->(); + ArrayNode* all_vars = self->all_iter_vars.CopyOnWrite(); + ArrayNode* leaf_vars = self->leaf_iter_vars.CopyOnWrite(); + std::vector<size_t> pos; + for (size_t i = 0; i < order.size(); ++i) { + pos.push_back(FindLeafVar(all_vars, leaf_vars, order[i])); + } + std::vector<std::shared_ptr<Node> > temp; + for (size_t i = 0; i < pos.size(); ++i) { + temp.emplace_back(leaf_vars->data[pos[i]]); + } + std::sort(pos.begin(), pos.end()); + for (size_t i = 0; i < pos.size(); ++i) { + leaf_vars->data[pos[i]] = temp[i]; + } + return *this; +} IterVarRelation SplitNode::make( IterVar parent, IterVar outer, diff --git a/tests/python/test_schedule.py b/tests/python/test_schedule.py index 5d3f76131..773be8b55 100644 --- a/tests/python/test_schedule.py +++ b/tests/python/test_schedule.py @@ -6,28 +6,36 @@ def test_schedule_create(): l = tvm.Var('l') A = tvm.placeholder((m, l), name='A') B = tvm.placeholder((n, l), name='B') + AA = tvm.compute((m, l), lambda i, j: A[i, j]) T = tvm.compute((m, n, l), lambda i, j, k: A(i, k) * B(j, k)) - Tsch = tvm.Schedule(T.op, scope="shared") - Asch = tvm.Schedule(A.op) + sch_T = tvm.Schedule(T.op, scope="shared") + sch_A = tvm.Schedule(AA.op, scope="global") - T.op. + xo, xi = sch_T.split(T.op.dim_var[0], factor=10) + xi1, xi2 = sch_T.split(xi, factor=2) + sch_A.compute_at(sch_T, xi1) + xo, xi = sch_A.split(AA.op.dim_var[0], factor=10) - xo, xi = sch.split(sch.dim_var[0], factor) - Asch.compute_at(Tsch, xi) + sch_T.reorder(xi2, xi1) + assert T.op.dim_var[1] in sch_T.leaf_iter_vars - xf = sch.fuse(xo, xi) - - - tk1 = tvm.Split(T.op.dim_var[0], 10) - assert isinstance(sch, tvm.schedule.Schedule) - assert isinstance(tk1, tvm.schedule.DimSplit) +def test_reorder(): + m = tvm.Var('m') + A = tvm.placeholder((m,), name='A') + T = tvm.compute(m, lambda i: A[i+1]) - print(tk1.var) - print(sch.scope) - print(sch.attachs) + sch_T = tvm.Schedule(T.op, scope="shared") + xo, xi = sch_T.split(T.op.dim_var[0], factor=10) + xi1, xi2 = sch_T.split(xi, factor=2) + order = (xi2, xi1, xo) + assert tuple(sch_T.leaf_iter_vars) != order + sch_T.reorder(*order) + assert tuple(sch_T.leaf_iter_vars) == order if __name__ == "__main__": test_schedule_create() + test_reorder() + -- GitLab