diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index 1a5470489ce2179a3841b0fb723bba4ad4461d05..2319f8baec00a8244fa0e5172a6499a6258a022b 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -429,6 +429,16 @@ inline const TTypeNode* ExprNode::type_as() const {
   return node;
 }
 
+/*!
+ * \brief Print node as text format.
+ * \param node The node to be printed.
+ * \param annotate An optional callback function for attaching
+ *        additional comment block to an expr.
+ * \return The text representation.
+ */
+std::string RelayPrint(
+    const NodeRef& node,
+    runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_EXPR_H_
diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h
index 59ad52ccf3fdbd1eb165a240536a0d794e92def3..f25785d39eeb19bdb22f3cd5f29840cee5c66aec 100644
--- a/include/tvm/runtime/packed_func.h
+++ b/include/tvm/runtime/packed_func.h
@@ -161,6 +161,8 @@ class TypedPackedFunc<R(Args...)> {
   using TSelf = TypedPackedFunc<R(Args...)>;
   /*! \brief default constructor */
   TypedPackedFunc() {}
+  /*! \brief constructor from null */
+  TypedPackedFunc(std::nullptr_t null) {}  // NOLINT(*)
   /*!
    * \brief construct by wrap a PackedFunc
    *
diff --git a/python/tvm/relay/base.py b/python/tvm/relay/base.py
index 5a92eb57d209115c1850c4a9983f6be8840d053a..012315b40f5104b3809e079ba4db42698384c878 100644
--- a/python/tvm/relay/base.py
+++ b/python/tvm/relay/base.py
@@ -22,15 +22,20 @@ def register_relay_node(type_key=None):
 
 
 class RelayNode(NodeBase):
-    def astext(self):
+    """Base class of all relay node."""
+    def astext(self, annotate=None):
         """Get the text format of the expression.
 
         Returns
         -------
         text : str
             The text format of the expression.
+
+        annotate: Optional[relay.Expr->str]
+            Optional annotate function to provide additional
+            information in the comment block.
         """
-        return _expr._text_print(self)
+        return _expr.RelayPrint(self, annotate)
 
 
 @register_relay_node
diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py
index c48ec90e9e12536330d2a676b431eb7b5399abb0..0f33e86ab5cdcfbe3f8a8c2f3d8c81cabe77e470 100644
--- a/python/tvm/relay/build_module.py
+++ b/python/tvm/relay/build_module.py
@@ -173,11 +173,13 @@ def build(func,
     else:
         tophub_context = autotvm.util.EmptyContext()
 
+    cfg = BuildConfig.current
+
     with tophub_context:
         func = optimize(func)
         # Fuse ops before running code gen
         func = ir_pass.infer_type(func)
-        func = ir_pass.fuse_ops(func)
+        func = ir_pass.fuse_ops(func, cfg.opt_level)
         # Graph code generation
         func = ir_pass.infer_type(func)
         graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index 43cff0bac57a06e00aa1031e826ae1ddcf57e627..f82ea09a102ae75ec51124b5d891bc3a34cfcdba 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -6,7 +6,6 @@ from numbers import Number as _Number
 import numpy as _np
 from .base import RelayNode, register_relay_node
 from . import _make
-from . import _expr
 from . import ty as _ty
 from .._ffi import base as _base
 from .. import nd as _nd
@@ -477,7 +476,7 @@ class TupleWrapper(object):
         text : str
             The text format of the tuple expression.
         """
-        return _expr._text_print(self.tuple_value)
+        return self.tuple_value.astext()
 
     def __getitem__(self, index):
         if index >= len(self):
diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py
index 274761f0a27bf5f26839c120be5233b62303b7c0..b1a76d6fae6fb0ad4a126407d7f49299431b30f8 100644
--- a/python/tvm/relay/ir_pass.py
+++ b/python/tvm/relay/ir_pass.py
@@ -259,7 +259,7 @@ def structural_hash(value):
         raise TypeError(msg)
 
 
-def fuse_ops(expr):
+def fuse_ops(expr, opt_level=1):
     """Fuse operators in expr together.
 
     Parameters
@@ -267,9 +267,12 @@ def fuse_ops(expr):
     expr : tvm.relay.Expr
         The input expression.
 
+    opt_level : int
+        The level of fuse optimization.
+
     Returns
     -------
     transformed_expr : tvm.relay.Expr
         Transformed expression, containing fused result.
     """
-    return _ir_pass.FuseOps(expr)
+    return _ir_pass.FuseOps(expr, opt_level)
diff --git a/src/common/arena.h b/src/common/arena.h
index e8d4b2e23e37588f7836b7d19687565499735b32..c5da093a70b881d45ff07ea33d6f7f37033ac572 100644
--- a/src/common/arena.h
+++ b/src/common/arena.h
@@ -38,11 +38,29 @@ class Arena {
   /*!
    * \brief Allocate a space from Arena for type T
    * \param T the data type to be allocated
+   * \note The space of T is not initialized.
    */
   template<typename T>
-  T* Alloc() {
+  T* allocate_() {
     return static_cast<T*>(Alloc(sizeof(T), alignof(T)));
   }
+  /*!
+   * \brief Create a new instance of type T.
+   * \param args The constructor argument.
+   * \tparam T the type to be created.
+   * \tparam Args Arguments to the constructor.
+   *
+   * \return The allocated object.
+   * \note The type T must be simple type, or only contain
+   *  memory allocated from the same arena.
+   *  Otherwise the destructor needs to be called explicitly.
+   */
+  template<typename T, typename... Args>
+  T* make(Args&&... args) {
+    T* ptr = allocate_<T>();
+    new (ptr) T(std::forward<Args>(args)...);
+    return ptr;
+  }
 
  private:
   // page size 16 KB
@@ -87,6 +105,44 @@ class Arena {
   }
 };
 
+/*!
+ * \brief Link list node
+ * \tparam T the content data type
+ */
+template<typename T>
+struct LinkNode {
+  /*! \brief The content value */
+  T value;
+  /*! \brief pointer to the next location */
+  LinkNode<T>* next{nullptr};
+};
+/*!
+ * \brief LinkedList structure
+ * \tparam T the content data type
+ * \note This is a simple data structure that can be used together with the arena.
+ * \sa LinkNode
+ */
+template<typename T>
+struct LinkedList {
+  /*! \brief Head pointer */
+  LinkNode<T>* head{nullptr};
+  /*! \brief Tail pointer */
+  LinkNode<T>* tail{nullptr};
+  /*!
+   * \brief Push a new node to the end of the linked list.
+   * \param node The node to be pushed.
+   */
+  void Push(LinkNode<T>* node) {
+    node->next = nullptr;
+    if (this->tail != nullptr) {
+      this->tail->next = node;
+      this->tail = node;
+    } else {
+      head = tail = node;
+    }
+  }
+};
+
 }  // namespace common
 }  // namespace tvm
 #endif  // TVM_COMMON_ARENA_H_
diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc
index 38e3f6c2a7b81f0d9c11ee018e7ff35238b68d5c..dc094e00e05b38eea6e313128fab4614614f8638 100644
--- a/src/relay/backend/compile_engine.cc
+++ b/src/relay/backend/compile_engine.cc
@@ -109,6 +109,29 @@ class ScheduleGetter :
     return {};
   }
 
+  Array<Tensor> VisitExpr_(const ConstantNode* op) final {
+    CHECK(op->is_scalar());
+    void* data = op->data->data;
+    DataType dtype = TVMType2Type(op->data->dtype);
+    Tensor value = tvm::compute({}, [&](const Array<tvm::Var>&) {
+        if (dtype == Int(32)) {
+          return make_const(dtype, static_cast<const int32_t*>(data)[0]);
+        } else if (dtype == Int(64)) {
+          return make_const(dtype, static_cast<const int64_t*>(data)[0]);
+        } else if (dtype == Float(32)) {
+          return make_const(dtype, static_cast<const float*>(data)[0]);
+        } else if (dtype == Float(64)) {
+          return make_const(dtype, static_cast<const double*>(data)[0]);
+        } else if (dtype == Bool()) {
+          return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
+        } else {
+          LOG(FATAL) << "not handled";
+          return tvm::Expr();
+        }
+      });
+    return {value};
+  }
+
   Array<Tensor> VisitExpr_(const CallNode* call_node) final {
     static auto fcompute =
         Op::GetAttr<FTVMCompute>("FTVMCompute");
diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc
index f28db371706e26677db3105d1f1529dfbf31e652..93ed76bed3c28e0cdfaf43fd64561f4160cab60c 100644
--- a/src/relay/ir/text_printer.cc
+++ b/src/relay/ir/text_printer.cc
@@ -125,6 +125,8 @@ class TextPrinter :
     public TypeFunctor<void (const Type&, std::ostream& os)>,  // NOLINT(*)
     public AttrFunctor<void (const NodeRef&, std::ostream& os)> { // NOLINT(*)
  public:
+  explicit TextPrinter(runtime::TypedPackedFunc<std::string(Expr)> annotate)
+      : annotate_(annotate) {}
   /*!
    * \brief Print a node to string.
    * \param node.
@@ -279,11 +281,11 @@ class TextPrinter :
 
   TextValue VisitExpr_(const CallNode* op) final {
     // possibly through meta-data
-    TextValue call_op = GetValue(op->op);
     std::vector<TextValue> args;
     for (Expr arg : op->args) {
       args.emplace_back(GetValue(arg));
     }
+    TextValue call_op = GetValue(op->op);
     TextValue id = this->AllocTempVar();
     this->PrintIndent();
 
@@ -532,7 +534,9 @@ class TextPrinter :
    */
   void PrintOptionalInfo(const Expr& expr) {
     // additional information in comment.
-    if (expr->checked_type_.defined()) {
+    if (annotate_ != nullptr) {
+      stream_ << " # " << annotate_(expr);
+    } else if (expr->checked_type_.defined()) {
       stream_ << " # ty=";
       this->PrintType(expr->checked_type(), stream_);
     }
@@ -678,7 +682,10 @@ class TextPrinter :
       name = "%" + name;
     }
     TextValue val(GetUniqueName(name));
-    CHECK(!memo_.count(var)) << "Duplicated variable " << var;
+    // still print if ir is malformed, but show the error.
+    if (memo_.count(var)) {
+      memo_[var] = TextValue(val.name + "-malformed-ir");
+    }
     memo_[var] = val;
     return val;
   }
@@ -686,6 +693,8 @@ class TextPrinter :
  private:
   class AttrPrinter;
   friend class AttrPrinter;
+  /*! \brief additional comment function */
+  runtime::TypedPackedFunc<std::string(Expr)> annotate_;
   /*! \brief meta data context */
   TextMetaDataContext meta_;
   /*! \brief Check whether scope is still valid */
@@ -776,12 +785,15 @@ void TextPrinter::PrintCallAttrs(const Expr& op,
   os << ", " << meta_.GetMetaNode(attrs);
 }
 
-std::string RelayPrint(const NodeRef& node) {
-  return TextPrinter().Print(node);
+std::string RelayPrint(const NodeRef& node,
+                       runtime::TypedPackedFunc<std::string(Expr)> annotate) {
+  return TextPrinter(annotate).Print(node);
 }
 
-TVM_REGISTER_API("relay._expr._text_print")
-.set_body_typed<std::string(const NodeRef&)>(RelayPrint);
+TVM_REGISTER_API("relay._expr.RelayPrint")
+.set_body_typed<std::string(
+    const NodeRef&,
+    runtime::TypedPackedFunc<std::string(Expr)>)>(RelayPrint);
 
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/pass/fold_scale_axis.cc b/src/relay/pass/fold_scale_axis.cc
index e757118f33f27cc8903bb145e3591294cf300bbf..038f34df576087a88661e02090ead848b7b25109 100644
--- a/src/relay/pass/fold_scale_axis.cc
+++ b/src/relay/pass/fold_scale_axis.cc
@@ -10,6 +10,7 @@
 #include <tvm/relay/attrs/nn.h>
 #include <tvm/relay/expr_functor.h>
 #include "pattern_util.h"
+#include "pass_util.h"
 #include "../op/nn/layout.h"
 
 namespace tvm {
@@ -580,23 +581,6 @@ using FBackwardTransform = TypedPackedFunc<
 //----------------------------------------------
 // Generic Visitors for FScaleAxisBackward
 //----------------------------------------------
-/*!
- * \brief Get reference counter of each internal ExprNode in body.
- * \param body The body expression.
- * \return The reference count mapping.
- */
-std::unordered_map<const Node*, size_t>
-GetExprRefCount(const Expr& body) {
-  class ExprRefCounter : private ExprVisitor {
-   public:
-    std::unordered_map<const Node*, size_t>
-    Get(const Expr& body) {
-      this->VisitExpr(body);
-      return std::move(this->visit_counter_);
-    }
-  };
-  return ExprRefCounter().Get(body);
-}
 
 class BackwardPrep : private ExprVisitor {
  public:
diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc
index 2bd16a4f840fb8962b9814c9971ded65260e87f9..2503bd5f53faee2ecd2325d5b6043368aa1efa52 100644
--- a/src/relay/pass/fuse_ops.cc
+++ b/src/relay/pass/fuse_ops.cc
@@ -9,13 +9,686 @@
 #include <tvm/ir_operator.h>
 #include <tvm/relay/pass.h>
 #include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op_attr_types.h>
+#include "../../common/arena.h"
+
 
 namespace tvm {
 namespace relay {
 
-// Simple fuser that only makes each operator function as primitive.
-class SimpleFuser : public ExprMutator {
+/*
+  Note on Fusing algorithm:
+
+  The main challenge of genenral fusor is to handle possible diamond shape branches,
+  in the following graph, conv2d can be fused to elemwise add.
+
+            conv2d
+            /  |  \
+           /   |   \
+         op    op   op
+          \    |    /
+           \   |   /
+          elemwise add
+               |
+
+  However, at the point of conv2d we do not necessarily know that all its future path
+  will merge at the elemwise add. The new fusor algorithm applies post-dominator analysis.
+  The immediate post-dominator of a node defined by the closest node where all the future path goes into.
+  In the above case, the elemwise add is the post-dominator of conv2d. The general algorithm is as follows:
+
+  - Construct a DAG of dataflow graph for dominator analysis
+  - Construct a post-dominator tree which gives immediate post dominator of each node.
+  - Run fusion algorithm with the given post-dominator information.
+
+  Note that, because we run analysis on a DAG, we use a single pass post-dominator
+  tree construction algorithm via LCA, which is simpler than the full version that handles cycles.
+
+  The fusion algorithm traverses from each node and checks if it can be fused to its
+  immediate post dominator. It has to check the following things:
+
+  - CheckPath: check all the path between a node and its immediate post-dominator
+               satiesfies the fuse condition.
+  - Note that these intermediate node can already be fused with another nodes, the algorithm
+      will still run correctly.
+  - CommitFuse: mark all the nodes between source and post-dominator as the same group.
+  - We use an Union-Find data structure to manage the groups.
+*/
+using common::LinkNode;
+using common::LinkedList;
+
+/*!
+ * \brief Indexed data flow graph in forward direction.
+ *  This is a temporary data structure used for operator fusion analysis.
+ *
+ *  This data structure only captures the dataflow fragement and
+ *  could ignore blocks like let by simply ordering each dataflow block
+ *  and mark the output node as extern_ref;
+ */
+class IndexedForwardGraph {
+ public:
+  struct Node;
+  /*!
+   * The forward edge in the dataflow graph.
+   */
+  struct Edge {
+    /*! \brief The corresponding node */
+    Node* node{nullptr};
+    /*! \brief The respective pattern of this op */
+    OpPatternKind pattern{kOpaque};
+  };
+  /*! \brief A node in the graph. */
+  struct Node {
+    /*! \brief weak reference to the corresponding edge. */
+    const tvm::Node* ref{nullptr};
+    /*! \brief The index of the node in topological order. */
+    size_t index{0};
+    /*! \brief Whether this node is referenced by external source */
+    bool extern_ref{false};
+    /*! \brief The general pattern in the node */
+    OpPatternKind pattern{kOpaque};
+    /*! \brief The outputs of the node. */
+    LinkedList<Edge> outputs;
+  };
+  /*! \brief The node map that maps node to graph */
+  std::unordered_map<const tvm::Node*, Node*> node_map;
+  /*! \brief All the nodes in post DFS order */
+  std::vector<Node*> post_dfs_order;
+
+  /*! \brief Dump the graph into string. */
+  void DebugDump() {
+    std::ostringstream os;
+    for (size_t i = 0; i < post_dfs_order.size(); ++i) {
+      Node* node = post_dfs_order[i];
+      os << "node[" << i << "], "
+         << GetRef<NodeRef>(node->ref)
+         << " outputs=[";
+      for (auto* link = node->outputs.head; link != nullptr; link = link->next) {
+        os << link->value.node->index << ", ";
+      }
+      os << "]\n";
+    }
+    LOG(INFO) << os.str();
+  }
+  /*!
+   * \brief create a indexed forward graph.
+   * \param arena The arena used for data allocation.
+   * \param body The body of the expression to create a graph.
+   */
+  static IndexedForwardGraph Create(common::Arena* arena, const Expr& body);
+
+ private:
+  class Creator;
+};
+
+// Creator of post dominator tree of the dataflow
+class IndexedForwardGraph::Creator : private ExprVisitor {
+ public:
+  explicit Creator(common::Arena* arena)
+      : arena_(arena) {}
+
+  IndexedForwardGraph Prepare(const Expr& body) {
+    this->Update(body, nullptr, kOpaque);
+    this->VisitExpr(body);
+    return std::move(graph_);
+  }
+
+ private:
+  /*! \brief allocator of all the internal node object */
+  common::Arena* arena_;
+  // The output.
+  IndexedForwardGraph graph_;
+  // attribute equal comparator
+  AttrsEqual attr_equal_;
+  // Update the message stored at the node.
+  void Update(const Expr& node,
+              IndexedForwardGraph::Node* parent,
+              OpPatternKind pattern) {
+    const tvm::Node* key = node.get();
+    IndexedForwardGraph::Node* current;
+    auto it = graph_.node_map.find(key);
+    if (it != graph_.node_map.end()) {
+      current = it->second;
+    } else {
+      current = arena_->make<IndexedForwardGraph::Node>();
+      graph_.node_map[key] = current;
+    }
+    if (parent != nullptr) {
+      auto* link = arena_->make<LinkNode<IndexedForwardGraph::Edge> >();
+      link->value.node = parent;
+      link->value.pattern = pattern;
+      current->outputs.Push(link);
+    } else {
+      current->extern_ref = true;
+    }
+  }
+  void AddNode(const tvm::Node* key) {
+    auto it = graph_.node_map.find(key);
+    CHECK(it != graph_.node_map.end())
+        << "Cannot find node " << GetRef<NodeRef>(key);
+    IndexedForwardGraph::Node* node = it->second;
+    CHECK(node->ref == nullptr);
+    node->ref = key;
+    node->index = graph_.post_dfs_order.size();
+    graph_.post_dfs_order.push_back(node);
+  }
+
+  // Post order tree
+  void VisitExpr_(const FunctionNode* op) {
+    for (auto param : op->params) {
+      this->Update(param, nullptr, kOpaque);
+    }
+    this->Update(op->body, nullptr, kOpaque);
+    ExprVisitor::VisitExpr_(op);
+  }
+
+  void VisitExpr_(const ConstantNode* op) {
+    this->AddNode(op);
+    Node* node = graph_.node_map.at(op);
+    DataType dtype = TVMType2Type(op->data->dtype);
+    // This rule must be consistent with code generator.
+    bool is_simple_const = (
+        dtype == Int(32) ||
+        dtype == Int(64) ||
+        dtype == Float(32) ||
+        dtype == Float(64) ||
+        dtype == Bool());
+    if (op->is_scalar() && is_simple_const) {
+      node->pattern = kElemWise;
+    } else {
+      // for now, mark non-scalar constant
+      // as opaque, we will not choose to fuse it.
+      node->pattern = kOpaque;
+    }
+  }
+
+  void VisitExpr_(const CallNode* call) {
+    CHECK(graph_.node_map.count(call));
+    Node* node = graph_.node_map.at(call);
+    static auto fpattern =
+        Op::GetAttr<TOpPattern>("TOpPattern");
+    // setup pattern.
+    OpPatternKind op_pattern = kOpaque;
+    if (const OpNode* opnode = call->op.as<OpNode>()) {
+      op_pattern = static_cast<OpPatternKind>(fpattern[GetRef<Op>(opnode)]);
+    }
+    node->pattern = op_pattern;
+    const auto* rtype = call->checked_type().as<TensorTypeNode>();
+    // pass the message back to all the children it references.
+    for (size_t i = 0; i < call->args.size(); ++i) {
+      const auto* arg_type =
+          call->args[i]->checked_type().as<TensorTypeNode>();
+      // specifically check if result type
+      OpPatternKind edge_pattern = op_pattern;
+      if (edge_pattern == kBroadcast &&
+          arg_type != nullptr &&
+          rtype != nullptr &&
+          attr_equal_(rtype->shape, arg_type->shape)) {
+        edge_pattern = kElemWise;
+      }
+      this->Update(call->args[i], node, edge_pattern);
+    }
+    ExprVisitor::VisitExpr_(call);
+    this->AddNode(call);
+  }
+
+  void VisitExpr_(const TupleNode* op) {
+    for (const Expr& field : op->fields) {
+      this->Update(field, nullptr, kOpaque);
+    }
+    ExprVisitor::VisitExpr_(op);
+    this->AddNode(op);
+  }
+
+  void VisitExpr_(const TupleGetItemNode* op) {
+    CHECK(graph_.node_map.count(op));
+    Node* node = graph_.node_map.at(op);
+    this->Update(op->tuple, node, kOpaque);
+    ExprVisitor::VisitExpr_(op);
+    this->AddNode(op);
+  }
+
+  void VisitExpr_(const VarNode* op) {
+    this->AddNode(op);
+  }
+
+  void VisitExpr_(const LetNode* op) {
+    // do not fuse through let.
+    this->Update(op->var, nullptr, kOpaque);
+    this->Update(op->value, nullptr, kOpaque);
+    this->Update(op->body, nullptr, kOpaque);
+    ExprVisitor::VisitExpr_(op);
+    this->AddNode(op);
+  }
+
+  void VisitExpr_(const IfNode* op) {
+    // do not fuse through if.
+    this->Update(op->cond, nullptr, kOpaque);
+    this->Update(op->true_branch, nullptr, kOpaque);
+    this->Update(op->false_branch, nullptr, kOpaque);
+    ExprVisitor::VisitExpr_(op);
+    this->AddNode(op);
+  }
+};
+
+IndexedForwardGraph IndexedForwardGraph::Create(
+    common::Arena* arena, const Expr& body) {
+  return Creator(arena).Prepare(body);
+}
+
+/*!
+ * \brief Dominator tree that represent domination or
+ *  post domination relation of the node.
+ */
+class DominatorTree {
  public:
+  /*!
+   * \brief A node in the dominator tree.
+   */
+  struct Node {
+    /*! \brief The node in the tree */
+    IndexedForwardGraph::Node* gnode{nullptr};
+    /*! \brief parent of the tree */
+    Node* parent{nullptr};
+    /*! \brief current depth*/
+    int depth{0};
+    /*! \brief aggregated pattern to parent */
+    OpPatternKind pattern{kOpaque};
+  };
+  // index -> node.
+  std::vector<Node*> nodes;
+  /*!
+   * \brief compute a post dominator relation for a given dataflow graph.
+   * \param arena The arena used for node allocation.
+   * \param graph The graph to be analyze.
+   * \return The dominator tree of the graph.
+   * \note This algorithm makes use of the fact that graph is DAG,
+   *       and runs a single pass algorithm via LCA.
+   */
+  static DominatorTree PostDom(common::Arena* arena,
+                               const IndexedForwardGraph& graph);
+
+ private:
+  // Combine pattern together.
+  static OpPatternKind CombinePattern(
+      OpPatternKind lhs, OpPatternKind rhs) {
+    if (lhs > rhs) return lhs;
+    return rhs;
+  }
+  /*!
+   * \brief Find the least common acenstor of the two nodes.
+   * \param lhs The left node.
+   * \param rhs The right node.
+   * \param edge_pattern
+   *        The combined edge pattern across all the parents.
+   * \return The least common acenstor of thw two.
+   */
+  static Node* LeastCommonAcenstor(
+      Node* lhs,
+      Node* rhs,
+      OpPatternKind* edge_pattern) {
+    while (lhs != rhs) {
+      if (lhs == nullptr) return nullptr;
+      if (rhs == nullptr) return nullptr;
+      if (lhs->depth < rhs->depth) {
+        edge_pattern[0] = CombinePattern(
+            edge_pattern[0], rhs->pattern);
+        rhs = rhs->parent;
+      } else if (rhs->depth < lhs->depth) {
+        edge_pattern[0] = CombinePattern(
+            edge_pattern[0], lhs->pattern);
+        lhs = lhs->parent;
+      } else {
+        lhs = lhs->parent;
+        rhs = rhs->parent;
+        edge_pattern[0] = CombinePattern(
+            edge_pattern[0], lhs->pattern);
+        edge_pattern[0] = CombinePattern(
+            edge_pattern[0], rhs->pattern);
+      }
+    }
+    return lhs;
+  }
+};
+
+DominatorTree DominatorTree::PostDom(common::Arena* arena,
+                                     const IndexedForwardGraph& graph) {
+  DominatorTree tree;
+  tree.nodes.resize(graph.post_dfs_order.size(), nullptr);
+  // reverse topo order
+  for (size_t i = graph.post_dfs_order.size(); i != 0; --i) {
+    size_t index = i - 1;
+    Node* tnode = arena->make<Node>();
+    auto* gnode = graph.post_dfs_order[index];
+    tnode->gnode = gnode;
+    if (gnode->extern_ref) {
+      tnode->depth = 1;
+      tnode->parent = nullptr;
+      tnode->pattern = kOpaque;
+    } else {
+      // find the LCAs of all outputs.
+      OpPatternKind pattern = kElemWise;
+      Node* parent = nullptr;
+      for (auto link = gnode->outputs.head; link != nullptr; link= link->next) {
+        size_t oindex = link->value.node->index;
+        CHECK_LT(oindex, tree.nodes.size());
+        Node* onode = tree.nodes[oindex];
+        CHECK(onode != nullptr);
+        if (parent != nullptr) {
+          parent = LeastCommonAcenstor(parent, onode, &pattern);
+        } else {
+          parent = onode;
+        }
+        pattern = CombinePattern(pattern, link->value.pattern);
+      }
+      CHECK(parent != nullptr);
+      tnode->depth = parent->depth + 1;
+      tnode->parent = parent;
+      tnode->pattern = pattern;
+    }
+    tree.nodes[index] = tnode;
+  }
+  return tree;
+}
+
+/*!
+ * \brief A partition of the graph marked by union find data structure.
+ */
+class GraphPartitioner {
+ public:
+  explicit GraphPartitioner(common::Arena* arena, int opt_level)
+      : arena_(arena), opt_level_(opt_level) {}
+  /*!
+   * \brief Group as a union find data structure.
+   */
+  struct Group {
+    /*! \brief The parent in the union find data structure. */
+    Group* parent{nullptr};
+    /*! \brief The pattern of the group */
+    OpPatternKind pattern;
+    /*! \brief reference to the root node. */
+    const tvm::Node* root_ref{nullptr};
+    /*!
+     * \brief Reference to the master node,
+     * this field is not nullptr only if pattern is kOutEWiseFusable.
+     */
+    const tvm::Node* master_ref{nullptr};
+    /*!
+     * \brief Find the group root, perform path compression
+     * \return The root type node.
+     */
+    Group* FindRoot() {
+      // fast path
+      if (this->parent == nullptr) return this;
+      // slow path with path compression.
+      Group* root = this;
+      while (root->parent != nullptr) {
+        root = root->parent;
+      }
+      for (Group* p = this; p != root;) {
+        Group* parent = p->parent;
+        p->parent = root;
+        p = parent;
+      }
+      return root;
+    }
+  };
+  /*!
+   * \brief Partition a graph.
+   * \return group assignments of each node.
+   */
+  std::vector<Group*> Partition(const IndexedForwardGraph& graph);
+
+ private:
+  /*! \brief The internal arena for temporary space. */
+  common::Arena* arena_;
+  /*! \brief optimization level for fuse operation. */
+  int opt_level_;
+  /*! \brief The internal groups. */
+  std::vector<Group*> groups_;
+  /*! \brief internal field used for deduplication */
+  std::unordered_set<IndexedForwardGraph::Node*> visited_;
+  // Internal implelementation of CheckPath
+  template<typename F>
+  bool CheckPath_(IndexedForwardGraph::Node* src,
+                  IndexedForwardGraph::Node* sink,
+                  F fcond) {
+    if (visited_.count(src)) return true;
+    visited_.insert(src);
+    Group* gnode =  groups_[src->index];
+    CHECK(gnode != nullptr);
+    gnode = gnode->FindRoot();
+    if (!fcond(gnode->pattern, src == sink)) return false;
+    if (src == sink) return true;
+    for (auto link = src->outputs.head; link != nullptr; link = link->next) {
+      if (!CheckPath_(link->value.node, sink, fcond)) return false;
+    }
+    return true;
+  }
+  /*!
+   * \brief Check all the node between src and sink satisfies fcond.
+   *
+   * src and sink are not checked.
+   *
+   * \param src The source node.
+   * \param sink The termination node.
+   * \param fcond The condition to be checked.
+   * \tparam F the condition function.
+   * \note sink must be a post-dominator of src.
+   */
+  template<typename F>
+  bool CheckPath(IndexedForwardGraph::Node* src,
+                 IndexedForwardGraph::Node* sink,
+                 F fcond) {
+    CHECK(!src->extern_ref);
+    visited_.clear();
+    CHECK(src != sink);
+    for (auto link = src->outputs.head; link != nullptr; link = link->next) {
+      if (!CheckPath_(link->value.node, sink, fcond)) return false;
+    }
+    return true;
+  }
+  // Combine two patterns together.
+  static OpPatternKind CombinePattern(
+      OpPatternKind lhs, OpPatternKind rhs) {
+    if (lhs > kBroadcast && rhs > kBroadcast) {
+      LOG(FATAL) << "Cannot merge two complex group together";
+    }
+    if (lhs > rhs) return lhs;
+    return rhs;
+  }
+  /*!
+   * \brief Merge the child group to the parent.
+   * \param child The child group.
+   * \param parent The parent group.
+   */
+  void MergeFromTo(Group* child, Group* parent) {
+    child = child->FindRoot();
+    parent = parent->FindRoot();
+    if (child == parent) return;
+    child->parent = parent;
+    // update master ref and pattern
+    if (child->master_ref != nullptr) {
+      CHECK(parent->master_ref == nullptr);
+      parent->master_ref = child->master_ref;
+      parent->pattern = CombinePattern(
+          child->pattern, parent->pattern);
+    }
+  }
+  // Internal implelementation of CommitFuse
+  void CommitFuse_(IndexedForwardGraph::Node* src,
+                   IndexedForwardGraph::Node* sink,
+                   Group* target) {
+    if (src == sink) return;
+    if (visited_.count(src)) return;
+    visited_.insert(src);
+    Group* gnode = groups_[src->index];
+    CHECK(gnode != nullptr);
+    // merge the current group to the parent if possible.
+    MergeFromTo(gnode, target);
+    for (auto link = src->outputs.head; link != nullptr; link = link->next) {
+      CommitFuse_(link->value.node, sink, target);;
+    }
+  }
+  /*!
+   * \brief Commit fusion operation.
+   * \param src The source node.
+   * \param sink The termination node.
+   * \tparam group the group to be committed.
+   * \note sink must be a post-dominator of src.
+   */
+  void CommitFuse(IndexedForwardGraph::Node* src,
+                  IndexedForwardGraph::Node* sink) {
+    Group* target = groups_[sink->index];
+    visited_.clear();
+    CHECK(src != sink);
+    CommitFuse_(src, sink, target);
+  }
+
+  // Initialize the groups.
+  void InitGroups(const IndexedForwardGraph& graph) {
+    groups_.resize(graph.post_dfs_order.size());
+    for (size_t nid = 0; nid < groups_.size(); ++nid) {
+      const auto* graph_node = graph.post_dfs_order[nid];
+      auto* group_node = arena_->make<Group>();
+      group_node->pattern = graph_node->pattern;
+      group_node->root_ref = graph_node->ref;
+      // set master ref if necessary.
+      if (group_node->pattern == kOutEWiseFusable) {
+        group_node->master_ref = graph_node->ref;
+      }
+      groups_[nid] = group_node;
+    }
+  }
+
+  // execute the fusion algorithm.
+  void RunFuse(const IndexedForwardGraph& graph,
+               const DominatorTree& post_dom_tree,
+               int phase) {
+    for (size_t nid = 0; nid < groups_.size(); ++nid) {
+      // the group of current node has been specified already.
+      auto* graph_node = graph.post_dfs_order[nid];
+      auto* dom_node = post_dom_tree.nodes[nid];
+      Group* group_node = groups_[nid];
+      CHECK(group_node != nullptr);
+      // no actions for opaque nodes
+      if (group_node->pattern == kOpaque) continue;
+      // no actions needed if the current node have no dominator
+      if (dom_node->parent == nullptr) continue;
+      CHECK(!graph_node->extern_ref);
+      // Skip if current node is already fused to the parent.
+      size_t dom_parent_gindex = dom_node->parent->gnode->index;
+      if (groups_[dom_parent_gindex] != nullptr &&
+          group_node->FindRoot() == groups_[dom_parent_gindex]->FindRoot()) {
+        continue;
+      }
+      // Try to fuse current node to its post-dominator.
+      if (group_node->pattern == kOutEWiseFusable) {
+        if (phase != 0) continue;
+        // Path for OutEWiseFusable: conv2d
+        // Check if the dominator relation is elemwise.
+        if (dom_node->parent != nullptr && dom_node->pattern == kElemWise) {
+          CHECK(dom_node->parent->gnode != nullptr);
+          // The fuse can be executed if all the intermediate ops are still broadcast.
+          auto fcond = [](OpPatternKind kind, bool is_sink) {
+            return kind <= kBroadcast;
+          };
+          if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
+            CommitFuse(graph_node, dom_node->parent->gnode);
+          }
+        }
+      } else if (group_node->pattern <= kBroadcast) {
+        // The fuse can be executed if all the intermediate ops are still broadcast.
+        auto fcond = [](OpPatternKind kind, bool is_sink) {
+          if (!is_sink) {
+            return kind <= kBroadcast;
+          } else {
+            return (kind <= kBroadcast ||
+                    kind == kCommReduce ||
+                    kind == kOutEWiseFusable);
+          }
+        };
+        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
+          CommitFuse(graph_node, dom_node->parent->gnode);
+        }
+      } else if (group_node->pattern == kInjective) {
+        // defer injective fusion to second phase.
+        // so conv2d always finishes fusing.
+        if (phase != 1) continue;
+        // Check if all path are injective.
+        auto fcond = [](OpPatternKind kind, bool is_sink) {
+          return kind <= kInjective;
+        };
+        if (CheckPath(graph_node, dom_node->parent->gnode, fcond)) {
+          CommitFuse(graph_node, dom_node->parent->gnode);
+        }
+      } else {
+        // do nothing.
+        CHECK(group_node->pattern == kCommReduce);
+      }
+    }
+  }
+};
+
+std::vector<GraphPartitioner::Group*>
+GraphPartitioner::Partition(const IndexedForwardGraph& graph) {
+  this->InitGroups(graph);
+  if (opt_level_ == 0) return std::move(groups_);
+  // get post dominator tree
+  auto post_dom_tree = DominatorTree::PostDom(arena_, graph);
+  // run fusion algorithm.
+  for (int phase = 0; phase < 2; ++phase) {
+    this->RunFuse(graph, post_dom_tree, phase);
+  }
+  return std::move(groups_);
+}
+
+class FuseMutator : private ExprMutator {
+ public:
+  // Run the transform
+  Expr Transform(const Expr& body, int fuse_opt_level) {
+    // setup the group map.
+    auto graph = IndexedForwardGraph::Create(&arena_, body);
+    auto groups = GraphPartitioner(&arena_, fuse_opt_level).Partition(
+        graph);
+    for (size_t nid = 0; nid < graph.post_dfs_order.size(); ++nid) {
+      CHECK(graph.post_dfs_order[nid]->ref != nullptr);
+      gmap_[graph.post_dfs_order[nid]->ref] = groups[nid];
+    }
+    // The following line can be used for debug.
+    // this->DebugDumpGroup(body);
+    return this->Mutate(body);
+  }
+
+
+ private:
+  /*! \brief Temporary information from each group. */
+  struct GroupInfo {
+   public:
+    // The parameters of the function.
+    Array<Var> params;
+    // The arguments to call the functions.
+    Array<Expr> arguments;
+    // Get a new parameter or allocate an old one
+    Var GetOrAllocParam(const Expr& expr, const Type& type) {
+      // run linear scan as most fused groups contain only a few inputs.
+      for (size_t i = 0; i < arguments.size(); ++i) {
+        if (expr.same_as(arguments[i])) return params[i];
+      }
+      // create a new parameter.
+      std::ostringstream os;
+      os << "p" << params.size();
+      auto var = VarNode::make(os.str(), type);
+      params.push_back(var);
+      arguments.push_back(expr);
+      return var;
+    }
+  };
+  /*! \brief Internal arena. */
+  common::Arena arena_;
+  /*! \brief The group assignment map. */
+  std::unordered_map<const Node*, GraphPartitioner::Group*> gmap_;
+  /* \brief Internal group information map. */
+  std::unordered_map<GraphPartitioner::Group*, GroupInfo> ginfo_;
   // Skip primitive function.
   Expr VisitExpr_(const FunctionNode* fn_node) {
     NodeRef res = FunctionGetAttr(GetRef<Function>(fn_node), "Primitive");
@@ -26,48 +699,74 @@ class SimpleFuser : public ExprMutator {
       return ExprMutator::VisitExpr_(fn_node);
     }
   }
-
+  // Transform calls.
   Expr VisitExpr_(const CallNode* call) {
     if (call->op.as<OpNode>()) {
-      // Placeholder fusion algorithm which abstracts
-      // single definitions into functions only.
-      Array<Var> params;
-      Array<Expr> inner_args;
-      Array<Expr> args;
-
-      int param_number = 0;
+      // If it is a primitive op call
+      // then we must have a group assignment for it already.
+      CHECK(gmap_.count(call));
+      auto* ret_group = gmap_.at(call)->FindRoot();
+      Array<Expr> new_args;
       for (auto arg : call->args) {
-        std::ostringstream os;
-        os << "p" << param_number++;
         auto type = arg->checked_type();
-        auto var = VarNode::make(os.str(), type);
-        params.push_back(var);
-        inner_args.push_back(var);
-        args.push_back(this->Mutate(arg));
+        CHECK(gmap_.count(arg.get()))
+            << "cannot find group of " << arg;
+        auto* arg_group = gmap_.at(arg.get())->FindRoot();
+        Expr new_arg = this->Mutate(arg);
+
+        if (ret_group != arg_group) {
+          Var param = ginfo_[ret_group].GetOrAllocParam(new_arg, type);
+          new_args.push_back(param);
+        } else {
+          new_args.push_back(new_arg);
+        }
+      }
+      auto new_call = CallNode::make(
+          call->op, new_args, call->attrs, call->type_args);
+
+      if (ret_group->root_ref == call) {
+        // This is the root of the group
+        // create the new call node.
+        const GroupInfo& ginfo = ginfo_[ret_group];
+        auto func = FunctionNode::make(
+            ginfo.params, new_call, call->checked_type(), {});
+        func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
+        return CallNode::make(func, ginfo.arguments, Attrs());
+      } else {
+        // This is an intermediate node of a fused function
+        // simply return the new call.
+        return new_call;
       }
-      auto body = CallNode::make(call->op, inner_args, call->attrs);
-      auto func = FunctionNode::make(
-          params, body, call->checked_type(), {});
-      func = FunctionSetAttr(func, "Primitive", tvm::Integer(1));
-      return CallNode::make(func, args, Attrs());
     } else {
       return ExprMutator::VisitExpr_(call);
     }
   }
+  // Debug function, dump the group assignment in text.
+  void DebugDumpGroup(const Expr& body) {
+    std::string text = RelayPrint(body, [this](const Expr& expr) -> std::string {
+        auto it = gmap_.find(expr.get());
+        if (it == gmap_.end()) return "";
+        std::ostringstream os;
+        auto *group = it->second->FindRoot();
+        os << "group=" << group;
+        return os.str();
+      });
+    LOG(INFO) << "Dump of group info:\n" << text;
+  }
 };
 
 
-Expr FuseOps(const Expr& expr) {
+Expr FuseOps(const Expr& expr, int fuse_opt_level) {
   // First we convert all chains of fusable ops into
   // abstracted functions which we mark as primtive
   // then we convert these primtive functions into
   // new operators.
-  return SimpleFuser().Mutate(expr);
+  return FuseMutator().Transform(expr, fuse_opt_level);
 }
 
 TVM_REGISTER_API("relay._ir_pass.FuseOps")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
-    *ret = FuseOps(args[0]);
+    *ret = FuseOps(args[0], args[1]);
 });
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/pass/pass_util.h b/src/relay/pass/pass_util.h
new file mode 100644
index 0000000000000000000000000000000000000000..bf52297e8930f46feb0df111e196f3d247dbf0ae
--- /dev/null
+++ b/src/relay/pass/pass_util.h
@@ -0,0 +1,27 @@
+/*!
+ *  Copyright (c) 2018 by Contributors.
+ *
+ * \file tvm/relay/pass/pass_util.h
+ * \brief Utilities for writing
+ */
+#ifndef TVM_RELAY_PASS_PASS_UTIL_H_
+#define TVM_RELAY_PASS_PASS_UTIL_H_
+
+#include <tvm/relay/op.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/attrs/transform.h>
+
+namespace tvm {
+namespace relay {
+
+/*!
+ * \brief Get reference counter of each internal ExprNode in body.
+ * \param body The body expression.
+ * \return The reference count mapping.
+ */
+std::unordered_map<const Node*, size_t>
+GetExprRefCount(const Expr& body);
+
+}  // namespace relay
+}  // namespace tvm
+#endif  // TVM_RELAY_PASS_PASS_UTIL_H_
diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc
index b224a099aee1ee2d2c7ceb525b6dc897d0bb2f27..5cabfbdabc49e6c1b884d76fe675654ac14e7550 100644
--- a/src/relay/pass/type_infer.cc
+++ b/src/relay/pass/type_infer.cc
@@ -442,6 +442,9 @@ class TypeInferencer::Resolver : public ExprMutator {
     VarNode* new_var =(
         std::is_base_of<VarNode, T>::value ?
         static_cast<VarNode*>(new_e.node_.get()) : nullptr);
+    FunctionNode* new_fn =(
+        std::is_base_of<FunctionNode, T>::value ?
+        static_cast<FunctionNode*>(new_e.node_.get()) : nullptr);
 
     // check if we need update the new_e
     bool need_update_type = !checked_type.same_as(new_e->checked_type_);
@@ -454,7 +457,17 @@ class TypeInferencer::Resolver : public ExprMutator {
         update_missing_type_annotation_ &&
         !new_var->type_annotation.defined());
 
-    if (!need_update_type && !need_update_var && !need_update_call) return new_e;
+    bool need_update_fn = (
+        std::is_base_of<FunctionNode, T>::value &&
+        update_missing_type_annotation_ &&
+        !new_fn->ret_type.defined());
+
+    if (!need_update_type &&
+        !need_update_var &&
+        !need_update_call &&
+        !need_update_fn) {
+      return new_e;
+    }
 
     if (!new_e.node_.unique()) {
       // Copy on write optimization
@@ -467,6 +480,9 @@ class TypeInferencer::Resolver : public ExprMutator {
       new_var = (
           std::is_base_of<VarNode, T>::value ?
           static_cast<VarNode*>(new_e.node_.get()) : nullptr);
+      new_fn = (
+          std::is_base_of<FunctionNode, T>::value ?
+          static_cast<FunctionNode*>(new_e.node_.get()) : nullptr);
     }
 
     // attach the information.
@@ -483,6 +499,11 @@ class TypeInferencer::Resolver : public ExprMutator {
     if (need_update_var) {
       new_var->type_annotation = checked_type;
     }
+    if (need_update_fn) {
+      auto* fn_type = checked_type.as<FuncTypeNode>();
+      CHECK(fn_type != nullptr);
+      new_fn->ret_type = fn_type->ret_type;
+    }
     return new_e;
   }
 
diff --git a/src/relay/pass/type_solver.cc b/src/relay/pass/type_solver.cc
index 3ca161d23f728e593b01a64e7703f3f6ae743c65..e1efcbbdd0b91d2ce09ff096dcaa4428b6fc58cf 100644
--- a/src/relay/pass/type_solver.cc
+++ b/src/relay/pass/type_solver.cc
@@ -85,18 +85,18 @@ Type TypeSolver::Unify(const Type& dst, const Type& src) {
 void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
   if (auto *op = constraint.as<TypeRelationNode>()) {
     // create a new relation node.
-    RelationNode* rnode = make<RelationNode>();
+    RelationNode* rnode = arena_.make<RelationNode>();
     rnode->rel = GetRef<TypeRelation>(op);
     rel_nodes_.push_back(rnode);
     // populate the type information.
     for (size_t i = 0; i < op->args.size(); ++i) {
       // insert link to the type list
-      LinkNode<TypeNode*>* tlink = make<LinkNode<TypeNode*> >();
+      LinkNode<TypeNode*>* tlink = arena_.make<LinkNode<TypeNode*> >();
       TypeNode* tnode = GetTypeNode(op->args[i]);
       tlink->value = tnode;
       rnode->type_list.Push(tlink);
       // insert type->relation node
-      LinkNode<RelationNode*>* rlink = make<LinkNode<RelationNode*> >();
+      LinkNode<RelationNode*>* rlink = arena_.make<LinkNode<RelationNode*> >();
       rlink->value = rnode;
       tnode->rel_list.Push(rlink);
     }
diff --git a/src/relay/pass/type_solver.h b/src/relay/pass/type_solver.h
index 30f82f980a751244a0eadd3dda411680491636a5..2f311c9b9810229b54cf178197abf608619269b8 100644
--- a/src/relay/pass/type_solver.h
+++ b/src/relay/pass/type_solver.h
@@ -16,6 +16,8 @@
 namespace tvm {
 namespace relay {
 
+using common::LinkNode;
+using common::LinkedList;
 /*!
  * \brief Interface of type solver used in type inference.
  *
@@ -69,41 +71,6 @@ class TypeSolver {
   // Internally the solver maintains a bipartite graph of Relation and Types.
   // All the object in the structure is managed by a arena allocator
   // which releases the memory upon distruction of the type solver.
-  /*!
-   * \brief Link list node
-   * \tparam T the content data type
-   */
-  template<typename T>
-  struct LinkNode {
-    /*! \brief The content value */
-    T value;
-    /*! \brief pointer to the next location */
-    LinkNode<T>* next{nullptr};
-  };
-  /*!
-   * \brief LinkedList structure
-   * \tparam T the content data type
-   */
-  template<typename T>
-  struct LinkedList {
-    /*! \brief Head pointer */
-    LinkNode<T>* head{nullptr};
-    /*! \brief Tail pointer */
-    LinkNode<T>* tail{nullptr};
-    /*!
-     * \brief Push a new node to the end of the linked list.
-     * \param node The node to be pushed.
-     */
-    void Push(LinkNode<T>* node) {
-      node->next = nullptr;
-      if (this->tail != nullptr) {
-        this->tail->next = node;
-        this->tail = node;
-      } else {
-        head = tail = node;
-      }
-    }
-  };
   /*!
    * \brief type node struct
    *  TypeNode implements a union-find data structure(via parent)
@@ -164,18 +131,6 @@ class TypeSolver {
   common::Arena arena_;
   /*! \brief Reporter that reports back to self */
   TypeReporter reporter_;
-  /*!
-   * \brief Create function to create a new node ptr via arena
-   * \tparam The type parameter
-   * \return The node pointer.
-   */
-  template<typename T>
-  T* make() {
-    T* ptr = arena_.Alloc<T>();
-    // call constructor
-    new (ptr) T();
-    return ptr;
-  }
   /*!
    * \brief GetTypeNode that is corresponds to t.
    * if it do not exist, create a new one.
@@ -186,7 +141,7 @@ class TypeSolver {
     if (it != tmap_.end()) {
       return it->second->FindRoot();
     } else {
-      TypeNode* n = make<TypeNode>();
+      TypeNode* n = arena_.make<TypeNode>();
       type_nodes_.push_back(n);
       n->resolved_type = t;
       tmap_[t] = n;
diff --git a/src/relay/pass/util.cc b/src/relay/pass/util.cc
index 51ef0377868f18f74a9f4573cec66ba997fad914..ebc4e6fc16e614b82a866fc5a3fe43b4a29a8c1c 100644
--- a/src/relay/pass/util.cc
+++ b/src/relay/pass/util.cc
@@ -129,5 +129,23 @@ TVM_REGISTER_API("relay._ir_pass.free_type_vars")
     }
   });
 
+/*!
+ * \brief Get reference counter of each internal ExprNode in body.
+ * \param body The body expression.
+ * \return The reference count mapping.
+ */
+std::unordered_map<const Node*, size_t>
+GetExprRefCount(const Expr& body) {
+  class ExprRefCounter : private ExprVisitor {
+   public:
+    std::unordered_map<const Node*, size_t>
+    Get(const Expr& body) {
+      this->VisitExpr(body);
+      return std::move(this->visit_counter_);
+    }
+  };
+  return ExprRefCounter().Get(body);
+}
+
 }  // namespace relay
 }  // namespace tvm
diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py
index dd790a6d7d873e5c9fd82f12ef40430d0a50d23a..d12804d512f09321e2d733917f2fd1666ce6e9eb 100644
--- a/tests/python/relay/test_ir_text_printer.py
+++ b/tests/python/relay/test_ir_text_printer.py
@@ -33,6 +33,7 @@ def test_env():
     text = env.astext()
     assert "def @myf" in text
     assert "%1 = add(%0, %0) # ty=float32" in text
+    show(env.astext(annotate=lambda x: str(x.checked_type.dtype)))
     show(text)
 
 
diff --git a/tests/python/relay/test_pass_fold_scale_axis.py b/tests/python/relay/test_pass_fold_scale_axis.py
index 1b57bdce0e0c9b839a5e6591406a456c17caa99e..a5a7a05a974cd4ee518ed2ca96408cd974c4f3f3 100644
--- a/tests/python/relay/test_pass_fold_scale_axis.py
+++ b/tests/python/relay/test_pass_fold_scale_axis.py
@@ -46,6 +46,8 @@ def test_fold_fwd_simple():
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = relay.ir_pass.forward_fold_scale_axis(y1)
         y1_expected = expected(x, weight, in_bias, in_scale, channels)
+        y1_folded = relay.ir_pass.infer_type(y1_folded)
+        y1_expected = relay.ir_pass.infer_type(y1_expected)
         assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
 
     check((2, 4, 10, 10), 2)
@@ -113,6 +115,8 @@ def test_fold_fwd_dual_path():
         type_dict = {x.name_hint:x.checked_type for x in y1.params}
         weight = relay.var("weight", type_dict["weight"])
         y1_expected = expected(x, weight, in_bias, in_scale, channels)
+        y1_folded = relay.ir_pass.infer_type(y1_folded)
+        y1_expected = relay.ir_pass.infer_type(y1_expected)
         assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
 
     check((2, 4, 10, 3), 3)
@@ -194,6 +198,8 @@ def test_fold_bwd_simple():
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
         y1_expected = expected(x, weight, out_bias, out_scale, channels)
+        y1_folded = relay.ir_pass.infer_type(y1_folded)
+        y1_expected = relay.ir_pass.infer_type(y1_expected)
         assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
 
     check((2, 4, 10, 10), 8)
@@ -255,6 +261,8 @@ def test_fold_bwd_dual_path():
         weight = relay.var("weight", type_dict["weight"])
         y1_folded = relay.ir_pass.backward_fold_scale_axis(y1)
         y1_expected = expected(x, weight, out_bias, out_scale, channels)
+        y1_folded = relay.ir_pass.infer_type(y1_folded)
+        y1_expected = relay.ir_pass.infer_type(y1_expected)
         assert relay.ir_pass.alpha_equal(y1_folded, y1_expected)
 
     check((2, 4, 10, 10), 8)
diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py
index 2bbc1dce9693db4aa04d773764fe89a204c8e25a..19bec20ac4afd90ee85556b3abcd8b215c9c90af 100644
--- a/tests/python/relay/test_pass_fuse_ops.py
+++ b/tests/python/relay/test_pass_fuse_ops.py
@@ -3,15 +3,103 @@ from tvm import relay
 
 def test_fuse_simple():
     """Simple testcase."""
-    x = relay.var("x", shape=(10, 20))
-    y = relay.add(x, x)
-    z = relay.exp(y)
+    def before():
+        x = relay.var("x", shape=(10, 20))
+        y = relay.add(x, relay.const(1, "float32"))
+        z = relay.exp(y)
+        return relay.Function([x], z)
+
+    def expected():
+        x = relay.var("p", shape=(10, 20))
+        y = relay.add(x, relay.const(1, "float32"))
+        z = relay.exp(y)
+        f1 = relay.Function([x], z)
+        x = relay.var("x", shape=(10, 20))
+        y = relay.Call(f1, [x])
+        return relay.Function([x], y)
+
+    z = before()
     z = relay.ir_pass.infer_type(z)
-    zz = relay.ir_pass.fuse_ops(z)
+    zz = relay.ir_pass.fuse_ops(z, opt_level=2)
+    zz = relay.ir_pass.infer_type(zz)
     zz = relay.ir_pass.fuse_ops(zz)
     zz = relay.ir_pass.infer_type(zz)
-    zz.astext()
+    after = relay.ir_pass.infer_type(expected())
+    assert relay.ir_pass.alpha_equal(zz, after)
+
+
+
+def test_conv2d_fuse():
+    """Test fusion case of conv2d"""
+    def before(dshape):
+        x = relay.var("x", shape=dshape)
+        y = relay.nn.conv2d(x, relay.var("w1"),
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            channels=16)
+        # this is the next dominator.
+        y1 = relay.add(relay.const(1, "float32"), y)
+        y = relay.add(y, y1)
+        # second path
+        z2 = relay.nn.conv2d(y, relay.var("w2"),
+                             kernel_size=(1, 1),
+                             padding=(0,0),
+                             channels=16)
+        z3 = relay.nn.conv2d(y, relay.var("w3"),
+                             kernel_size=(3, 3),
+                             padding=(1,1),
+                             channels=16)
+        # add can only be fused to z1
+        z = relay.add(z2, z3)
+        return relay.Function(relay.ir_pass.free_vars(z), z)
+
+    def expected(dshape):
+        # segment 1
+        x = relay.var("p0", shape=dshape)
+        w = relay.var("p1")
+        y = relay.nn.conv2d(x, w,
+                            kernel_size=(3, 3),
+                            padding=(1, 1),
+                            channels=16)
+        y1 = relay.add(relay.const(1, "float32"), y)
+        y = relay.add(y, y1)
+        f1 = relay.Function([x, w], y)
+        # segment 2
+        x = relay.var("p0", shape=dshape)
+        w = relay.var("p1")
+        z2 = relay.nn.conv2d(x, w,
+                             kernel_size=(3, 3),
+                             padding=(1,1),
+                             channels=16)
+        f2 = relay.Function([x, w], z2)
+        # segment 3
+        x = relay.var("p0", shape=dshape)
+        w = relay.var("p1")
+        offset = relay.var("p2", shape=dshape)
+        z3 = relay.nn.conv2d(x, w,
+                             kernel_size=(1, 1),
+                             padding=(0, 0),
+                             channels=16)
+        z3 = relay.add(z3, offset)
+        f3 = relay.Function([x, w, offset], z3)
+        # compose
+        x = relay.var("x", shape=dshape)
+        y = relay.Call(f1, [x, relay.var("w1")])
+        z2 = relay.Call(f2, [y, relay.var("w3")])
+        z3 = relay.Call(f3, [y, relay.var("w2"), z2])
+        z = z3
+        return relay.Function(relay.ir_pass.free_vars(z), z)
+
+    dshape = (1, 16, 64, 64)
+    z = before(dshape)
+    z = relay.ir_pass.infer_type(z)
+    zz = relay.ir_pass.fuse_ops(z, opt_level=2)
+    zz = relay.ir_pass.infer_type(zz)
+    after = relay.ir_pass.infer_type(expected(dshape))
+    assert relay.ir_pass.alpha_equal(zz, after)
+
 
 
 if __name__ == "__main__":
     test_fuse_simple()
+    test_conv2d_fuse()