diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 1a5470489ce2179a3841b0fb723bba4ad4461d05..2319f8baec00a8244fa0e5172a6499a6258a022b 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const { return node; } +/*! + * \brief Print node as text format. + * \param node The node to be printed. + * \param annotate An optional callback function for attaching + * additional comment block to an expr. + * \return The text representation. + */ +std::string RelayPrint( + const NodeRef& node, + runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr); } // namespace relay } // namespace tvm #endif // TVM_RELAY_EXPR_H_ diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 59ad52ccf3fdbd1eb165a240536a0d794e92def3..f25785d39eeb19bdb22f3cd5f29840cee5c66aec 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> { using TSelf = TypedPackedFunc<R(Args...)>; /*! \brief default constructor */ TypedPackedFunc() {} + /*! \brief constructor from null */ + TypedPackedFunc(std::nullptr_t null) {} // NOLINT(*) /*! * \brief construct by wrap a PackedFunc * diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py index 5a92eb57d209115c1850c4a9983f6be8840d053a..012315b40f5104b3809e079ba4db42698384c878 100644 --- a/python/tvm/relay/base.py +++ b/python/tvm/relay/base.py @@ -22,15 +22,20 @@ def register_relay_node(type_key=None): class RelayNode(NodeBase): - def astext(self): + """Base class of all relay node.""" + def astext(self, annotate=None): """Get the text format of the expression. Returns ------- text : str The text format of the expression. + + annotate: Optional[relay.Expr->str] + Optional annotate function to provide additional + information in the comment block. """ - return _expr._text_print(self) + return _expr.RelayPrint(self, annotate) @register_relay_node diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index c48ec90e9e12536330d2a676b431eb7b5399abb0..0f33e86ab5cdcfbe3f8a8c2f3d8c81cabe77e470 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -173,11 +173,13 @@ def build(func, else: tophub_context = autotvm.util.EmptyContext() + cfg = BuildConfig.current + with tophub_context: func = optimize(func) # Fuse ops before running code gen func = ir_pass.infer_type(func) - func = ir_pass.fuse_ops(func) + func = ir_pass.fuse_ops(func, cfg.opt_level) # Graph code generation func = ir_pass.infer_type(func) graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 43cff0bac57a06e00aa1031e826ae1ddcf57e627..f82ea09a102ae75ec51124b5d891bc3a34cfcdba 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -6,7 +6,6 @@ from numbers import Number as _Number import numpy as _np from .base import RelayNode, register_relay_node from . import _make -from . import _expr from . import ty as _ty from .._ffi import base as _base from .. import nd as _nd @@ -477,7 +476,7 @@ class TupleWrapper(object): text : str The text format of the tuple expression. """ - return _expr._text_print(self.tuple_value) + return self.tuple_value.astext() def __getitem__(self, index): if index >= len(self): diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 274761f0a27bf5f26839c120be5233b62303b7c0..b1a76d6fae6fb0ad4a126407d7f49299431b30f8 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -259,7 +259,7 @@ def structural_hash(value): raise TypeError(msg) -def fuse_ops(expr): +def fuse_ops(expr, opt_level=1): """Fuse operators in expr together. Parameters @@ -267,9 +267,12 @@ def fuse_ops(expr): expr : tvm.relay.Expr The input expression. + opt_level : int + The level of fuse optimization. + Returns ------- transformed_expr : tvm.relay.Expr Transformed expression, containing fused result. """ - return _ir_pass.FuseOps(expr) + return _ir_pass.FuseOps(expr, opt_level) diff --git a/src/common/arena.h b/src/common/arena.h index e8d4b2e23e37588f7836b7d19687565499735b32..c5da093a70b881d45ff07ea33d6f7f37033ac572 100644 --- a/src/common/arena.h +++ b/src/common/arena.h @@ -38,11 +38,29 @@ class Arena { /*! * \brief Allocate a space from Arena for type T * \param T the data type to be allocated + * \note The space of T is not initialized. */ template<typename T> - T* Alloc() { + T* allocate_() { return static_cast<T*>(Alloc(sizeof(T), alignof(T))); } + /*! + * \brief Create a new instance of type T. + * \param args The constructor argument. + * \tparam T the type to be created. + * \tparam Args Arguments to the constructor. + * + * \return The allocated object. + * \note The type T must be simple type, or only contain + * memory allocated from the same arena. + * Otherwise the destructor needs to be called explicitly. + */ + template<typename T, typename... Args> + T* make(Args&&... args) { + T* ptr = allocate_<T>(); + new (ptr) T(std::forward<Args>(args)...); + return ptr; + } private: // page size 16 KB @@ -87,6 +105,44 @@ class Arena { } }; +/*! + * \brief Link list node + * \tparam T the content data type + */ +template<typename T> +struct LinkNode { + /*! \brief The content value */ + T value; + /*! \brief pointer to the next location */ + LinkNode<T>* next{nullptr}; +}; +/*! + * \brief LinkedList structure + * \tparam T the content data type + * \note This is a simple data structure that can be used together with the arena. + * \sa LinkNode + */ +template<typename T> +struct LinkedList { + /*! \brief Head pointer */ + LinkNode<T>* head{nullptr}; + /*! \brief Tail pointer */ + LinkNode<T>* tail{nullptr}; + /*! + * \brief Push a new node to the end of the linked list. + * \param node The node to be pushed. + */ + void Push(LinkNode<T>* node) { + node->next = nullptr; + if (this->tail != nullptr) { + this->tail->next = node; + this->tail = node; + } else { + head = tail = node; + } + } +}; + } // namespace common } // namespace tvm #endif // TVM_COMMON_ARENA_H_ diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 38e3f6c2a7b81f0d9c11ee018e7ff35238b68d5c..dc094e00e05b38eea6e313128fab4614614f8638 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -109,6 +109,29 @@ class ScheduleGetter : return {}; } + Array<Tensor> VisitExpr_(const ConstantNode* op) final { + CHECK(op->is_scalar()); + void* data = op->data->data; + DataType dtype = TVMType2Type(op->data->dtype); + Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) { + if (dtype == Int(32)) { + return make_const(dtype, static_cast<const int32_t*>(data)[0]); + } else if (dtype == Int(64)) { + return make_const(dtype, static_cast<const int64_t*>(data)[0]); + } else if (dtype == Float(32)) { + return make_const(dtype, static_cast<const float*>(data)[0]); + } else if (dtype == Float(64)) { + return make_const(dtype, static_cast<const double*>(data)[0]); + } else if (dtype == Bool()) { + return make_const(dtype, static_cast<const uint8_t*>(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::Expr(); + } + }); + return {value}; + } + Array<Tensor> VisitExpr_(const CallNode* call_node) final { static auto fcompute = Op::GetAttr<FTVMCompute>("FTVMCompute"); diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index f28db371706e26677db3105d1f1529dfbf31e652..93ed76bed3c28e0cdfaf43fd64561f4160cab60c 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -125,6 +125,8 @@ class TextPrinter : public TypeFunctor<void (const Type&, std::ostream& os)>, // NOLINT(*) public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*) public: + explicit TextPrinter(runtime::TypedPackedFunc<std::string(Expr)> annotate) + : annotate_(annotate) {} /*! * \brief Print a node to string. * \param node. @@ -279,11 +281,11 @@ class TextPrinter : TextValue VisitExpr_(const CallNode* op) final { // possibly through meta-data - TextValue call_op = GetValue(op->op); std::vector<TextValue> args; for (Expr arg : op->args) { args.emplace_back(GetValue(arg)); } + TextValue call_op = GetValue(op->op); TextValue id = this->AllocTempVar(); this->PrintIndent(); @@ -532,7 +534,9 @@ class TextPrinter : */ void PrintOptionalInfo(const Expr& expr) { // additional information in comment. - if (expr->checked_type_.defined()) { + if (annotate_ != nullptr) { + stream_ << " # " << annotate_(expr); + } else if (expr->checked_type_.defined()) { stream_ << " # ty="; this->PrintType(expr->checked_type(), stream_); } @@ -678,7 +682,10 @@ class TextPrinter : name = "%" + name; } TextValue val(GetUniqueName(name)); - CHECK(!memo_.count(var)) << "Duplicated variable " << var; + // still print if ir is malformed, but show the error. + if (memo_.count(var)) { + memo_[var] = TextValue(val.name + "-malformed-ir"); + } memo_[var] = val; return val; } @@ -686,6 +693,8 @@ class TextPrinter : private: class AttrPrinter; friend class AttrPrinter; + /*! \brief additional comment function */ + runtime::TypedPackedFunc<std::string(Expr)> annotate_; /*! \brief meta data context */ TextMetaDataContext meta_; /*! \brief Check whether scope is still valid */ @@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op, os << ", " << meta_.GetMetaNode(attrs); } -std::string RelayPrint(const NodeRef& node) { - return TextPrinter().Print(node); +std::string RelayPrint(const NodeRef& node, + runtime::TypedPackedFunc<std::string(Expr)> annotate) { + return TextPrinter(annotate).Print(node); } -TVM_REGISTER_API("relay._expr._text_print") -.set_body_typed<std::string(const NodeRef&)>(RelayPrint); +TVM_REGISTER_API("relay._expr.RelayPrint") +.set_body_typed<std::string( + const NodeRef&, + runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc index e757118f33f27cc8903bb145e3591294cf300bbf..038f34df576087a88661e02090ead848b7b25109 100644 --- a/src/relay/pass/fold_scale_axis.cc +++ b/src/relay/pass/fold_scale_axis.cc @@ -10,6 +10,7 @@ #include <tvm/relay/attrs/nn.h> #include <tvm/relay/expr_functor.h> #include "pattern_util.h" +#include "pass_util.h" #include "../op/nn/layout.h" namespace tvm { @@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc< //---------------------------------------------- // Generic Visitors for FScaleAxisBackward //---------------------------------------------- -/*! - * \brief Get reference counter of each internal ExprNode in body. - * \param body The body expression. - * \return The reference count mapping. - */ -std::unordered_map<const Node*, size_t> -GetExprRefCount(const Expr& body) { - class ExprRefCounter : private ExprVisitor { - public: - std::unordered_map<const Node*, size_t> - Get(const Expr& body) { - this->VisitExpr(body); - return std::move(this->visit_counter_); - } - }; - return ExprRefCounter().Get(body); -} class BackwardPrep : private ExprVisitor { public: diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 2bd16a4f840fb8962b9814c9971ded65260e87f9..2503bd5f53faee2ecd2325d5b6043368aa1efa52 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -9,13 +9,686 @@ #include <tvm/ir_operator.h> #include <tvm/relay/pass.h> #include <tvm/relay/expr_functor.h> +#include <tvm/relay/op_attr_types.h> +#include "../../common/arena.h" + namespace tvm { namespace relay { -// Simple fuser that only makes each operator function as primitive. -class SimpleFuser : public ExprMutator { +/* + Note on Fusing algorithm: + + The main challenge of genenral fusor is to handle possible diamond shape branches, + in the following graph, conv2d can be fused to elemwise add. + + conv2d + / | \ + / | \ + op op op + \ | / + \ | / + elemwise add + | + + However, at the point of conv2d we do not necessarily know that all its future path + will merge at the elemwise add. The new fusor algorithm applies post-dominator analysis. + The immediate post-dominator of a node defined by the closest node where all the future path goes into. + In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm is as follows: + + - Construct a DAG of dataflow graph for dominator analysis + - Construct a post-dominator tree which gives immediate post dominator of each node. + - Run fusion algorithm with the given post-dominator information. + + Note that, because we run analysis on a DAG, we use a single pass post-dominator + tree construction algorithm via LCA, which is simpler than the full version that handles cycles. + + The fusion algorithm traverses from each node and checks if it can be fused to its + immediate post dominator. It has to check the following things: + + - CheckPath: check all the path between a node and its immediate post-dominator + satiesfies the fuse condition. + - Note that these intermediate node can already be fused with another nodes, the algorithm + will still run correctly. + - CommitFuse: mark all the nodes between source and post-dominator as the same group. + - We use an Union-Find data structure to manage the groups. +*/ +using common::LinkNode; +using common::LinkedList; + +/*! + * \brief Indexed data flow graph in forward direction. + * This is a temporary data structure used for operator fusion analysis. + * + * This data structure only captures the dataflow fragement and + * could ignore blocks like let by simply ordering each dataflow block + * and mark the output node as extern_ref; + */ +class IndexedForwardGraph { + public: + struct Node; + /*! + * The forward edge in the dataflow graph. + */ + struct Edge { + /*! \brief The corresponding node */ + Node* node{nullptr}; + /*! \brief The respective pattern of this op */ + OpPatternKind pattern{kOpaque}; + }; + /*! \brief A node in the graph. */ + struct Node { + /*! \brief weak reference to the corresponding edge. */ + const tvm::Node* ref{nullptr}; + /*! \brief The index of the node in topological order. */ + size_t index{0}; + /*! \brief Whether this node is referenced by external source */ + bool extern_ref{false}; + /*! \brief The general pattern in the node */ + OpPatternKind pattern{kOpaque}; + /*! \brief The outputs of the node. */ + LinkedList<Edge> outputs; + }; + /*! \brief The node map that maps node to graph */ + std::unordered_map<const tvm::Node*, Node*> node_map; + /*! \brief All the nodes in post DFS order */ + std::vector<Node*> post_dfs_order; + + /*! \brief Dump the graph into string. */ + void DebugDump() { + std::ostringstream os; + for (size_t i = 0; i < post_dfs_order.size(); ++i) { + Node* node = post_dfs_order[i]; + os << "node[" << i << "], " + << GetRef<NodeRef>(node->ref) + << " outputs=["; + for (auto* link = node->outputs.head; link != nullptr; link = link->next) { + os << link->value.node->index << ", "; + } + os << "]\n"; + } + LOG(INFO) << os.str(); + } + /*! + * \brief create a indexed forward graph. + * \param arena The arena used for data allocation. + * \param body The body of the expression to create a graph. + */ + static IndexedForwardGraph Create(common::Arena* arena, const Expr& body); + + private: + class Creator; +}; + +// Creator of post dominator tree of the dataflow +class IndexedForwardGraph::Creator : private ExprVisitor { + public: + explicit Creator(common::Arena* arena) + : arena_(arena) {} + + IndexedForwardGraph Prepare(const Expr& body) { + this->Update(body, nullptr, kOpaque); + this->VisitExpr(body); + return std::move(graph_); + } + + private: + /*! \brief allocator of all the internal node object */ + common::Arena* arena_; + // The output. + IndexedForwardGraph graph_; + // attribute equal comparator + AttrsEqual attr_equal_; + // Update the message stored at the node. + void Update(const Expr& node, + IndexedForwardGraph::Node* parent, + OpPatternKind pattern) { + const tvm::Node* key = node.get(); + IndexedForwardGraph::Node* current; + auto it = graph_.node_map.find(key); + if (it != graph_.node_map.end()) { + current = it->second; + } else { + current = arena_->make<IndexedForwardGraph::Node>(); + graph_.node_map[key] = current; + } + if (parent != nullptr) { + auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge> >(); + link->value.node = parent; + link->value.pattern = pattern; + current->outputs.Push(link); + } else { + current->extern_ref = true; + } + } + void AddNode(const tvm::Node* key) { + auto it = graph_.node_map.find(key); + CHECK(it != graph_.node_map.end()) + << "Cannot find node " << GetRef<NodeRef>(key); + IndexedForwardGraph::Node* node = it->second; + CHECK(node->ref == nullptr); + node->ref = key; + node->index = graph_.post_dfs_order.size(); + graph_.post_dfs_order.push_back(node); + } + + // Post order tree + void VisitExpr_(const FunctionNode* op) { + for (auto param : op->params) { + this->Update(param, nullptr, kOpaque); + } + this->Update(op->body, nullptr, kOpaque); + ExprVisitor::VisitExpr_(op); + } + + void VisitExpr_(const ConstantNode* op) { + this->AddNode(op); + Node* node = graph_.node_map.at(op); + DataType dtype = TVMType2Type(op->data->dtype); + // This rule must be consistent with code generator. + bool is_simple_const = ( + dtype == Int(32) || + dtype == Int(64) || + dtype == Float(32) || + dtype == Float(64) || + dtype == Bool()); + if (op->is_scalar() && is_simple_const) { + node->pattern = kElemWise; + } else { + // for now, mark non-scalar constant + // as opaque, we will not choose to fuse it. + node->pattern = kOpaque; + } + } + + void VisitExpr_(const CallNode* call) { + CHECK(graph_.node_map.count(call)); + Node* node = graph_.node_map.at(call); + static auto fpattern = + Op::GetAttr<TOpPattern>("TOpPattern"); + // setup pattern. + OpPatternKind op_pattern = kOpaque; + if (const OpNode* opnode = call->op.as<OpNode>()) { + op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]); + } + node->pattern = op_pattern; + const auto* rtype = call->checked_type().as<TensorTypeNode>(); + // pass the message back to all the children it references. + for (size_t i = 0; i < call->args.size(); ++i) { + const auto* arg_type = + call->args[i]->checked_type().as<TensorTypeNode>(); + // specifically check if result type + OpPatternKind edge_pattern = op_pattern; + if (edge_pattern == kBroadcast && + arg_type != nullptr && + rtype != nullptr && + attr_equal_(rtype->shape, arg_type->shape)) { + edge_pattern = kElemWise; + } + this->Update(call->args[i], node, edge_pattern); + } + ExprVisitor::VisitExpr_(call); + this->AddNode(call); + } + + void VisitExpr_(const TupleNode* op) { + for (const Expr& field : op->fields) { + this->Update(field, nullptr, kOpaque); + } + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } + + void VisitExpr_(const TupleGetItemNode* op) { + CHECK(graph_.node_map.count(op)); + Node* node = graph_.node_map.at(op); + this->Update(op->tuple, node, kOpaque); + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } + + void VisitExpr_(const VarNode* op) { + this->AddNode(op); + } + + void VisitExpr_(const LetNode* op) { + // do not fuse through let. + this->Update(op->var, nullptr, kOpaque); + this->Update(op->value, nullptr, kOpaque); + this->Update(op->body, nullptr, kOpaque); + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } + + void VisitExpr_(const IfNode* op) { + // do not fuse through if. + this->Update(op->cond, nullptr, kOpaque); + this->Update(op->true_branch, nullptr, kOpaque); + this->Update(op->false_branch, nullptr, kOpaque); + ExprVisitor::VisitExpr_(op); + this->AddNode(op); + } +}; + +IndexedForwardGraph IndexedForwardGraph::Create( + common::Arena* arena, const Expr& body) { + return Creator(arena).Prepare(body); +} + +/*! + * \brief Dominator tree that represent domination or + * post domination relation of the node. + */ +class DominatorTree { public: + /*! + * \brief A node in the dominator tree. + */ + struct Node { + /*! \brief The node in the tree */ + IndexedForwardGraph::Node* gnode{nullptr}; + /*! \brief parent of the tree */ + Node* parent{nullptr}; + /*! \brief current depth*/ + int depth{0}; + /*! \brief aggregated pattern to parent */ + OpPatternKind pattern{kOpaque}; + }; + // index -> node. + std::vector<Node*> nodes; + /*! + * \brief compute a post dominator relation for a given dataflow graph. + * \param arena The arena used for node allocation. + * \param graph The graph to be analyze. + * \return The dominator tree of the graph. + * \note This algorithm makes use of the fact that graph is DAG, + * and runs a single pass algorithm via LCA. + */ + static DominatorTree PostDom(common::Arena* arena, + const IndexedForwardGraph& graph); + + private: + // Combine pattern together. + static OpPatternKind CombinePattern( + OpPatternKind lhs, OpPatternKind rhs) { + if (lhs > rhs) return lhs; + return rhs; + } + /*! + * \brief Find the least common acenstor of the two nodes. + * \param lhs The left node. + * \param rhs The right node. + * \param edge_pattern + * The combined edge pattern across all the parents. + * \return The least common acenstor of thw two. + */ + static Node* LeastCommonAcenstor( + Node* lhs, + Node* rhs, + OpPatternKind* edge_pattern) { + while (lhs != rhs) { + if (lhs == nullptr) return nullptr; + if (rhs == nullptr) return nullptr; + if (lhs->depth < rhs->depth) { + edge_pattern[0] = CombinePattern( + edge_pattern[0], rhs->pattern); + rhs = rhs->parent; + } else if (rhs->depth < lhs->depth) { + edge_pattern[0] = CombinePattern( + edge_pattern[0], lhs->pattern); + lhs = lhs->parent; + } else { + lhs = lhs->parent; + rhs = rhs->parent; + edge_pattern[0] = CombinePattern( + edge_pattern[0], lhs->pattern); + edge_pattern[0] = CombinePattern( + edge_pattern[0], rhs->pattern); + } + } + return lhs; + } +}; + +DominatorTree DominatorTree::PostDom(common::Arena* arena, + const IndexedForwardGraph& graph) { + DominatorTree tree; + tree.nodes.resize(graph.post_dfs_order.size(), nullptr); + // reverse topo order + for (size_t i = graph.post_dfs_order.size(); i != 0; --i) { + size_t index = i - 1; + Node* tnode = arena->make<Node>(); + auto* gnode = graph.post_dfs_order[index]; + tnode->gnode = gnode; + if (gnode->extern_ref) { + tnode->depth = 1; + tnode->parent = nullptr; + tnode->pattern = kOpaque; + } else { + // find the LCAs of all outputs. + OpPatternKind pattern = kElemWise; + Node* parent = nullptr; + for (auto link = gnode->outputs.head; link != nullptr; link= link->next) { + size_t oindex = link->value.node->index; + CHECK_LT(oindex, tree.nodes.size()); + Node* onode = tree.nodes[oindex]; + CHECK(onode != nullptr); + if (parent != nullptr) { + parent = LeastCommonAcenstor(parent, onode, &pattern); + } else { + parent = onode; + } + pattern = CombinePattern(pattern, link->value.pattern); + } + CHECK(parent != nullptr); + tnode->depth = parent->depth + 1; + tnode->parent = parent; + tnode->pattern = pattern; + } + tree.nodes[index] = tnode; + } + return tree; +} + +/*! + * \brief A partition of the graph marked by union find data structure. + */ +class GraphPartitioner { + public: + explicit GraphPartitioner(common::Arena* arena, int opt_level) + : arena_(arena), opt_level_(opt_level) {} + /*! + * \brief Group as a union find data structure. + */ + struct Group { + /*! \brief The parent in the union find data structure. */ + Group* parent{nullptr}; + /*! \brief The pattern of the group */ + OpPatternKind pattern; + /*! \brief reference to the root node. */ + const tvm::Node* root_ref{nullptr}; + /*! + * \brief Reference to the master node, + * this field is not nullptr only if pattern is kOutEWiseFusable. + */ + const tvm::Node* master_ref{nullptr}; + /*! + * \brief Find the group root, perform path compression + * \return The root type node. + */ + Group* FindRoot() { + // fast path + if (this->parent == nullptr) return this; + // slow path with path compression. + Group* root = this; + while (root->parent != nullptr) { + root = root->parent; + } + for (Group* p = this; p != root;) { + Group* parent = p->parent; + p->parent = root; + p = parent; + } + return root; + } + }; + /*! + * \brief Partition a graph. + * \return group assignments of each node. + */ + std::vector<Group*> Partition(const IndexedForwardGraph& graph); + + private: + /*! \brief The internal arena for temporary space. */ + common::Arena* arena_; + /*! \brief optimization level for fuse operation. */ + int opt_level_; + /*! \brief The internal groups. */ + std::vector<Group*> groups_; + /*! \brief internal field used for deduplication */ + std::unordered_set<IndexedForwardGraph::Node*> visited_; + // Internal implelementation of CheckPath + template<typename F> + bool CheckPath_(IndexedForwardGraph::Node* src, + IndexedForwardGraph::Node* sink, + F fcond) { + if (visited_.count(src)) return true; + visited_.insert(src); + Group* gnode = groups_[src->index]; + CHECK(gnode != nullptr); + gnode = gnode->FindRoot(); + if (!fcond(gnode->pattern, src == sink)) return false; + if (src == sink) return true; + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + if (!CheckPath_(link->value.node, sink, fcond)) return false; + } + return true; + } + /*! + * \brief Check all the node between src and sink satisfies fcond. + * + * src and sink are not checked. + * + * \param src The source node. + * \param sink The termination node. + * \param fcond The condition to be checked. + * \tparam F the condition function. + * \note sink must be a post-dominator of src. + */ + template<typename F> + bool CheckPath(IndexedForwardGraph::Node* src, + IndexedForwardGraph::Node* sink, + F fcond) { + CHECK(!src->extern_ref); + visited_.clear(); + CHECK(src != sink); + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + if (!CheckPath_(link->value.node, sink, fcond)) return false; + } + return true; + } + // Combine two patterns together. + static OpPatternKind CombinePattern( + OpPatternKind lhs, OpPatternKind rhs) { + if (lhs > kBroadcast && rhs > kBroadcast) { + LOG(FATAL) << "Cannot merge two complex group together"; + } + if (lhs > rhs) return lhs; + return rhs; + } + /*! + * \brief Merge the child group to the parent. + * \param child The child group. + * \param parent The parent group. + */ + void MergeFromTo(Group* child, Group* parent) { + child = child->FindRoot(); + parent = parent->FindRoot(); + if (child == parent) return; + child->parent = parent; + // update master ref and pattern + if (child->master_ref != nullptr) { + CHECK(parent->master_ref == nullptr); + parent->master_ref = child->master_ref; + parent->pattern = CombinePattern( + child->pattern, parent->pattern); + } + } + // Internal implelementation of CommitFuse + void CommitFuse_(IndexedForwardGraph::Node* src, + IndexedForwardGraph::Node* sink, + Group* target) { + if (src == sink) return; + if (visited_.count(src)) return; + visited_.insert(src); + Group* gnode = groups_[src->index]; + CHECK(gnode != nullptr); + // merge the current group to the parent if possible. + MergeFromTo(gnode, target); + for (auto link = src->outputs.head; link != nullptr; link = link->next) { + CommitFuse_(link->value.node, sink, target);; + } + } + /*! + * \brief Commit fusion operation. + * \param src The source node. + * \param sink The termination node. + * \tparam group the group to be committed. + * \note sink must be a post-dominator of src. + */ + void CommitFuse(IndexedForwardGraph::Node* src, + IndexedForwardGraph::Node* sink) { + Group* target = groups_[sink->index]; + visited_.clear(); + CHECK(src != sink); + CommitFuse_(src, sink, target); + } + + // Initialize the groups. + void InitGroups(const IndexedForwardGraph& graph) { + groups_.resize(graph.post_dfs_order.size()); + for (size_t nid = 0; nid < groups_.size(); ++nid) { + const auto* graph_node = graph.post_dfs_order[nid]; + auto* group_node = arena_->make<Group>(); + group_node->pattern = graph_node->pattern; + group_node->root_ref = graph_node->ref; + // set master ref if necessary. + if (group_node->pattern == kOutEWiseFusable) { + group_node->master_ref = graph_node->ref; + } + groups_[nid] = group_node; + } + } + + // execute the fusion algorithm. + void RunFuse(const IndexedForwardGraph& graph, + const DominatorTree& post_dom_tree, + int phase) { + for (size_t nid = 0; nid < groups_.size(); ++nid) { + // the group of current node has been specified already. + auto* graph_node = graph.post_dfs_order[nid]; + auto* dom_node = post_dom_tree.nodes[nid]; + Group* group_node = groups_[nid]; + CHECK(group_node != nullptr); + // no actions for opaque nodes + if (group_node->pattern == kOpaque) continue; + // no actions needed if the current node have no dominator + if (dom_node->parent == nullptr) continue; + CHECK(!graph_node->extern_ref); + // Skip if current node is already fused to the parent. + size_t dom_parent_gindex = dom_node->parent->gnode->index; + if (groups_[dom_parent_gindex] != nullptr && + group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) { + continue; + } + // Try to fuse current node to its post-dominator. + if (group_node->pattern == kOutEWiseFusable) { + if (phase != 0) continue; + // Path for OutEWiseFusable: conv2d + // Check if the dominator relation is elemwise. + if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) { + CHECK(dom_node->parent->gnode != nullptr); + // The fuse can be executed if all the intermediate ops are still broadcast. + auto fcond = [](OpPatternKind kind, bool is_sink) { + return kind <= kBroadcast; + }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } + } else if (group_node->pattern <= kBroadcast) { + // The fuse can be executed if all the intermediate ops are still broadcast. + auto fcond = [](OpPatternKind kind, bool is_sink) { + if (!is_sink) { + return kind <= kBroadcast; + } else { + return (kind <= kBroadcast || + kind == kCommReduce || + kind == kOutEWiseFusable); + } + }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } else if (group_node->pattern == kInjective) { + // defer injective fusion to second phase. + // so conv2d always finishes fusing. + if (phase != 1) continue; + // Check if all path are injective. + auto fcond = [](OpPatternKind kind, bool is_sink) { + return kind <= kInjective; + }; + if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) { + CommitFuse(graph_node, dom_node->parent->gnode); + } + } else { + // do nothing. + CHECK(group_node->pattern == kCommReduce); + } + } + } +}; + +std::vector<GraphPartitioner::Group*> +GraphPartitioner::Partition(const IndexedForwardGraph& graph) { + this->InitGroups(graph); + if (opt_level_ == 0) return std::move(groups_); + // get post dominator tree + auto post_dom_tree = DominatorTree::PostDom(arena_, graph); + // run fusion algorithm. + for (int phase = 0; phase < 2; ++phase) { + this->RunFuse(graph, post_dom_tree, phase); + } + return std::move(groups_); +} + +class FuseMutator : private ExprMutator { + public: + // Run the transform + Expr Transform(const Expr& body, int fuse_opt_level) { + // setup the group map. + auto graph = IndexedForwardGraph::Create(&arena_, body); + auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition( + graph); + for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) { + CHECK(graph.post_dfs_order[nid]->ref != nullptr); + gmap_[graph.post_dfs_order[nid]->ref] = groups[nid]; + } + // The following line can be used for debug. + // this->DebugDumpGroup(body); + return this->Mutate(body); + } + + + private: + /*! \brief Temporary information from each group. */ + struct GroupInfo { + public: + // The parameters of the function. + Array<Var> params; + // The arguments to call the functions. + Array<Expr> arguments; + // Get a new parameter or allocate an old one + Var GetOrAllocParam(const Expr& expr, const Type& type) { + // run linear scan as most fused groups contain only a few inputs. + for (size_t i = 0; i < arguments.size(); ++i) { + if (expr.same_as(arguments[i])) return params[i]; + } + // create a new parameter. + std::ostringstream os; + os << "p" << params.size(); + auto var = VarNode::make(os.str(), type); + params.push_back(var); + arguments.push_back(expr); + return var; + } + }; + /*! \brief Internal arena. */ + common::Arena arena_; + /*! \brief The group assignment map. */ + std::unordered_map<const Node*, GraphPartitioner::Group*> gmap_; + /* \brief Internal group information map. */ + std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_; // Skip primitive function. Expr VisitExpr_(const FunctionNode* fn_node) { NodeRef res = FunctionGetAttr(GetRef<Function>(fn_node), "Primitive"); @@ -26,48 +699,74 @@ class SimpleFuser : public ExprMutator { return ExprMutator::VisitExpr_(fn_node); } } - + // Transform calls. Expr VisitExpr_(const CallNode* call) { if (call->op.as<OpNode>()) { - // Placeholder fusion algorithm which abstracts - // single definitions into functions only. - Array<Var> params; - Array<Expr> inner_args; - Array<Expr> args; - - int param_number = 0; + // If it is a primitive op call + // then we must have a group assignment for it already. + CHECK(gmap_.count(call)); + auto* ret_group = gmap_.at(call)->FindRoot(); + Array<Expr> new_args; for (auto arg : call->args) { - std::ostringstream os; - os << "p" << param_number++; auto type = arg->checked_type(); - auto var = VarNode::make(os.str(), type); - params.push_back(var); - inner_args.push_back(var); - args.push_back(this->Mutate(arg)); + CHECK(gmap_.count(arg.get())) + << "cannot find group of " << arg; + auto* arg_group = gmap_.at(arg.get())->FindRoot(); + Expr new_arg = this->Mutate(arg); + + if (ret_group != arg_group) { + Var param = ginfo_[ret_group].GetOrAllocParam(new_arg, type); + new_args.push_back(param); + } else { + new_args.push_back(new_arg); + } + } + auto new_call = CallNode::make( + call->op, new_args, call->attrs, call->type_args); + + if (ret_group->root_ref == call) { + // This is the root of the group + // create the new call node. + const GroupInfo& ginfo = ginfo_[ret_group]; + auto func = FunctionNode::make( + ginfo.params, new_call, call->checked_type(), {}); + func = FunctionSetAttr(func, "Primitive", tvm::Integer(1)); + return CallNode::make(func, ginfo.arguments, Attrs()); + } else { + // This is an intermediate node of a fused function + // simply return the new call. + return new_call; } - auto body = CallNode::make(call->op, inner_args, call->attrs); - auto func = FunctionNode::make( - params, body, call->checked_type(), {}); - func = FunctionSetAttr(func, "Primitive", tvm::Integer(1)); - return CallNode::make(func, args, Attrs()); } else { return ExprMutator::VisitExpr_(call); } } + // Debug function, dump the group assignment in text. + void DebugDumpGroup(const Expr& body) { + std::string text = RelayPrint(body, [this](const Expr& expr) -> std::string { + auto it = gmap_.find(expr.get()); + if (it == gmap_.end()) return ""; + std::ostringstream os; + auto *group = it->second->FindRoot(); + os << "group=" << group; + return os.str(); + }); + LOG(INFO) << "Dump of group info:\n" << text; + } }; -Expr FuseOps(const Expr& expr) { +Expr FuseOps(const Expr& expr, int fuse_opt_level) { // First we convert all chains of fusable ops into // abstracted functions which we mark as primtive // then we convert these primtive functions into // new operators. - return SimpleFuser().Mutate(expr); + return FuseMutator().Transform(expr, fuse_opt_level); } TVM_REGISTER_API("relay._ir_pass.FuseOps") .set_body([](TVMArgs args, TVMRetValue *ret) { - *ret = FuseOps(args[0]); + *ret = FuseOps(args[0], args[1]); }); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/pass_util.h b/src/relay/pass/pass_util.h new file mode 100644 index 0000000000000000000000000000000000000000..bf52297e8930f46feb0df111e196f3d247dbf0ae --- /dev/null +++ b/src/relay/pass/pass_util.h @@ -0,0 +1,27 @@ +/*! + * Copyright (c) 2018 by Contributors. + * + * \file tvm/relay/pass/pass_util.h + * \brief Utilities for writing + */ +#ifndef TVM_RELAY_PASS_PASS_UTIL_H_ +#define TVM_RELAY_PASS_PASS_UTIL_H_ + +#include <tvm/relay/op.h> +#include <tvm/relay/expr.h> +#include <tvm/relay/attrs/transform.h> + +namespace tvm { +namespace relay { + +/*! + * \brief Get reference counter of each internal ExprNode in body. + * \param body The body expression. + * \return The reference count mapping. + */ +std::unordered_map<const Node*, size_t> +GetExprRefCount(const Expr& body); + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_PASS_UTIL_H_ diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index b224a099aee1ee2d2c7ceb525b6dc897d0bb2f27..5cabfbdabc49e6c1b884d76fe675654ac14e7550 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -442,6 +442,9 @@ class TypeInferencer::Resolver : public ExprMutator { VarNode* new_var =( std::is_base_of<VarNode, T>::value ? static_cast<VarNode*>(new_e.node_.get()) : nullptr); + FunctionNode* new_fn =( + std::is_base_of<FunctionNode, T>::value ? + static_cast<FunctionNode*>(new_e.node_.get()) : nullptr); // check if we need update the new_e bool need_update_type = !checked_type.same_as(new_e->checked_type_); @@ -454,7 +457,17 @@ class TypeInferencer::Resolver : public ExprMutator { update_missing_type_annotation_ && !new_var->type_annotation.defined()); - if (!need_update_type && !need_update_var && !need_update_call) return new_e; + bool need_update_fn = ( + std::is_base_of<FunctionNode, T>::value && + update_missing_type_annotation_ && + !new_fn->ret_type.defined()); + + if (!need_update_type && + !need_update_var && + !need_update_call && + !need_update_fn) { + return new_e; + } if (!new_e.node_.unique()) { // Copy on write optimization @@ -467,6 +480,9 @@ class TypeInferencer::Resolver : public ExprMutator { new_var = ( std::is_base_of<VarNode, T>::value ? static_cast<VarNode*>(new_e.node_.get()) : nullptr); + new_fn = ( + std::is_base_of<FunctionNode, T>::value ? + static_cast<FunctionNode*>(new_e.node_.get()) : nullptr); } // attach the information. @@ -483,6 +499,11 @@ class TypeInferencer::Resolver : public ExprMutator { if (need_update_var) { new_var->type_annotation = checked_type; } + if (need_update_fn) { + auto* fn_type = checked_type.as<FuncTypeNode>(); + CHECK(fn_type != nullptr); + new_fn->ret_type = fn_type->ret_type; + } return new_e; } diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc index 3ca161d23f728e593b01a64e7703f3f6ae743c65..e1efcbbdd0b91d2ce09ff096dcaa4428b6fc58cf 100644 --- a/src/relay/pass/type_solver.cc +++ b/src/relay/pass/type_solver.cc @@ -85,18 +85,18 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) { void TypeSolver::AddConstraint(const TypeConstraint& constraint) { if (auto *op = constraint.as<TypeRelationNode>()) { // create a new relation node. - RelationNode* rnode = make<RelationNode>(); + RelationNode* rnode = arena_.make<RelationNode>(); rnode->rel = GetRef<TypeRelation>(op); rel_nodes_.push_back(rnode); // populate the type information. for (size_t i = 0; i < op->args.size(); ++i) { // insert link to the type list - LinkNode<TypeNode*>* tlink = make<LinkNode<TypeNode*> >(); + LinkNode<TypeNode*>* tlink = arena_.make<LinkNode<TypeNode*> >(); TypeNode* tnode = GetTypeNode(op->args[i]); tlink->value = tnode; rnode->type_list.Push(tlink); // insert type->relation node - LinkNode<RelationNode*>* rlink = make<LinkNode<RelationNode*> >(); + LinkNode<RelationNode*>* rlink = arena_.make<LinkNode<RelationNode*> >(); rlink->value = rnode; tnode->rel_list.Push(rlink); } diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h index 30f82f980a751244a0eadd3dda411680491636a5..2f311c9b9810229b54cf178197abf608619269b8 100644 --- a/src/relay/pass/type_solver.h +++ b/src/relay/pass/type_solver.h @@ -16,6 +16,8 @@ namespace tvm { namespace relay { +using common::LinkNode; +using common::LinkedList; /*! * \brief Interface of type solver used in type inference. * @@ -69,41 +71,6 @@ class TypeSolver { // Internally the solver maintains a bipartite graph of Relation and Types. // All the object in the structure is managed by a arena allocator // which releases the memory upon distruction of the type solver. - /*! - * \brief Link list node - * \tparam T the content data type - */ - template<typename T> - struct LinkNode { - /*! \brief The content value */ - T value; - /*! \brief pointer to the next location */ - LinkNode<T>* next{nullptr}; - }; - /*! - * \brief LinkedList structure - * \tparam T the content data type - */ - template<typename T> - struct LinkedList { - /*! \brief Head pointer */ - LinkNode<T>* head{nullptr}; - /*! \brief Tail pointer */ - LinkNode<T>* tail{nullptr}; - /*! - * \brief Push a new node to the end of the linked list. - * \param node The node to be pushed. - */ - void Push(LinkNode<T>* node) { - node->next = nullptr; - if (this->tail != nullptr) { - this->tail->next = node; - this->tail = node; - } else { - head = tail = node; - } - } - }; /*! * \brief type node struct * TypeNode implements a union-find data structure(via parent) @@ -164,18 +131,6 @@ class TypeSolver { common::Arena arena_; /*! \brief Reporter that reports back to self */ TypeReporter reporter_; - /*! - * \brief Create function to create a new node ptr via arena - * \tparam The type parameter - * \return The node pointer. - */ - template<typename T> - T* make() { - T* ptr = arena_.Alloc<T>(); - // call constructor - new (ptr) T(); - return ptr; - } /*! * \brief GetTypeNode that is corresponds to t. * if it do not exist, create a new one. @@ -186,7 +141,7 @@ class TypeSolver { if (it != tmap_.end()) { return it->second->FindRoot(); } else { - TypeNode* n = make<TypeNode>(); + TypeNode* n = arena_.make<TypeNode>(); type_nodes_.push_back(n); n->resolved_type = t; tmap_[t] = n; diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc index 51ef0377868f18f74a9f4573cec66ba997fad914..ebc4e6fc16e614b82a866fc5a3fe43b4a29a8c1c 100644 --- a/src/relay/pass/util.cc +++ b/src/relay/pass/util.cc @@ -129,5 +129,23 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars") } }); +/*! + * \brief Get reference counter of each internal ExprNode in body. + * \param body The body expression. + * \return The reference count mapping. + */ +std::unordered_map<const Node*, size_t> +GetExprRefCount(const Expr& body) { + class ExprRefCounter : private ExprVisitor { + public: + std::unordered_map<const Node*, size_t> + Get(const Expr& body) { + this->VisitExpr(body); + return std::move(this->visit_counter_); + } + }; + return ExprRefCounter().Get(body); +} + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index dd790a6d7d873e5c9fd82f12ef40430d0a50d23a..d12804d512f09321e2d733917f2fd1666ce6e9eb 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -33,6 +33,7 @@ def test_env(): text = env.astext() assert "def @myf" in text assert "%1 = add(%0, %0) # ty=float32" in text + show(env.astext(annotate=lambda x: str(x.checked_type.dtype))) show(text) diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py index 1b57bdce0e0c9b839a5e6591406a456c17caa99e..a5a7a05a974cd4ee518ed2ca96408cd974c4f3f3 100644 --- a/tests/python/relay/test_pass_fold_scale_axis.py +++ b/tests/python/relay/test_pass_fold_scale_axis.py @@ -46,6 +46,8 @@ def test_fold_fwd_simple(): weight = relay.var("weight", type_dict["weight"]) y1_folded = relay.ir_pass.forward_fold_scale_axis(y1) y1_expected = expected(x, weight, in_bias, in_scale, channels) + y1_folded = relay.ir_pass.infer_type(y1_folded) + y1_expected = relay.ir_pass.infer_type(y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 2) @@ -113,6 +115,8 @@ def test_fold_fwd_dual_path(): type_dict = {x.name_hint:x.checked_type for x in y1.params} weight = relay.var("weight", type_dict["weight"]) y1_expected = expected(x, weight, in_bias, in_scale, channels) + y1_folded = relay.ir_pass.infer_type(y1_folded) + y1_expected = relay.ir_pass.infer_type(y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) check((2, 4, 10, 3), 3) @@ -194,6 +198,8 @@ def test_fold_bwd_simple(): weight = relay.var("weight", type_dict["weight"]) y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) y1_expected = expected(x, weight, out_bias, out_scale, channels) + y1_folded = relay.ir_pass.infer_type(y1_folded) + y1_expected = relay.ir_pass.infer_type(y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 8) @@ -255,6 +261,8 @@ def test_fold_bwd_dual_path(): weight = relay.var("weight", type_dict["weight"]) y1_folded = relay.ir_pass.backward_fold_scale_axis(y1) y1_expected = expected(x, weight, out_bias, out_scale, channels) + y1_folded = relay.ir_pass.infer_type(y1_folded) + y1_expected = relay.ir_pass.infer_type(y1_expected) assert relay.ir_pass.alpha_equal(y1_folded, y1_expected) check((2, 4, 10, 10), 8) diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index 2bbc1dce9693db4aa04d773764fe89a204c8e25a..19bec20ac4afd90ee85556b3abcd8b215c9c90af 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -3,15 +3,103 @@ from tvm import relay def test_fuse_simple(): """Simple testcase.""" - x = relay.var("x", shape=(10, 20)) - y = relay.add(x, x) - z = relay.exp(y) + def before(): + x = relay.var("x", shape=(10, 20)) + y = relay.add(x, relay.const(1, "float32")) + z = relay.exp(y) + return relay.Function([x], z) + + def expected(): + x = relay.var("p", shape=(10, 20)) + y = relay.add(x, relay.const(1, "float32")) + z = relay.exp(y) + f1 = relay.Function([x], z) + x = relay.var("x", shape=(10, 20)) + y = relay.Call(f1, [x]) + return relay.Function([x], y) + + z = before() z = relay.ir_pass.infer_type(z) - zz = relay.ir_pass.fuse_ops(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) zz = relay.ir_pass.fuse_ops(zz) zz = relay.ir_pass.infer_type(zz) - zz.astext() + after = relay.ir_pass.infer_type(expected()) + assert relay.ir_pass.alpha_equal(zz, after) + + + +def test_conv2d_fuse(): + """Test fusion case of conv2d""" + def before(dshape): + x = relay.var("x", shape=dshape) + y = relay.nn.conv2d(x, relay.var("w1"), + kernel_size=(3, 3), + padding=(1, 1), + channels=16) + # this is the next dominator. + y1 = relay.add(relay.const(1, "float32"), y) + y = relay.add(y, y1) + # second path + z2 = relay.nn.conv2d(y, relay.var("w2"), + kernel_size=(1, 1), + padding=(0,0), + channels=16) + z3 = relay.nn.conv2d(y, relay.var("w3"), + kernel_size=(3, 3), + padding=(1,1), + channels=16) + # add can only be fused to z1 + z = relay.add(z2, z3) + return relay.Function(relay.ir_pass.free_vars(z), z) + + def expected(dshape): + # segment 1 + x = relay.var("p0", shape=dshape) + w = relay.var("p1") + y = relay.nn.conv2d(x, w, + kernel_size=(3, 3), + padding=(1, 1), + channels=16) + y1 = relay.add(relay.const(1, "float32"), y) + y = relay.add(y, y1) + f1 = relay.Function([x, w], y) + # segment 2 + x = relay.var("p0", shape=dshape) + w = relay.var("p1") + z2 = relay.nn.conv2d(x, w, + kernel_size=(3, 3), + padding=(1,1), + channels=16) + f2 = relay.Function([x, w], z2) + # segment 3 + x = relay.var("p0", shape=dshape) + w = relay.var("p1") + offset = relay.var("p2", shape=dshape) + z3 = relay.nn.conv2d(x, w, + kernel_size=(1, 1), + padding=(0, 0), + channels=16) + z3 = relay.add(z3, offset) + f3 = relay.Function([x, w, offset], z3) + # compose + x = relay.var("x", shape=dshape) + y = relay.Call(f1, [x, relay.var("w1")]) + z2 = relay.Call(f2, [y, relay.var("w3")]) + z3 = relay.Call(f3, [y, relay.var("w2"), z2]) + z = z3 + return relay.Function(relay.ir_pass.free_vars(z), z) + + dshape = (1, 16, 64, 64) + z = before(dshape) + z = relay.ir_pass.infer_type(z) + zz = relay.ir_pass.fuse_ops(z, opt_level=2) + zz = relay.ir_pass.infer_type(zz) + after = relay.ir_pass.infer_type(expected(dshape)) + assert relay.ir_pass.alpha_equal(zz, after) + if __name__ == "__main__": test_fuse_simple() + test_conv2d_fuse()