diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h
index d726b1dab66061c1156418ca4c42089fc6002609..efa930568c48234be6af88de8eb7436885f67c31 100644
--- a/include/tvm/node/node.h
+++ b/include/tvm/node/node.h
@@ -102,10 +102,10 @@ class TVM_DLL Node : public NodeBase {
   template<typename T>
   inline bool is_type() const;
   /*!
-   * \brief Get a NodeRef that holds reference to this Node.
-   * \return the NodeRef
+   * \brief Get a NodePtr that holds reference to this Node.
+   * \return the NodePtr
    */
-  inline NodeRef GetNodeRef() const;
+  inline NodePtr<Node> GetNodePtr() const;
   // node ref can see this
   friend class NodeRef;
   static constexpr const char* _type_key = "Node";
@@ -176,6 +176,32 @@ class NodeRef {
   NodePtr<Node> node_;
 };
 
+/*!
+ * \brief Get a reference type from a Node ptr type
+ *
+ *  It is always important to get a reference type
+ *  if we want to return a value as reference or keep
+ *  the node alive beyond the scope of the function.
+ *
+ * \param ptr The node pointer
+ * \tparam RefType The reference type
+ * \tparam NodeType The node type
+ * \return The corresponding RefType
+ */
+template <typename RefType, typename NodeType>
+inline RefType GetRef(const NodeType* ptr);
+
+/*!
+ * \brief Downcast a base reference type to a more specific type.
+ *
+ * \param ref The inptut reference
+ * \return The corresponding SubRef.
+ * \tparam SubRef The target specific reference type.
+ * \tparam BaseRef the current reference type.
+ */
+template <typename SubRef, typename BaseRef>
+inline SubRef Downcast(BaseRef ref);
+
 /*!
  * \brief helper macro to declare type information in a base node.
  */
@@ -218,8 +244,24 @@ inline bool Node::derived_from() const {
   return this->_DerivedFrom(type_id);
 }
 
-inline NodeRef Node::GetNodeRef() const {
-  return NodeRef(NodePtr<Node>(const_cast<Node*>(this)));
+inline NodePtr<Node> Node::GetNodePtr() const {
+  return NodePtr<Node>(const_cast<Node*>(this));
+}
+
+template <typename RefType, typename NodeType>
+inline RefType GetRef(const NodeType* ptr) {
+  static_assert(std::is_base_of<typename RefType::ContainerType, NodeType>::value,
+                "Can only cast to the ref of same container type");
+  return RefType(ptr->GetNodePtr());
+}
+
+template <typename SubRef, typename BaseRef>
+inline SubRef Downcast(BaseRef ref) {
+  CHECK(ref->template is_type<typename SubRef::ContainerType>() ||
+        ref->template derived_from<typename SubRef::ContainerType>())
+      << "Downcast from " << ref->type_key() << " to "
+      << SubRef::ContainerType::_type_key << " failed.";
+  return SubRef(std::move(ref.node_));
 }
 
 inline const Node* NodeRef::get() const {
diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h
index ecf45353af67bc4d07bd031b49555370c4828719..ab55f6f3965f70312bef3a06cca19b9a6c79ca8e 100644
--- a/include/tvm/relay/base.h
+++ b/include/tvm/relay/base.h
@@ -158,43 +158,6 @@ class RelayNode : public Node {
   TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
 };
 
-/*!
- * \brief Get a reference type from a Node ptr type
- *
- *  It is always important to get a reference type
- *  if we want to return a value as reference or keep
- *  the node alive beyond the scope of the function.
- *
- * \param ptr The node pointer
- * \tparam RefType The reference type
- * \tparam NodeType The node type
- * \return The corresponding RefType
- */
-template <typename RefType, typename NodeType>
-RefType GetRef(const NodeType* ptr) {
-  static_assert(std::is_same<typename RefType::ContainerType, NodeType>::value,
-                "Can only cast to the ref of same container type");
-  return RefType(std::move(ptr->GetNodeRef().node_));
-}
-
-// TODO(@tqchen, @jroesch): can we move these semantics to HalideIR
-template <typename T>
-inline const T* As(const NodeRef& node) {
-  const Node* ptr = static_cast<const Node*>(node.get());
-  if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) {
-    return static_cast<const T*>(ptr);
-  }
-  return nullptr;
-}
-
-template <typename SubRef, typename BaseRef>
-SubRef Downcast(BaseRef ref) {
-  CHECK(ref->template is_type<typename SubRef::ContainerType>())
-      << "Downcast from " << ref->type_key() << " to "
-      << SubRef::ContainerType::_type_key << " failed.";
-  return SubRef(ref.node_);
-}
-
 }  // namespace relay
 }  // namespace tvm
 
diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h
index 6388e8367bf68525158bc84c97999cbaa843d32e..0dc2ff6fce2da644731d867f7f35cec0cbf2ce63 100644
--- a/include/tvm/relay/expr.h
+++ b/include/tvm/relay/expr.h
@@ -65,7 +65,9 @@ class ConstantNode : public ExprNode {
   TensorType tensor_type() const;
 
   /*! \return Whether it is scalar(rank-0 tensor) */
-  bool is_scalar() const { return data->ndim == 0; }
+  bool is_scalar() const {
+    return data->ndim == 0;
+  }
 
   void VisitAttrs(tvm::AttrVisitor* v) final {
     v->Visit("data", &data);
@@ -341,7 +343,7 @@ RELAY_DEFINE_NODE_REF(Let, LetNode, Expr);
  *
  * let x = if (true) { 1 } else { 0 }; // x is 1
  * let y = if (false) { 1 } else { 0 }; // y is 0
- * 
+ *
  * \note This is similar to C's ternary operator.
  */
 class If;
diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h
index 27bb464b98a3d2524d9fa6b8953429da313a8177..e79535a5034bb422555d47ce7de5279bc275233a 100644
--- a/include/tvm/relay/expr_functor.h
+++ b/include/tvm/relay/expr_functor.h
@@ -139,19 +139,19 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
 * the cost of using functional updates.
 */
 class ExprMutator
-    : public ::tvm::relay::ExprFunctor<Expr(const Expr&, const Expr&)> {
+    : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
  public:
   Expr Mutate(const Expr& expr);
-  Expr VisitExpr_(const VarNode* op, const Expr& e) override;
-  Expr VisitExpr_(const ConstantNode* op, const Expr& e) override;
-  Expr VisitExpr_(const GlobalVarNode* op, const Expr& e) override;
-  Expr VisitExpr_(const OpNode* op, const Expr& expr) override;
-  Expr VisitExpr_(const TupleNode* op, const Expr& e) override;
-  Expr VisitExpr_(const ParamNode* op, const Expr& e) override;
-  Expr VisitExpr_(const FunctionNode* op, const Expr& e) override;
-  Expr VisitExpr_(const CallNode* call_node, const Expr& e) override;
-  Expr VisitExpr_(const LetNode* op, const Expr& e) override;
-  Expr VisitExpr_(const IfNode* op, const Expr& e) override;
+  Expr VisitExpr_(const VarNode* op) override;
+  Expr VisitExpr_(const ConstantNode* op) override;
+  Expr VisitExpr_(const GlobalVarNode* op) override;
+  Expr VisitExpr_(const OpNode* op) override;
+  Expr VisitExpr_(const TupleNode* op) override;
+  Expr VisitExpr_(const ParamNode* op) override;
+  Expr VisitExpr_(const FunctionNode* op) override;
+  Expr VisitExpr_(const CallNode* call_node) override;
+  Expr VisitExpr_(const LetNode* op) override;
+  Expr VisitExpr_(const IfNode* op) override;
   /*! \brief Used to visit the types inside of expressions.
    *
    * Can be overloaded to transform the types in arbitrary
@@ -162,7 +162,7 @@ class ExprMutator
 
  private:
   /*! \brief Internal map used for memoization. */
-  tvm::Map<Expr, Expr> memo_;
+  std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
 };
 
 }  // namespace relay
diff --git a/src/relay/ir/environment.cc b/src/relay/ir/environment.cc
index 16b0314507cfda2436931a999cb1c70ea321b56c..d7a28231ceacd4bb99d32a711cda4bd337880bea 100644
--- a/src/relay/ir/environment.cc
+++ b/src/relay/ir/environment.cc
@@ -41,12 +41,12 @@ void EnvironmentNode::Add(const GlobalVar &var,
                           const Function &func,
                           bool update) {
   // Type check the item before we add it to the environment.
-  auto env = relay::GetRef<Environment>(this);
+  auto env = GetRef<Environment>(this);
 
   Expr checked_expr = InferType(env, var, func);
 
   if (const FunctionNode *func_node = checked_expr.as<FunctionNode>()) {
-    auto checked_func = relay::GetRef<Function>(func_node);
+    auto checked_func = GetRef<Function>(func_node);
     auto type = checked_func->checked_type();
 
     CHECK(IsFullyResolved(type));
diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc
index 85ae5ffa694e439076b41efec3470ffff540c555..e3393bdb039bb95b60e87b27f4cdf059f67c0dc3 100644
--- a/src/relay/ir/expr_functor.cc
+++ b/src/relay/ir/expr_functor.cc
@@ -13,33 +13,33 @@ namespace tvm {
 namespace relay {
 
 Expr ExprMutator::Mutate(const Expr& expr) {
-  auto cached_expr = this->memo_.find(expr);
-  if (cached_expr != this->memo_.end()) {
-    return (*cached_expr).second;
+  auto it = this->memo_.find(expr);
+  if (it != this->memo_.end()) {
+    return it->second;
   } else {
-    auto new_expr = this->ExprMutator::VisitExpr(expr, expr);
-    this->memo_.Set(expr, new_expr);
+    Expr new_expr = ExprMutator::VisitExpr(expr);
+    memo_[expr] = new_expr;
     return new_expr;
   }
 }
 
-Expr ExprMutator::VisitExpr_(const VarNode* op, const Expr& expr) {
-  return expr;
+Expr ExprMutator::VisitExpr_(const VarNode* op) {
+  return GetRef<Expr>(op);
 }
 
-Expr ExprMutator::VisitExpr_(const ConstantNode* op, const Expr& expr) {
-  return expr;
+Expr ExprMutator::VisitExpr_(const ConstantNode* op) {
+  return GetRef<Expr>(op);
 }
 
-Expr ExprMutator::VisitExpr_(const GlobalVarNode* op, const Expr& expr) {
-  return expr;
+Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) {
+  return GetRef<Expr>(op);
 }
 
-Expr ExprMutator::VisitExpr_(const OpNode* op, const Expr& expr) {
-  return expr;
+Expr ExprMutator::VisitExpr_(const OpNode* op) {
+  return GetRef<Expr>(op);
 }
 
-Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) {
+Expr ExprMutator::VisitExpr_(const TupleNode* op) {
   tvm::Array<Expr> fields;
   bool all_fields_unchanged = true;
   for (auto field : op->fields) {
@@ -49,23 +49,23 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) {
   }
 
   if (all_fields_unchanged) {
-    return e;
+    return GetRef<Expr>(op);
   } else {
     return TupleNode::make(fields);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const ParamNode* op, const Expr& e) {
+Expr ExprMutator::VisitExpr_(const ParamNode* op) {
   Var var = Downcast<Var>(this->Mutate(op->var));
   auto type = this->VisitType(op->type);
-  if (var == op->var && type == op->type) {
-    return e;
+  if (op->var.same_as(var) && op->type.same_as(type)) {
+    return GetRef<Expr>(op);
   } else {
     return ParamNode::make(var, type);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) {
+Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
   tvm::Array<TypeParam> ty_params;
   bool all_ty_params_changed = true;
 
@@ -86,74 +86,82 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) {
   auto ret_type = this->VisitType(op->ret_type);
   auto body = this->Mutate(op->body);
 
-  if (ty_params.same_as(op->type_params) && params.same_as(op->params) &&
-      ret_type.same_as(op->ret_type) && body.same_as(op->body)) {
-    return e;
+  if (ty_params.same_as(op->type_params) &&
+      params.same_as(op->params) &&
+      ret_type.same_as(op->ret_type) &&
+      body.same_as(op->body)) {
+    return GetRef<Expr>(op);
   } else {
     return FunctionNode::make(params, ret_type, body, ty_params);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const CallNode* call_node, const Expr& e) {
-  auto op = this->Mutate(call_node->op);
+Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
+  auto new_op = this->Mutate(call_node->op);
+  bool unchanged = call_node->op.same_as(new_op);
 
   tvm::Array<Type> ty_args;
-  bool all_ty_args_unchanged = true;
   for (auto ty_arg : call_node->type_args) {
     auto new_ty_arg = this->VisitType(ty_arg);
     ty_args.push_back(new_ty_arg);
-    all_ty_args_unchanged &= new_ty_arg.same_as(ty_arg);
+    unchanged &= new_ty_arg.same_as(ty_arg);
   }
 
   tvm::Array<Expr> call_args;
-  bool all_args_unchanged = true;
   for (auto arg : call_node->args) {
     auto new_arg = this->Mutate(arg);
     call_args.push_back(new_arg);
-    all_args_unchanged &= new_arg.same_as(arg);
+    unchanged &= new_arg.same_as(arg);
   }
 
-  if (all_ty_args_unchanged && all_args_unchanged &&
-      call_node->op.same_as(op)) {
-    return e;
+  if (unchanged) {
+    return GetRef<Expr>(call_node);
   } else {
-    return CallNode::make(op, call_args, call_node->attrs, ty_args);
+    return CallNode::make(new_op, call_args, call_node->attrs, ty_args);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const LetNode* op, const Expr& e) {
+Expr ExprMutator::VisitExpr_(const LetNode* op) {
   Var var = Downcast<Var>(this->Mutate(op->var));
   auto type = this->VisitType(op->value_type);
   auto value = this->Mutate(op->value);
   auto body = this->Mutate(op->body);
 
-  if (var.same_as(op->var) && type.same_as(op->value_type) &&
-      value.same_as(op->value) && body.same_as(op->body)) {
-    return e;
+  if (var.same_as(op->var) &&
+      type.same_as(op->value_type) &&
+      value.same_as(op->value) &&
+      body.same_as(op->body)) {
+    return GetRef<Expr>(op);
   } else {
     return LetNode::make(var, value, body, type);
   }
 }
 
-Expr ExprMutator::VisitExpr_(const IfNode* op, const Expr& e) {
+Expr ExprMutator::VisitExpr_(const IfNode* op) {
   auto guard = this->Mutate(op->cond);
   auto true_b = this->Mutate(op->true_branch);
   auto false_b = this->Mutate(op->false_branch);
-  if (op->cond == guard && true_b == op->true_branch &&
-      false_b == op->false_branch) {
-    return e;
+  if (op->cond.same_as(guard) &&
+      op->true_branch.same_as(true_b) &&
+      op->false_branch.same_as(false_b)) {
+    return GetRef<Expr>(op);;
   } else {
     return IfNode::make(guard, true_b, false_b);
   }
 }
 
-Type ExprMutator::VisitType(const Type& t) { return t; }
+Type ExprMutator::VisitType(const Type& t) {
+  return t;
+}
 
-void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { return; }
+void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
+}
 
-void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { return; }
+void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {
+}
 
-void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { return; }
+void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) {
+}
 
 void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
   for (auto field : op->fields) {
@@ -202,4 +210,3 @@ void ExprVisitor::VisitType(const Type& t) { return; }
 
 }  // namespace relay
 }  // namespace tvm
-
diff --git a/src/relay/pass/type_visitor.h b/src/relay/pass/type_visitor.h
index 725e3d9b3846564f4be89445d055d542272d991f..c37b536ce0d01d171326252ebb454e34feef3641 100644
--- a/src/relay/pass/type_visitor.h
+++ b/src/relay/pass/type_visitor.h
@@ -78,7 +78,8 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
     Array<TypeConstraint> type_constraints;
     for (auto type_cs : op->type_constraints) {
       auto new_type_cs = VisitType(type_cs);
-      if (const TypeConstraintNode* tin = As<TypeConstraintNode>(new_type_cs)) {
+      if (const TypeConstraintNode* tin =
+          new_type_cs.as_derived<TypeConstraintNode>()) {
         type_constraints.push_back(GetRef<TypeConstraint>(tin));
       } else {
         CHECK(false) << new_type_cs << std::endl;
diff --git a/tests/cpp/expr_test.cc b/tests/cpp/expr_test.cc
index 9cdfef7f6a01afeb5c224f96bed84a4fb79fa543..dca76205d79f1b0382ec03bc35203aad39d1394e 100644
--- a/tests/cpp/expr_test.cc
+++ b/tests/cpp/expr_test.cc
@@ -20,7 +20,7 @@ TEST(ExprNodeRef, Basic) {
   Var x("x");
   Expr z = max(x + 1 + 2, 100);
   const ir::Max* op = z.as<ir::Max>();
-  CHECK(op->GetNodeRef().same_as(z));
+  CHECK(NodeRef(op->GetNodePtr()).same_as(z));
 }