diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 59429074f1320c4fc2208356b01f19bd6e5c96cb..d7f53bf3e4340efb998d869ebd934b2e4ed1c4dd 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -133,13 +133,13 @@ std::ostream& operator<<(std::ostream& os, const NodeRef& n); // NOLINT(*) */ class IterVarNode : public Node { public: - /*! \brief The looping variable */ - Var var; /*! * \brief the domain of iteration, if known, can be None * For the intermediate schedule node, before schedule. */ Range dom; + /*! \brief The looping variable */ + Var var; /*! * \brief additional tag on the iteration variable, * set this if this is binded already to a known thread tag. @@ -147,12 +147,13 @@ class IterVarNode : public Node { std::string thread_tag; void VisitAttrs(AttrVisitor* v) final { - v->Visit("var", &var); v->Visit("dom", &dom); + v->Visit("var", &var); v->Visit("thread_tag", &thread_tag); } - static IterVar make(Var var, Range dom, std::string thread_tag); + static IterVar make(Range dom, Var var, std::string thread_tag); + static constexpr const char* _type_key = "IterVar"; TVM_DECLARE_NODE_TYPE_INFO(IterVarNode); }; diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 5024308c986c5e2b0a59cbe4ea88e45af72baf85..d738f880247fb9ec39c7a93dd3d92e7f75251de2 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -17,6 +17,8 @@ namespace tvm { */ class ComputeOpNode : public OperationNode { public: + /*! \brief Iteration variables over the dimensions */ + Array<IterVar> dim_var; /*! \brief the compute expression */ Expr body; /*! \brief constructor */ @@ -25,19 +27,18 @@ class ComputeOpNode : public OperationNode { size_t num_outputs() const final { return 1; } + Array<IterVar> root_iter_vars() const final; std::string output_name(size_t i) const final; Type output_dtype(size_t i) const final; Array<Expr> output_shape(size_t i) const final; void VisitAttrs(AttrVisitor* v) final { - v->Visit("domain", &domain); v->Visit("name", &name); v->Visit("dim_var", &dim_var); v->Visit("body", &body); } - static Operation make(Domain domain, - std::string name, - Array<Var> dim_var, + static Operation make(std::string name, + Array<IterVar> dim_var, Expr body); static constexpr const char* _type_key = "ComputeOp"; diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 3e861835bcf26bc4d92a400aa785da3fa2e6860c..015b595ea013dcf8e20e00945367b6bdc7345821 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -8,15 +8,14 @@ #include <string> #include "./base.h" -#include "./split.h" #include "./operation.h" namespace tvm { // Node container for Schedule class ScheduleNode; -// Node container for AttachSpec -class AttachSpecNode; +// Node container for IterVarRelation +class IterVarRelationNode; /*! \brief the attachment type */ enum AttachType : int { @@ -38,42 +37,132 @@ class Schedule : public NodeRef { inline const ScheduleNode* operator->() const; }; +/*! + * \brief The schedule relation between IterVars + * can be Split, Fuse. + */ +class IterVarRelation : public NodeRef { + public: + IterVarRelation() {} + explicit IterVarRelation(std::shared_ptr<Node> n) : NodeRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const IterVarRelationNode* operator->() const; +}; + // defintion of node containers -/*! \brief represents the schedule of the tensor */ +/*! + * \brief represents the schedule of the tensor + * + * A schedule is a Directed acylic hypergraph. + * With each node is represented by a IterVar, + * and each hyper-edge is represented by a IterVarRelation. + * + * The relations can be Split/Fuse. + * + * The current data structure stores the hyper graph in its + * bipartite representation. + * + * The relations connects the IterVars in the graph. + */ class ScheduleNode : public Node { public: /*! \brief The operation to be scheduled */ Operation op; /*! \brief The thread scope level of the schedule */ std::string scope; - /*! \brief Splits over iteration domains */ - Array<Split> splits; + /*! \brief All the nodes in the iter var */ + Array<IterVar> all_iter_vars; + /*! + * \brief The current leafs in the schedule. + * Operations can only be performed in leaves. + */ + Array<IterVar> leaf_iter_vars; + /*! \brief The relation bwteen of IterVars */ + Array<IterVarRelation> relations; /*! \brief The attachment type of the schedule */ AttachType attach_type; /*! - * \brief The attach point of this schedule, if it is a split - * \note This is not a cyclic dependency, - * because split do not refer back to parent schedule. + * \brief The attach point of this schedule. */ - Split attach_parent; + IterVar attach_parent; /*! \brief the schedules that this schedule depend on */ Array<Schedule> children; - // the type key - const char* type_key() const final { - return "Schedule"; - } - const uint32_t type_index() const final { - static uint32_t tidx = TypeKey2Index(type_key()); - return tidx; - } + void VisitAttrs(AttrVisitor* v) final { v->Visit("scope", &scope); v->Visit("op", &op); - v->Visit("splits", &splits); + v->Visit("all_iter_vars", &all_iter_vars); + v->Visit("leaf_iter_vars", &leaf_iter_vars); + v->Visit("relations", &relations); v->Visit("attach_type", &attach_type); v->Visit("attach_parent", &attach_parent); v->Visit("children", &children); } + + static constexpr const char* _type_key = "Schedule"; + TVM_DECLARE_NODE_TYPE_INFO(ScheduleNode); +}; + +/*! \brief base node of iteration var */ +class IterVarRelationNode : public Node { +}; + +/*! + * \brief Split the parent domain into product of + * outer and iter. + */ +class SplitNode : public IterVarRelationNode { + public: + /*! \brief The parent domain */ + IterVar parent; + /*! \brief The outer domain */ + IterVar outer; + /*! \brief The inner domain */ + IterVar inner; + /*! \brief The split factor */ + Expr factor; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("parent", &parent); + v->Visit("outer", &outer); + v->Visit("inner", &inner); + v->Visit("factor", &factor); + } + + static IterVarRelation make( + IterVar parent, IterVar outer, + IterVar inner, Expr factor); + + static constexpr const char* _type_key = "Split"; + TVM_DECLARE_NODE_TYPE_INFO(SplitNode); +}; + +/*! + * \brief Fuse two domains into one domain. + */ +class FuseNode : public IterVarRelationNode { + public: + /*! \brief The outer domain */ + IterVar outer; + /*! \brief The inner domain */ + IterVar inner; + /*! \brief The target domain */ + IterVar fused; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("outer", &outer); + v->Visit("inner", &inner); + v->Visit("fused", &fused); + } + + static IterVarRelation make( + IterVar outer, IterVar inner, IterVar fused); + + static constexpr const char* _type_key = "Fuse"; + TVM_DECLARE_NODE_TYPE_INFO(FuseNode); }; // implementations @@ -81,5 +170,9 @@ inline const ScheduleNode* Schedule::operator->() const { return static_cast<const ScheduleNode*>(node_.get()); } +inline const IterVarRelationNode* IterVarRelation::operator->() const { + return static_cast<const IterVarRelationNode*>(node_.get()); +} + } // namespace tvm #endif // TVM_SCHEDULE_H_ diff --git a/include/tvm/split.h b/include/tvm/split.h deleted file mode 100644 index acc637bd85f8d74c1b162a938ef61bbd95eb47ee..0000000000000000000000000000000000000000 --- a/include/tvm/split.h +++ /dev/null @@ -1,65 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file split.h - * \brief Define a split over Domain or RDomain - */ -#ifndef TVM_SPLIT_H_ -#define TVM_SPLIT_H_ - -#include "./base.h" -#include "./expr.h" - -namespace tvm { - -// internal node container for split. -class SplitNode; - -/*! \brief Split over input domain */ -class Split : public NodeRef { - public: - /*! \brief default constructor */ - Split() {} - explicit Split(std::shared_ptr<Node> n) : NodeRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const SplitNode* operator->() const; -}; - -/*! - * \brief base class of split node, - * specifies a split over domain - * split also defines how to generate - */ -class SplitNode : public Node { - public: - /*! \brief the variable to be splitted on */ - Var var; -}; - -/*! \brief simple split node that splits over one dimension */ -class DimSplitNode : public SplitNode { - public: - /*! \brief The factor of the split */ - Expr factor; - /*! \brief constructor */ - DimSplitNode() {} - - void VisitAttrs(AttrVisitor* v) final { - v->Visit("var", &var); - v->Visit("factor", &factor); - } - static Split make(Var var, Expr factor); - - static constexpr const char* _type_key = "DimSplit"; - TVM_DECLARE_NODE_TYPE_INFO(DimSplitNode); -}; - -// Implementations of inline functions -inline const SplitNode* Split::operator->() const { - return static_cast<const SplitNode*>(node_.get()); -} - -} // namespace tvm -#endif // TVM_SPLIT_H_ diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index 8898f5b3645aa18462be893b044fa366125c7521..3df636f338edd07ee31107d30db0982f3615b051 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -174,12 +174,10 @@ class TensorNode : public FunctionBaseNode { */ class OperationNode : public Node { public: - /*! \brief The domain of iteration of this op. */ - Domain domain; - /*! \brief iter-Var over the dimensions */ - Array<Var> dim_var; /*! \brief optional name of the operation */ std::string name; + /*! \return the list of iteration variable at root */ + virtual Array<IterVar> root_iter_vars() const = 0; /*! \return number of outputs of this op */ virtual size_t num_outputs() const = 0; /*! \return name of i-th output */ diff --git a/python/tvm/function.py b/python/tvm/function.py index 01e8ee22a01fa9db267b64c3f9d0491c569ad60d..b1f91bd44de5dec4c7f4fa2362deef82ad418f0e 100644 --- a/python/tvm/function.py +++ b/python/tvm/function.py @@ -83,11 +83,11 @@ def compute(shape, fcompute, name="TensorCompute"): arg_names = fcompute.__code__.co_varnames if ndim != len(arg_names): raise ValueError("fcompute do not match dimension") - dim_var = [Var(x) for x in arg_names] - body = fcompute(*dim_var) - dom = [Range(0, x) for x in shape] + + dim_var = [IterVar((0, s), x) for x, s in zip(arg_names, shape)] + body = fcompute(*[v.var for v in dim_var]) op_node = _function_internal._ComputeOp( - dom, name, dim_var, body) + name, dim_var, body) return _function_internal._Tensor( shape, name, body.dtype, op_node, 0) diff --git a/src/c_api/c_api_lang.cc b/src/c_api/c_api_lang.cc index 61e348245266e8cdb4f9bd1c3c96597287f78b22..af6ce62e6f08a08b2ed1cbe0e0b22d8ebf0306aa 100644 --- a/src/c_api/c_api_lang.cc +++ b/src/c_api/c_api_lang.cc @@ -5,7 +5,6 @@ */ #include <tvm/expr.h> #include <tvm/tensor.h> -#include <tvm/split.h> #include <tvm/schedule.h> #include "./c_api_registry.h" @@ -89,8 +88,7 @@ TVM_REGISTER_API(_ComputeOp) .set_body([](const ArgStack& args, RetValue *ret) { *ret = ComputeOpNode::make(args.at(0), args.at(1), - args.at(2), - args.at(3)); + args.at(2)); }); @@ -100,11 +98,6 @@ TVM_REGISTER_API(_IterVar) }); -TVM_REGISTER_API(_DimSplit) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = DimSplitNode::make(args.at(0), args.at(1)); - }); - TVM_REGISTER_API(_Schedule) .set_body([](const ArgStack& args, RetValue *ret) { *ret = Schedule(args.at(0), args.at(1)); diff --git a/src/lang/expr.cc b/src/lang/expr.cc index 30f0f5b157a38b7585687331bece32617d0e0fca..df235014844672dc95dcd8182bf71247ee439814 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -24,12 +24,12 @@ Range Range::make_with_min_extent(Expr min, Expr extent) { } IterVar::IterVar(Range dom, std::string var_name, std::string thread_tag) - : IterVar(IterVarNode::make(Var(var_name, Int(32)), dom, thread_tag)) {} + : IterVar(IterVarNode::make(dom, Var(var_name, Int(32)), thread_tag)) {} -IterVar IterVarNode::make(Var var, Range dom, std::string thread_tag) { +IterVar IterVarNode::make(Range dom, Var var, std::string thread_tag) { std::shared_ptr<IterVarNode> n = std::make_shared<IterVarNode>(); - n->var = var; n->dom = dom; + n->var = var; n->thread_tag = thread_tag; return IterVar(n); } diff --git a/src/lang/operation.cc b/src/lang/operation.cc index dec710b1ffbfd82aeaa89ad983bda483e837b1c7..522a35a93353c31d480d4f39f724cef7435a88bd 100644 --- a/src/lang/operation.cc +++ b/src/lang/operation.cc @@ -13,32 +13,25 @@ Tensor Compute(Array<Expr> shape, FCompute fcompute, std::string name) { auto op_node = std::make_shared<ComputeOpNode>(); // compute dimension. size_t ndim = shape.size(); - std::vector<Var> dim_index; + std::vector<IterVar> dim_var; + std::vector<Var> args; for (size_t i = 0; i < ndim; ++i) { std::ostringstream os; os << "dim_var" << i; - dim_index.push_back(Var(os.str())); + dim_var.push_back(IterVar(Range(0, shape[i]), os.str())); + args.push_back(dim_var.back()->var); } - std::vector<Range> dom; - for (size_t i = 0; i < ndim; ++i) { - dom.push_back(Range(0, shape[i])); - } - - op_node->dim_var = Array<Var>(dim_index); - op_node->domain = Domain(dom); - op_node->body = fcompute(op_node->dim_var); + op_node->dim_var = Array<IterVar>(dim_var); + op_node->body = fcompute(args); op_node->name = name; - return Operation(op_node).output(0); } -Operation ComputeOpNode::make(Domain domain, - std::string name, - Array<Var> dim_var, +Operation ComputeOpNode::make(std::string name, + Array<IterVar> dim_var, Expr body) { auto n = std::make_shared<ComputeOpNode>(); - n->domain = domain; n->name = name; n->dim_var = dim_var; n->body = body; @@ -55,6 +48,10 @@ Tensor Operation::output(size_t i) const { return Tensor(node); } +Array<IterVar> ComputeOpNode::root_iter_vars() const { + return dim_var; +} + std::string ComputeOpNode::output_name(size_t i) const { CHECK_EQ(i, 0); return name; @@ -68,8 +65,9 @@ Type ComputeOpNode::output_dtype(size_t i) const { Array<Expr> ComputeOpNode::output_shape(size_t i) const { CHECK_EQ(i, 0); std::vector<Expr> shape; - for (size_t i = 0; i < domain.size(); ++i) { - shape.push_back(domain[i]->extent); + for (size_t i = 0; i < dim_var.size(); ++i) { + const Range& r = dim_var[i]->dom; + shape.push_back(r->extent); } return Array<Expr>(shape); } diff --git a/src/lang/schedule.cc b/src/lang/schedule.cc index 1477541a66cec15c8be29d2aa7132230d1937cda..1292c2d8a9256147ce8a65b3902fe271f789cc4f 100644 --- a/src/lang/schedule.cc +++ b/src/lang/schedule.cc @@ -13,6 +13,28 @@ Schedule::Schedule(Operation op, std::string scope) { node_ = n; } +IterVarRelation SplitNode::make( + IterVar parent, IterVar outer, + IterVar inner, Expr factor) { + auto n = std::make_shared<SplitNode>(); + n->parent = parent; + n->outer = outer; + n->inner = inner; + n->factor = factor; + return IterVarRelation(n); +} + +IterVarRelation FuseNode::make( + IterVar outer, IterVar inner, IterVar fused) { + auto n = std::make_shared<FuseNode>(); + n->outer = outer; + n->inner = inner; + n->fused = fused; + return IterVarRelation(n); +} + TVM_REGISTER_NODE_TYPE(ScheduleNode); +TVM_REGISTER_NODE_TYPE(SplitNode); +TVM_REGISTER_NODE_TYPE(FuseNode); } // namespace tvm diff --git a/src/lang/split.cc b/src/lang/split.cc deleted file mode 100644 index 55fbba613025ae7f50926fb17c1b974caaf094b6..0000000000000000000000000000000000000000 --- a/src/lang/split.cc +++ /dev/null @@ -1,20 +0,0 @@ -/*! - * Copyright (c) 2016 by Contributors - * \file split.cc - */ -#include <tvm/split.h> - -namespace tvm { - -Split DimSplitNode::make(Var var, - Expr factor) { - auto n = std::make_shared<DimSplitNode>(); - CHECK_EQ(factor.type().lanes(), 1); - n->var = var; - n->factor = factor; - return Split(n); -} - -TVM_REGISTER_NODE_TYPE(DimSplitNode); - -} // namespace tvm diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index 1c4e46f720bed97ec67fafabf0423ce45696f265..3bdafa75bb367a32c9b01a25230c8ff879130a1c 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -53,7 +53,8 @@ inline Array<IterVar> MutateRDom(Array<IterVar> rdom, IRMutator *m) { if (!r->min.same_as(new_min)) changed = true; if (!r->extent.same_as(new_extent)) changed = true; new_dom[i] = IterVarNode::make( - v->var, Range::make_with_min_extent(new_min, new_extent), v->thread_tag); + Range::make_with_min_extent(new_min, new_extent), + v->var, v->thread_tag); } if (!changed) { return rdom; diff --git a/src/pass/schedule_ops.cc b/src/pass/schedule_ops.cc index 366f30c1acf51fc4340bbfa59de56853381238f9..e11909489a48af2b25bf69fcb703589716544e4f 100644 --- a/src/pass/schedule_ops.cc +++ b/src/pass/schedule_ops.cc @@ -38,32 +38,6 @@ Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) { return body; } -void MakeLoop(const DimSplitNode* op, - const Split& s, - Scope<AttrKey, Expr>* pscope, - std::vector<Stmt>* nest) { - auto& scope = *pscope; - Expr out_min = scope[{op->var, "min"}]; - Expr out_ext = scope[{op->var, "extent"}]; - Expr stride = op->factor; - Var offset(s->var->name_hint + ".offset", Int(32)); - // for loop with stride - // TODO(tqchen) split the loop to deal with tails - nest->emplace_back( - For::make( - offset, out_min, out_ext, - ForType::Parallel, DeviceAPI::None, Stmt())); - Expr in_min = offset + out_min; - Expr in_ext = min(stride, out_ext - offset); - // declare min and extent of the corresponding variable - nest->emplace_back(AttrStmt::make(op->var, "min", in_min, Stmt())); - nest->emplace_back(AttrStmt::make(op->var, "extent", in_ext, Stmt())); - // declare this is the loop - nest->emplace_back(AttrStmt::make(s, "split", 0, Stmt())); - // setup the scope. - pscope->Push({op->var, "min"}, in_min); - pscope->Push({op->var, "extent"}, in_ext); -} Stmt MakePipeline(const Schedule& sch, Stmt body) {