diff --git a/HalideIR b/HalideIR index 7f1d811972bccc26f651ea2289d88bcadea9fe9f..bf96f8af0dfd1f79d258c7c1506f9ded932b94a9 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit 7f1d811972bccc26f651ea2289d88bcadea9fe9f +Subproject commit bf96f8af0dfd1f79d258c7c1506f9ded932b94a9 diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 0ba993b022653947d843a90294a0195fbfe3ed77..3ca1f81e28c9dcdf3697cca639dea58c05afbc3e 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -17,6 +17,7 @@ namespace tvm { namespace ir { using Halide::Internal::ExprNode; +using Halide::Internal::StmtNode; using Halide::Internal::IRNodeType; using Halide::Internal::ForType; @@ -47,6 +48,34 @@ struct Reduce : public ExprNode<Reduce> { static constexpr const char* Min = "Min"; }; +/*! + * \brief Define certain auxiliary attribute for the body to be a symbolic value. + * This is used to insert hint(shape, storage, split) about certain scopes. + */ +struct AttrStmt : public StmtNode<AttrStmt> { + /*! \brief this is attribute about certain node */ + NodeRef node; + /*! \brief the type key of the attribute */ + std::string type_key; + /*! \brief The attribute value, value is well defined at current scope. */ + Expr value; + /*! \brief The body statement to be executed */ + Stmt body; + + /*! \brief construct expr from name and rdom */ + static Stmt make(NodeRef node, std::string type_key, Expr value, Stmt body); + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("node", &node); + v->Visit("type_key", &type_key); + v->Visit("value", &value); + v->Visit("body", &body); + } + + static const IRNodeType _type_info = IRNodeType::ExtensionExpr; + static constexpr const char* _type_key = "AttrStmt"; +}; + // Reuse IR node defintiion from HalideIR using Halide::Internal::IntImm; using Halide::Internal::UIntImm; diff --git a/include/tvm/operation.h b/include/tvm/operation.h index 8bf72ab54630713073e5ecd91814d69d5b9cabec..f005fcca2062872d2d2785f15a21e92ac208d551 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -32,6 +32,7 @@ class ComputeOpNode : public OperationNode { std::string output_name(size_t i) const final; Type output_dtype(size_t i) const final; Array<Expr> output_shape(size_t i) const final; + void VisitAttrs(AttrVisitor* v) final { v->Visit("domain", &domain); v->Visit("name", &name); diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 6195cddfd1eb2506ed0d3d62c99cb4bf1aeec14b..1fbeb3a03a6caab2b72b426e0d720feafb83cbbd 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -38,42 +38,7 @@ class Schedule : public NodeRef { inline const ScheduleNode* operator->() const; }; -/*! \brief schedule container */ -class AttachSpec : public NodeRef { - public: - AttachSpec() {} - explicit AttachSpec(std::shared_ptr<Node> n) : NodeRef(n) {} - /*! - * \brief access the internal node container - * \return the pointer to the internal node container - */ - inline const AttachSpecNode* operator->() const; -}; - // defintion of node containers - -/*! \brief The attach specification of each subschedule */ -class AttachSpecNode : public Node { - public: - /*! \brief The attachment type */ - AttachType attach_type; - /*! - * \brief The split to be attached to, - * only valid when attach_type is kRoot - */ - Split attach_split; - /*! \brief the child schedule to be attached. */ - Schedule schedule; - const char* type_key() const final { - return "AttachSpec"; - } - void VisitAttrs(AttrVisitor* v) final { - v->Visit("attach_type", &attach_type); - v->Visit("attach_split", &attach_split); - v->Visit("schedule", &schedule); - } -}; - /*! \brief represents the schedule of the tensor */ class ScheduleNode : public Node { public: @@ -83,8 +48,17 @@ class ScheduleNode : public Node { std::string scope; /*! \brief Splits over iteration domains */ Array<Split> splits; - /*! \brief attach specifications */ - Array<AttachSpec> attachs; + /*! \brief The attachment type of the schedule */ + AttachType attach_type; + /*! + * \brief The attach point of this schedule, if it is a split + * \note This is not a cyclic dependency, + * because split do not refer back to parent schedule. + */ + Split attach_parent; + /*! \brief the schedules that this schedule depend on */ + Array<Schedule> children; + // the type key const char* type_key() const final { return "Schedule"; } @@ -92,7 +66,9 @@ class ScheduleNode : public Node { v->Visit("scope", &scope); v->Visit("op", &op); v->Visit("splits", &splits); - v->Visit("attachs", &attachs); + v->Visit("attach_type", &attach_type); + v->Visit("attach_parent", &attach_parent); + v->Visit("children", &children); } }; @@ -101,9 +77,5 @@ inline const ScheduleNode* Schedule::operator->() const { return static_cast<const ScheduleNode*>(node_.get()); } -inline const AttachSpecNode* AttachSpec::operator->() const { - return static_cast<const AttachSpecNode*>(node_.get()); -} - } // namespace tvm #endif // TVM_SCHEDULE_H_ diff --git a/src/c_api/c_api_ir.cc b/src/c_api/c_api_ir.cc index 94b65c230fed8d7447bad37cda231881af870c86..151f9da7a1b3e3a760b9b31200d53f8ba167ef35 100644 --- a/src/c_api/c_api_ir.cc +++ b/src/c_api/c_api_ir.cc @@ -29,13 +29,6 @@ TVM_REGISTER_API(_make_For) args.at(5)); }); -TVM_REGISTER_API(_make_Reduce) -.set_body([](const ArgStack& args, RetValue *ret) { - *ret = Reduce::make(args.at(0), - args.at(1), - args.at(2)); - }); - TVM_REGISTER_API(_make_Call) .set_body([](const ArgStack& args, RetValue *ret) { *ret = Call::make(args.at(0), @@ -54,22 +47,6 @@ TVM_REGISTER_API(_make_Allocate) args.at(4)); }); -TVM_REGISTER_API(_make_LetStmt) -.set_body([](const ArgStack& args, RetValue *ret) { - if (args.size() == 3) { - *ret = LetStmt::make(args.at(0), - args.at(1), - args.at(2)); - } else { - CHECK_EQ(args.size(), 5); - *ret = LetStmt::make(args.at(0), - args.at(1), - args.at(2), - args.at(3), - args.at(4)); - } - }); - // make from two arguments #define REGISTER_MAKE1(Node) \ TVM_REGISTER_API(_make_## Node) \ @@ -89,6 +66,12 @@ TVM_REGISTER_API(_make_LetStmt) *ret = Node::make(args.at(0), args.at(1), args.at(2)); \ }) \ +#define REGISTER_MAKE4(Node) \ + TVM_REGISTER_API(_make_## Node) \ + .set_body([](const ArgStack& args, RetValue *ret) { \ +*ret = Node::make(args.at(0), args.at(1), args.at(2), args.at(3)); \ + }) \ + #define REGISTER_MAKE_BINARY_OP(Node) \ TVM_REGISTER_API(_make_## Node) \ .set_body([](const ArgStack& args, RetValue *ret) { \ @@ -99,6 +82,9 @@ TVM_REGISTER_API(_make_LetStmt) .add_argument("lhs", "Expr", "left operand") \ .add_argument("rhs", "Expr", "right operand") +REGISTER_MAKE3(Reduce); +REGISTER_MAKE4(AttrStmt); + REGISTER_MAKE2(IntImm); REGISTER_MAKE2(UIntImm); REGISTER_MAKE2(FloatImm); @@ -123,6 +109,7 @@ REGISTER_MAKE3(Select); REGISTER_MAKE3(Ramp); REGISTER_MAKE2(Broadcast); REGISTER_MAKE3(Let); +REGISTER_MAKE3(LetStmt); REGISTER_MAKE2(AssertStmt); REGISTER_MAKE3(ProducerConsumer); REGISTER_MAKE3(Store); diff --git a/src/lang/ir.cc b/src/lang/ir.cc index 65fa69bf2d95152c02cedf7ced8f7e3e62d418d4..ecceb6dd1803ddd3d7f0bc7373affb1bf6708377 100644 --- a/src/lang/ir.cc +++ b/src/lang/ir.cc @@ -18,10 +18,16 @@ namespace Halide { namespace Internal { using tvm::ir::Reduce; +using tvm::ir::AttrStmt; template<> void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const { - LOG(FATAL) << "Reduce do not work with IRVisitor yet"; + LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor"; +} + +template<> +void StmtNode<AttrStmt>::accept(IRVisitor *v, const Stmt&) const { + LOG(FATAL) << "AttrStmt do not work with old Visitor, use IRFunctor style visitor"; } TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) @@ -33,15 +39,20 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) p->stream << ", rdom=" << op->rdom << ")"; }); +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch<AttrStmt>([](const AttrStmt *op, IRPrinter *p) { + p->stream << "attr " << op->type_key << " = "; + p->print(op->value); + p->stream << '\n'; + p->print(op->body); +}); + } // namespace Internal } // namespace Halide namespace tvm { namespace ir { -// reduce -TVM_REGISTER_NODE_TYPE(Reduce); - Expr Reduce::make(std::string op, Expr source, RDomain rdom) { auto n = std::make_shared<Reduce>(); CHECK(source.defined()); @@ -52,9 +63,17 @@ Expr Reduce::make(std::string op, Expr source, RDomain rdom) { return Expr(n); } +Stmt AttrStmt::make(NodeRef node, std::string type_key, Expr value, Stmt body) { + auto n = std::make_shared<AttrStmt>(); + n->node = node; + n->type_key = type_key; + n->value = value; + n->body = body; + return Stmt(n); +} -// HalideIR node -using namespace Halide::Internal; +TVM_REGISTER_NODE_TYPE(Reduce); +TVM_REGISTER_NODE_TYPE(AttrStmt); TVM_REGISTER_NODE_TYPE(FloatImm); TVM_REGISTER_NODE_TYPE(IntImm); diff --git a/src/lang/operation.cc b/src/lang/operation.cc index 49625f158b2805b6c399ffdca641666eb4b80898..dec710b1ffbfd82aeaa89ad983bda483e837b1c7 100644 --- a/src/lang/operation.cc +++ b/src/lang/operation.cc @@ -74,8 +74,6 @@ Array<Expr> ComputeOpNode::output_shape(size_t i) const { return Array<Expr>(shape); } - - TVM_REGISTER_NODE_TYPE(ComputeOpNode); } // namespace tvm diff --git a/src/lang/schedule.cc b/src/lang/schedule.cc index ba663a87dfe186c61b054ecc9585bdebdb58f020..1477541a66cec15c8be29d2aa7132230d1937cda 100644 --- a/src/lang/schedule.cc +++ b/src/lang/schedule.cc @@ -13,7 +13,6 @@ Schedule::Schedule(Operation op, std::string scope) { node_ = n; } -TVM_REGISTER_NODE_TYPE(AttachSpecNode); TVM_REGISTER_NODE_TYPE(ScheduleNode); } // namespace tvm diff --git a/src/pass/inline.cc b/src/pass/inline.cc index 669324225a44e130c699ad538701c79acc1d3edc..1fe16372d5ef5d7256c64302184d8cd4335664a0 100644 --- a/src/pass/inline.cc +++ b/src/pass/inline.cc @@ -19,11 +19,12 @@ class IRInline : public IRMutator { : f_(f), args_(args), body_(body) {} Expr Mutate(Expr expr) final { + expr = IRMutator::Mutate(expr); const Call* call = expr.as<Call>(); if (call != nullptr && call->func == f_) { return InlineCall(call); } else { - return IRMutator::Mutate(expr); + return expr; } } diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index db918c6f2f48cb7e7ac021e682664975d4e405c7..2b98ca625ed55465a8450de29ece55c0d1eda938 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -72,6 +72,18 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) } }); +TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt) +.set_dispatch<AttrStmt>([](const AttrStmt* op, const Stmt& s, IRMutator* m) { + Expr value = m->Mutate(op->value); + Stmt body = m->Mutate(op->body); + if (value.same_as(op->value) && + body.same_as(op->body)) { + return s; + } else { + return AttrStmt::make(op->node, op->type_key, op->value, op->body); + } + }); + TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr) .set_dispatch<IntImm>(ReturnSelfExpr) .set_dispatch<UIntImm>(ReturnSelfExpr) diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index f116ee72f0f1c257215743085e2a361f89b58f9b..d05baf8c5c81d0ad0ca4bf7ffefa7dbbb60da24a 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -65,6 +65,12 @@ TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) v->Visit(op->source); }); +TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) +.set_dispatch<AttrStmt>([](const AttrStmt* op, IRVisitor* v) { + v->Visit(op->value); + v->Visit(op->body); + }); + TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable) .set_dispatch<IntImm>(NoOp) .set_dispatch<UIntImm>(NoOp) diff --git a/src/pass/schedule_ops.cc b/src/pass/schedule_ops.cc index e713858264cf83a071ee4d5ac7bd021ed4645bf2..74ae8ec6ae57206f3e04b2ca9d71a7cc2140a1c1 100644 --- a/src/pass/schedule_ops.cc +++ b/src/pass/schedule_ops.cc @@ -13,11 +13,19 @@ namespace { // inject the operator's realization on the stmt. class InjectRealize : public IRMutator { public: - explicit InjectRealize(std::vector<Tensor> tensors) - : tensors_(tensors) {} - std::vector<Tensor> tensors_; -}; + explicit InjectRealize(Schedule sch) + : sch_(sch) {} + + Stmt Mutate(Stmt stmt) final { + stmt = IRMutator::Mutate(stmt); + const For* op = stmt.as<For>(); + return stmt; + } + private: + // the operations to be carried + Schedule sch_; +}; } // namespace } // namespace ir diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index 4ae1ad35f60a5c018b066837b446cfdc8e38f3b3..36033c2625a92592dff0682e723a21316f9676d4 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -22,10 +22,15 @@ def test_let(): x = tvm.Var('x') y = tvm.Var('y') stmt = tvm.make.LetStmt( - x, 10, tvm.make.Evaluate(x + 1), y, "stride") - assert stmt.attr_of_node == y - print(stmt) + x, 10, tvm.make.Evaluate(x + 1)); +def test_attr(): + x = tvm.Var('x') + y = tvm.Var('y') + stmt = tvm.make.AttrStmt( + y, "stride", 10, tvm.make.Evaluate(x + 1)); + assert stmt.node == y + print(stmt) def test_basic(): a = tvm.Var('a') @@ -44,6 +49,8 @@ def test_stmt(): if __name__ == "__main__": + test_attr() + test_const() test_make() test_ir()