diff --git a/HalideIR b/HalideIR index bf96f8af0dfd1f79d258c7c1506f9ded932b94a9..eb2f7d604a611318fc685172847bcf5ba2fcf835 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit bf96f8af0dfd1f79d258c7c1506f9ded932b94a9 +Subproject commit eb2f7d604a611318fc685172847bcf5ba2fcf835 diff --git a/include/tvm/domain.h b/include/tvm/domain.h index 634a72b97be8d82779cd23c22aaa8c45aa38d770..56d90e3a7512e71bbba3e33b4330daa68c014680 100644 --- a/include/tvm/domain.h +++ b/include/tvm/domain.h @@ -95,13 +95,13 @@ class RDomainNode : public Node { RDomainNode(Array<Var> index, Domain domain) : index(index), domain(domain) { } - const char* type_key() const override { - return "RDomain"; - } void VisitAttrs(AttrVisitor* v) final { v->Visit("index", &index); v->Visit("domain", &domain); } + + static constexpr const char* _type_key = "RDomain"; + TVM_DECLARE_NODE_TYPE_INFO(RDomainNode); }; inline const RDomainNode* RDomain::operator->() const { diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 468e1b13e27ecca5fed1aeb3aea5ccf474a85858..3106df1ffd0285be823b1282a0230f9ff9314a9f 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -6,7 +6,7 @@ #ifndef TVM_IR_MUTATOR_H_ #define TVM_IR_MUTATOR_H_ -#include <tvm/ir_node.h> +#include <tvm/ir_functor.h> #include <unordered_map> #include "./expr.h" @@ -16,7 +16,7 @@ namespace ir { * \brief a base class for mutator to iterative mutate the IR * * This IRMutator is implemented via IRFunctor instead of Visitor Pattern. - * This enables easy extensions of possible new IRNode. + * This enables easy extensions of possible new Node. * It also makes changing return types easier. * * \note If you want to return a different type other than Expr and Stmt, @@ -44,9 +44,9 @@ class IRMutator { /*! \brief destructor */ virtual ~IRMutator() {} /*! \brief functor type of expr mutation */ - using FMutateExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRMutator*)>; + using FMutateExpr = IRFunctor<Expr(const NodeRef&, const Expr&, IRMutator*)>; /*! \brief functor type of stmt mutation */ - using FMutateStmt = IRFunctor<Stmt(const IRNodeRef&, const Stmt&, IRMutator*)>; + using FMutateStmt = IRFunctor<Stmt(const NodeRef&, const Stmt&, IRMutator*)>; /*! \return internal vtable of expr */ static FMutateExpr& vtable_expr(); // NOLINT(*) /*! \return internal stmt of expr */ diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index c887c4b88bd5c88e27493018859d0f47093a4dc5..276bba9448f80265e32130e183efa17445508e73 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -9,7 +9,7 @@ #ifndef TVM_IR_PASS_H_ #define TVM_IR_PASS_H_ -#include <tvm/ir_node.h> +#include <tvm/ir_functor.h> #include <unordered_map> #include <vector> #include "./expr.h" diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h index 937808fe2201e01626749ea8acf53b2b51819a30..b64406d7ec4fecd49113517547ff43b846828d10 100644 --- a/include/tvm/ir_visitor.h +++ b/include/tvm/ir_visitor.h @@ -15,7 +15,7 @@ namespace ir { * \brief a base class for visitor to iterative traverse the IR * * This IRVisitor is implemented via IRFunctor - * This enables extensions of possible new IRNode. + * This enables extensions of possible new Node. * * \sa IRFunctor, PostOrderVisit */ @@ -24,14 +24,14 @@ class IRVisitor { /*! * \brief recursively visit an IR node */ - virtual void Visit(const IRNodeRef& node) { + virtual void Visit(const NodeRef& node) { static const FVisit& f = vtable(); if (node.defined()) f(node, this); } /*! \brief destructor */ virtual ~IRVisitor() {} /*! \brief functor type of visitor */ - using FVisit = IRFunctor<void(const IRNodeRef&, IRVisitor*)>; + using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>; /*! \return internal vtable*/ static FVisit& vtable(); }; @@ -42,7 +42,7 @@ class IRVisitor { * \param node The ir to be visited. * \param fvisit The visitor function to be applied. */ -void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit); +void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit); } // namespace ir } // namespace tvm diff --git a/include/tvm/operation.h b/include/tvm/operation.h index f005fcca2062872d2d2785f15a21e92ac208d551..841e9f4f25e8aabbd080075f8403d7cd794ef9e5 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -23,9 +23,6 @@ class ComputeOpNode : public OperationNode { /*! \brief constructor */ ComputeOpNode() {} - const char* type_key() const final { - return "ComputeOp"; - } size_t num_outputs() const final { return 1; } @@ -43,6 +40,9 @@ class ComputeOpNode : public OperationNode { std::string name, Array<Var> dim_var, Expr body); + + static constexpr const char* _type_key = "ComputeOp"; + TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode); }; diff --git a/include/tvm/schedule.h b/include/tvm/schedule.h index 1fbeb3a03a6caab2b72b426e0d720feafb83cbbd..3e861835bcf26bc4d92a400aa785da3fa2e6860c 100644 --- a/include/tvm/schedule.h +++ b/include/tvm/schedule.h @@ -62,6 +62,10 @@ class ScheduleNode : public Node { const char* type_key() const final { return "Schedule"; } + const uint32_t type_index() const final { + static uint32_t tidx = TypeKey2Index(type_key()); + return tidx; + } void VisitAttrs(AttrVisitor* v) final { v->Visit("scope", &scope); v->Visit("op", &op); diff --git a/include/tvm/split.h b/include/tvm/split.h index 27c9d77715cda831f1f3f7044da27b8679df85b7..47fde3c64e700046dd324ad6a48ba0fdebf5ddff 100644 --- a/include/tvm/split.h +++ b/include/tvm/split.h @@ -46,14 +46,15 @@ class DimSplitNode : public SplitNode { Expr factor; /*! \brief constructor */ DimSplitNode() {} - const char* type_key() const final { - return "DimSplit"; - } + void VisitAttrs(AttrVisitor* v) final { v->Visit("var", &var); v->Visit("factor", &factor); } static Split make(Var var, Expr factor); + + static constexpr const char* _type_key = "DimSplit"; + TVM_DECLARE_NODE_TYPE_INFO(DimSplitNode); }; // Implementations of inline functions diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index ce3d9db2314f59b8aff632c2359bacc77aeebb3e..55a92384d52f65dcdb804ba3fc6bf617bd62a7a3 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -104,9 +104,7 @@ class TensorNode : public FunctionBaseNode { int value_index{0}; /*! \brief constructor */ TensorNode() {} - const char* type_key() const final { - return "Tensor"; - } + void VisitAttrs(AttrVisitor* v) final { v->Visit("shape", &shape); v->Visit("name", &name); @@ -125,6 +123,9 @@ class TensorNode : public FunctionBaseNode { Type dtype, Operation op, int value_index); + + static constexpr const char* _type_key = "Tensor"; + TVM_DECLARE_NODE_TYPE_INFO(TensorNode); }; /*! diff --git a/include/tvm/tvm.h b/include/tvm/tvm.h index e825272bbc7e5caab1dd681beb4ea82702e022ac..5b7113dfd5ef0d0c0a2dbd9bae81937c5089efc6 100644 --- a/include/tvm/tvm.h +++ b/include/tvm/tvm.h @@ -9,5 +9,6 @@ #include "./base.h" #include "./expr.h" #include "./tensor.h" +#include "./operation.h" #endif // TVM_TVM_H_ diff --git a/src/c_api/c_api_function.cc b/src/c_api/c_api_function.cc index b7b3d9f956e3e750bb6986bc2c81cac53672c705..2fbfe2b9399713b33bf085771872a1061ab17b3b 100644 --- a/src/c_api/c_api_function.cc +++ b/src/c_api/c_api_function.cc @@ -26,9 +26,9 @@ TVM_REGISTER_API(_format_str) CHECK(args.at(0).type_id == kNodeHandle); std::ostringstream os; auto& sptr = args.at(0).sptr; - if (sptr->is_type<TensorNode>()) { + if (dynamic_cast<const TensorNode*>(sptr.get())) { os << args.at(0).operator Tensor(); - } else if (sptr->is_type<RDomainNode>()) { + } else if (dynamic_cast<const RDomainNode*>(sptr.get())) { os << args.at(0).operator RDomain(); } else if (dynamic_cast<const BaseExprNode*>(sptr.get())) { os << args.at(0).operator Expr(); diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index 2b98ca625ed55465a8450de29ece55c0d1eda938..2c44ea8fe4d1ddda469dcbecdf3aefddbdf2c593 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -22,7 +22,7 @@ namespace { using namespace Halide::Internal; // const expr -inline Expr ReturnSelfExpr(const IRNodeRef&, const Expr& e, IRMutator*) { +inline Expr ReturnSelfExpr(const NodeRef&, const Expr& e, IRMutator*) { return e; } diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc index d05baf8c5c81d0ad0ca4bf7ffefa7dbbb60da24a..d9ae0416b8094f6a36c6c46ba7d68bceae03b6d1 100644 --- a/src/pass/ir_visitor.cc +++ b/src/pass/ir_visitor.cc @@ -12,9 +12,9 @@ namespace { // visitor to implement apply class IRApplyVisit : public IRVisitor { public: - explicit IRApplyVisit(std::function<void(const IRNodeRef&)> f) : f_(f) {} + explicit IRApplyVisit(std::function<void(const NodeRef&)> f) : f_(f) {} - void Visit(const IRNodeRef& node) final { + void Visit(const NodeRef& node) final { if (visited_.count(node.get()) != 0) return; visited_.insert(node.get()); IRVisitor::Visit(node); @@ -22,13 +22,13 @@ class IRApplyVisit : public IRVisitor { } private: - std::function<void(const IRNodeRef&)> f_; + std::function<void(const NodeRef&)> f_; std::unordered_set<const Node*> visited_; }; } // namespace -void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit) { +void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) { IRApplyVisit(fvisit).Visit(node); } @@ -42,7 +42,7 @@ namespace { using namespace Halide::Internal; -void NoOp(const IRNodeRef& n, IRVisitor* v) { +void NoOp(const NodeRef& n, IRVisitor* v) { } inline void VisitArray(Array<Expr> arr, IRVisitor* v) { diff --git a/src/pass/schedule_ops.cc b/src/pass/schedule_ops.cc index fe9ede17c76b0dbf9e42dd93b698c505312157ec..4f6277c488d51f674e25623041eb71446e557564 100644 --- a/src/pass/schedule_ops.cc +++ b/src/pass/schedule_ops.cc @@ -5,21 +5,37 @@ #include <tvm/ir.h> #include <tvm/ir_mutator.h> #include <tvm/ir_pass.h> +#include "./scope.h" namespace tvm { namespace ir { namespace { -Stmt MakeCompute(const ComputeOpNode* op, const Array<Split>& splits) { - Tensor output; - std::vector<Expr> args(op->dim_var.size()); - for (size_t i = 0; i < args.size(); ++i) { - args[i] = op->dim_var[i]; +/*! + * \brief make nest loops given list of stmt, whose body is not defined. + * \param nest A list of For and LetStmt, whose body is not defined. + * \param body The inner-most body of the loop + */ +Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) { + while (!nest.empty()) { + Stmt s = std::move(nest.back()); nest.pop_back(); + if (s.as<For>()) { + auto n = std::make_shared<For>(*s.as<For>()); + n->body = body; + body = Stmt(n); + } else if (s.as<LetStmt>()) { + auto n = std::make_shared<LetStmt>(*s.as<LetStmt>()); + n->body = body; + body = Stmt(n); + } else if (s.as<AttrStmt>()) { + auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>()); + n->body = body; + body = Stmt(n); + } else { + LOG(FATAL) << "not supported nest type"; + } } - Array<Expr> values{op->body}; - Stmt stmt = Provide::make(output, values, args); - // add splits from ousside most to outsidemost to innermost - return stmt; + return body; } diff --git a/src/pass/scope.h b/src/pass/scope.h new file mode 100644 index 0000000000000000000000000000000000000000..36a38d67c55eb7857f8079496dc6f9a69aaf55d8 --- /dev/null +++ b/src/pass/scope.h @@ -0,0 +1,84 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file scope.h + * \brief attribute scope data structure, + * defines attributes on current domain + */ +#ifndef TVM_PASS_SCOPE_H_ +#define TVM_PASS_SCOPE_H_ + +#include <tvm/ir.h> +#include <unordered_map> +#include <vector> +#include <string> + +namespace tvm { +namespace ir { + +/*! + * \brief Attribute scope of Nodes in the IR. + * \tparam ValueType The value of of the scope. + */ +template<typename K, typename V> +class Scope { + public: + /*! + * \brief Push value to scope + * \param key the key to be pushed. + * \param v The value to be pushed. + */ + inline void Push(const K& key, V v) { + data_[key].emplace_back(v); + } + /*! + * \brief Pop value from scope. + * \param key the key to be poped + */ + inline void Pop(const K& key) { + auto& v = data_[key]; + CHECK_NE(v.size(), 0); + v.pop_back(); + } + + /*! + * \brief Get value from the scope + * \param key the key to fetch. + * \return The value to be fetched. + */ + inline V operator[](const K& key) const { + const auto it = data_.find(key); + CHECK(it != data_.end() && it->second.size() != 0) + << "cannot find value in scope"; + return it->second.back(); + } + + private: + std::unordered_map<K, std::vector<V> > data_; +}; + +/*! \brief Attribute key for specific attribute */ +struct AttrKey { + /*! \brief The node of the attribute */ + NodeRef node; + /*! \brief The type key of the attribute. */ + std::string type_key; + // overload operator == + inline bool operator==(const AttrKey& other) const { + return node == other.node && type_key == other.type_key; + } +}; +} // namespace ir +} // namespace tvm + +namespace std { +template <> +struct hash<::tvm::ir::AttrKey> { + std::size_t operator()(const ::tvm::ir::AttrKey& k) const { + size_t lhs = k.node.hash(); + size_t rhs = std::hash<std::string>()(k.type_key); + lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2); + return lhs; + } +}; +} // namespace std +#endif // TVM_PASS_SCOPE_H_ diff --git a/src/pass/ssa.cc b/src/pass/ssa.cc index 44b2454de44af678fb07eb7bee53492ab3ad0f06..9b01e24fee388be18122d38f967a7c36d2b84265 100644 --- a/src/pass/ssa.cc +++ b/src/pass/ssa.cc @@ -17,7 +17,7 @@ namespace { // global functor to get var definition from struct FGetVarDef { - using FType = IRFunctor<VarExpr (const IRNodeRef&)>; + using FType = IRFunctor<VarExpr (const NodeRef&)>; static FType& vtable() { // NOLINT(*) static FType inst; return inst; } @@ -37,8 +37,8 @@ TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable) }); struct FSetVarDef { - using FTypeExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>; - using FTypeStmt = IRFunctor<Stmt (const IRNodeRef&, VarExpr)>; + using FTypeExpr = IRFunctor<Expr (const NodeRef&, VarExpr)>; + using FTypeStmt = IRFunctor<Stmt (const NodeRef&, VarExpr)>; static FTypeExpr& vtable_expr() { // NOLINT(*) static FTypeExpr inst; return inst; } @@ -69,7 +69,7 @@ class IRVerifySSA : public IRVisitor { public: bool is_ssa{true}; - void Visit(const IRNodeRef& n) final { + void Visit(const NodeRef& n) final { if (!is_ssa) return; static auto& fget_var_def = FGetVarDef::vtable(); if (fget_var_def.can_dispatch(n)) { diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index 1f07ce5a20297bb5477a38441e2f9bea3fe605cb..8e0e68e3f5c280e38c2e0ff3cb4ca42ef73ca4fa 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -1,7 +1,7 @@ #include <dmlc/logging.h> #include <gtest/gtest.h> #include <tvm/tvm.h> -#include <tvm/ir_node.h> +#include <tvm/ir_functor.h> TEST(IRF, Basic) { using namespace Halide::Internal; @@ -9,7 +9,7 @@ TEST(IRF, Basic) { Var x("x"); auto z = x + 1; - IRFunctor<int(const IRNodeRef& n, int b)> f; + IRFunctor<int(const NodeRef& n, int b)> f; LOG(INFO) << "x"; f.set_dispatch<Variable>([](const Variable* n, int b) { return b; diff --git a/tests/cpp/ir_visitor_test.cc b/tests/cpp/ir_visitor_test.cc index adee708daaa1551f11c10f173954e16d3b5331b8..0a649a09304c7b0129c7e1749866d72bf22f4289 100644 --- a/tests/cpp/ir_visitor_test.cc +++ b/tests/cpp/ir_visitor_test.cc @@ -11,7 +11,7 @@ TEST(IRVisitor, CountVar) { Var x("x"), y; auto z = x + 1 + y + y; - ir::PostOrderVisit(z, [&n_var](const IRNodeRef& n) { + ir::PostOrderVisit(z, [&n_var](const NodeRef& n) { if (n.as<Variable>()) ++n_var; }); CHECK_EQ(n_var, 2);