From 5c07413cd46692b2b23188feea690fd6b631e713 Mon Sep 17 00:00:00 2001
From: Ziheng Jiang <jzhtomas@gmail.com>
Date: Sat, 11 Feb 2017 21:55:57 -0800
Subject: [PATCH] [PASS] Change IRVisitor interfaces to function override (#42)

* [PASS] Change IRVisitor interfaces to function override

* [PASS] Change IRMutator interfaces to overloadable function
---
 include/tvm/ir_mutator.h |  84 ++++++-
 include/tvm/ir_visitor.h |  35 ++-
 src/pass/ir_mutator.cc   | 488 ++++++++++++++++++++++++---------------
 src/pass/ir_visitor.cc   | 232 +++++++++++--------
 4 files changed, 549 insertions(+), 290 deletions(-)

diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h
index eea6a3343..c42823269 100644
--- a/include/tvm/ir_mutator.h
+++ b/include/tvm/ir_mutator.h
@@ -16,7 +16,8 @@ namespace ir {
 /*!
  * \brief a base class for mutator to iterative mutate the IR
  *
- *  This IRMutator is implemented via IRFunctor instead of Visitor Pattern.
+ *  This IRMutator is implemented via Visitor Pattern.
+ *  Also you can implement via IRFunctor.
  *  This enables easy extensions of possible new Node.
  *  It also makes changing return types easier.
  *
@@ -54,20 +55,91 @@ class IRMutator {
   static FMutateStmt& vtable_stmt();  // NOLINT(*)
   // Set of overloadable functions
   // The underscore allows Mutate not to be shadowed by inheritance
+  virtual Stmt Mutate_(const Variable* op, const Stmt& s);
   virtual Stmt Mutate_(const LetStmt* op, const Stmt& s);
   virtual Stmt Mutate_(const AttrStmt* op, const Stmt& s);
+  virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
   virtual Stmt Mutate_(const For* op, const Stmt& s);
-  virtual Stmt Mutate_(const Provide* op, const Stmt& s);
   virtual Stmt Mutate_(const Allocate* op, const Stmt& s);
-  virtual Stmt Mutate_(const Realize* op, const Stmt& s);
+  virtual Stmt Mutate_(const Load* op, const Stmt& s);
   virtual Stmt Mutate_(const Store* op, const Stmt& s);
+  virtual Stmt Mutate_(const Let* op, const Stmt& s);
   virtual Stmt Mutate_(const Free* op, const Stmt& s);
-  virtual Stmt Mutate_(const IfThenElse* op, const Stmt& s);
+  virtual Stmt Mutate_(const Call* op, const Stmt& s);
+  virtual Stmt Mutate_(const Add* op, const Stmt& e);
+  virtual Stmt Mutate_(const Sub* op, const Stmt& e);
+  virtual Stmt Mutate_(const Mul* op, const Stmt& e);
+  virtual Stmt Mutate_(const Div* op, const Stmt& e);
+  virtual Stmt Mutate_(const Mod* op, const Stmt& e);
+  virtual Stmt Mutate_(const Min* op, const Stmt& e);
+  virtual Stmt Mutate_(const Max* op, const Stmt& e);
+  virtual Stmt Mutate_(const EQ* op, const Stmt& e);
+  virtual Stmt Mutate_(const NE* op, const Stmt& e);
+  virtual Stmt Mutate_(const LT* op, const Stmt& e);
+  virtual Stmt Mutate_(const LE* op, const Stmt& e);
+  virtual Stmt Mutate_(const GT* op, const Stmt& e);
+  virtual Stmt Mutate_(const GE* op, const Stmt& e);
+  virtual Stmt Mutate_(const And* op, const Stmt& e);
+  virtual Stmt Mutate_(const Or* op, const Stmt& e);
+  virtual Stmt Mutate_(const Reduce* op, const Stmt& s);
+  virtual Stmt Mutate_(const Cast* op, const Stmt& s);
+  virtual Stmt Mutate_(const Not* op, const Stmt& s);
+  virtual Stmt Mutate_(const Select* op, const Stmt& s);
+  virtual Stmt Mutate_(const Ramp* op, const Stmt& s);
+  virtual Stmt Mutate_(const Broadcast* op, const Stmt& e);
+  virtual Stmt Mutate_(const AssertStmt* op, const Stmt& e);
+  virtual Stmt Mutate_(const ProducerConsumer* op, const Stmt& e);
+  virtual Stmt Mutate_(const Provide* op, const Stmt& e);
+  virtual Stmt Mutate_(const Realize* op, const Stmt& s);
   virtual Stmt Mutate_(const Block* op, const Stmt& s);
-  virtual Expr Mutate_(const Call* op, const Expr& e);
-  virtual Expr Mutate_(const Load* op, const Expr& s);
+  virtual Stmt Mutate_(const Evaluate* op, const Stmt& e);
+  virtual Stmt Mutate_(const IntImm* op, const Stmt& e);
+  virtual Stmt Mutate_(const UIntImm* op, const Stmt& e);
+  virtual Stmt Mutate_(const FloatImm* op, const Stmt& e);
+  virtual Stmt Mutate_(const StringImm* op, const Stmt& e);
+
   virtual Expr Mutate_(const Variable* op, const Expr& e);
+  virtual Expr Mutate_(const LetStmt* op, const Expr& e);
+  virtual Expr Mutate_(const AttrStmt* op, const Expr& e);
+  virtual Expr Mutate_(const IfThenElse* op, const Expr& e);
+  virtual Expr Mutate_(const For* op, const Expr& e);
+  virtual Expr Mutate_(const Allocate* op, const Expr& e);
+  virtual Expr Mutate_(const Load* op, const Expr& e);
+  virtual Expr Mutate_(const Store* op, const Expr& e);
   virtual Expr Mutate_(const Let* op, const Expr& e);
+  virtual Expr Mutate_(const Free* op, const Expr& e);
+  virtual Expr Mutate_(const Call* op, const Expr& e);
+  virtual Expr Mutate_(const Add* op, const Expr& e);
+  virtual Expr Mutate_(const Sub* op, const Expr& e);
+  virtual Expr Mutate_(const Mul* op, const Expr& e);
+  virtual Expr Mutate_(const Div* op, const Expr& e);
+  virtual Expr Mutate_(const Mod* op, const Expr& e);
+  virtual Expr Mutate_(const Min* op, const Expr& e);
+  virtual Expr Mutate_(const Max* op, const Expr& e);
+  virtual Expr Mutate_(const EQ* op, const Expr& e);
+  virtual Expr Mutate_(const NE* op, const Expr& e);
+  virtual Expr Mutate_(const LT* op, const Expr& e);
+  virtual Expr Mutate_(const LE* op, const Expr& e);
+  virtual Expr Mutate_(const GT* op, const Expr& e);
+  virtual Expr Mutate_(const GE* op, const Expr& e);
+  virtual Expr Mutate_(const And* op, const Expr& e);
+  virtual Expr Mutate_(const Or* op, const Expr& e);
+  virtual Expr Mutate_(const Reduce* op, const Expr& e);
+  virtual Expr Mutate_(const Cast* op, const Expr& e);
+  virtual Expr Mutate_(const Not* op, const Expr& e);
+  virtual Expr Mutate_(const Select* op, const Expr& e);
+  virtual Expr Mutate_(const Ramp* op, const Expr& e);
+  virtual Expr Mutate_(const Broadcast* op, const Expr& e);
+  virtual Expr Mutate_(const AssertStmt* op, const Expr& e);
+  virtual Expr Mutate_(const ProducerConsumer* op, const Expr& e);
+  virtual Expr Mutate_(const Provide* op, const Expr& e);
+  virtual Expr Mutate_(const Realize* op, const Expr& e);
+  virtual Expr Mutate_(const Block* op, const Expr& e);
+  virtual Expr Mutate_(const Evaluate* op, const Expr& e);
+  virtual Expr Mutate_(const IntImm* op, const Expr& e);
+  virtual Expr Mutate_(const UIntImm* op, const Expr& e);
+  virtual Expr Mutate_(const FloatImm* op, const Expr& e);
+  virtual Expr Mutate_(const StringImm* op, const Expr& e);
 };
 
 /*!
diff --git a/include/tvm/ir_visitor.h b/include/tvm/ir_visitor.h
index e5711f65f..6bfbce25a 100644
--- a/include/tvm/ir_visitor.h
+++ b/include/tvm/ir_visitor.h
@@ -36,16 +36,47 @@ class IRVisitor {
   static FVisit& vtable();
   // overloadable visit function.
   virtual void Visit_(const Variable* op);
-  virtual void Visit_(const AttrStmt* op);
   virtual void Visit_(const LetStmt* op);
+  virtual void Visit_(const AttrStmt* op);
+  virtual void Visit_(const IfThenElse* op);
   virtual void Visit_(const For* op);
   virtual void Visit_(const Allocate* op);
-  virtual void Visit_(const IfThenElse* op);
   virtual void Visit_(const Load* op);
   virtual void Visit_(const Store* op);
   virtual void Visit_(const Let* op);
   virtual void Visit_(const Free* op);
   virtual void Visit_(const Call* op);
+  virtual void Visit_(const Add* op);
+  virtual void Visit_(const Sub* op);
+  virtual void Visit_(const Mul* op);
+  virtual void Visit_(const Div* op);
+  virtual void Visit_(const Mod* op);
+  virtual void Visit_(const Min* op);
+  virtual void Visit_(const Max* op);
+  virtual void Visit_(const EQ* op);
+  virtual void Visit_(const NE* op);
+  virtual void Visit_(const LT* op);
+  virtual void Visit_(const LE* op);
+  virtual void Visit_(const GT* op);
+  virtual void Visit_(const GE* op);
+  virtual void Visit_(const And* op);
+  virtual void Visit_(const Or* op);
+  virtual void Visit_(const Reduce* op);
+  virtual void Visit_(const Cast* op);
+  virtual void Visit_(const Not* op);
+  virtual void Visit_(const Select* op);
+  virtual void Visit_(const Ramp* op);
+  virtual void Visit_(const Broadcast* op);
+  virtual void Visit_(const AssertStmt* op);
+  virtual void Visit_(const ProducerConsumer* op);
+  virtual void Visit_(const Provide* op);
+  virtual void Visit_(const Realize* op);
+  virtual void Visit_(const Block* op);
+  virtual void Visit_(const Evaluate* op);
+  virtual void Visit_(const IntImm* op);
+  virtual void Visit_(const UIntImm* op);
+  virtual void Visit_(const FloatImm* op);
+  virtual void Visit_(const StringImm* op);
 };
 
 /*!
diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc
index f10c9c089..07f2b6d21 100644
--- a/src/pass/ir_mutator.cc
+++ b/src/pass/ir_mutator.cc
@@ -16,11 +16,6 @@ IRMutator::FMutateStmt& IRMutator::vtable_stmt() {  // NOLINT(*)
   static FMutateStmt inst; return inst;
 }
 
-// const expr
-inline Expr ReturnSelfExpr(const NodeRef&, const Expr& e, IRMutator*) {
-  return e;
-}
-
 inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
   std::vector<Expr> new_arr(arr.size());
   bool changed = false;
@@ -58,47 +53,33 @@ inline Array<IterVar> MutateIterVarArr(Array<IterVar> rdom, IRMutator *m) {
   }
 }
 
+
+// Mutate Stmt
+
 #define DISPATCH_TO_MUTATE_STMT(OP)                                 \
   set_dispatch<OP>([](const OP* op, const Stmt& s, IRMutator* m) {  \
       return m->Mutate_(op, s);                                     \
     })
 
-#define DISPATCH_TO_MUTATE_EXPR(OP)                                 \
-  set_dispatch<OP>([](const OP* op, const Expr& e, IRMutator* m) {  \
-      return m->Mutate_(op, e);                                     \
-    })
-
-TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
-.DISPATCH_TO_MUTATE_STMT(LetStmt)
-.DISPATCH_TO_MUTATE_STMT(AttrStmt)
-.DISPATCH_TO_MUTATE_STMT(Provide)
-.DISPATCH_TO_MUTATE_STMT(Realize)
-.DISPATCH_TO_MUTATE_STMT(Store)
-.DISPATCH_TO_MUTATE_STMT(IfThenElse)
-.DISPATCH_TO_MUTATE_STMT(For)
-.DISPATCH_TO_MUTATE_STMT(Allocate)
-.DISPATCH_TO_MUTATE_STMT(Block)
-.DISPATCH_TO_MUTATE_STMT(Free);
-
-Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
+Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
   Expr value = this->Mutate(op->value);
   Stmt body = this->Mutate(op->body);
   if (value.same_as(op->value) &&
       body.same_as(op->body)) {
     return s;
   } else {
-    return LetStmt::make(op->var, value, body);
+    return AttrStmt::make(op->node, op->type_key, value, body);
   }
 }
 
-Stmt IRMutator::Mutate_(const AttrStmt* op, const Stmt& s) {
+Stmt IRMutator::Mutate_(const LetStmt *op, const Stmt& s) {
   Expr value = this->Mutate(op->value);
   Stmt body = this->Mutate(op->body);
   if (value.same_as(op->value) &&
       body.same_as(op->body)) {
     return s;
   } else {
-    return AttrStmt::make(op->node, op->type_key, value, body);
+    return LetStmt::make(op->var, value, body);
   }
 }
 
@@ -143,6 +124,36 @@ Stmt IRMutator::Mutate_(const Allocate* op, const Stmt& s) {
   }
 }
 
+Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
+  Expr condition = this->Mutate(op->condition);
+  Stmt then_case = this->Mutate(op->then_case);
+  Stmt else_case;
+  if (else_case.defined()) {
+    else_case = this->Mutate(op->else_case);
+  }
+  if (condition.same_as(op->condition) &&
+      then_case.same_as(op->then_case) &&
+      else_case.same_as(op->else_case)) {
+    return s;
+  } else {
+    return IfThenElse::make(condition, then_case, else_case);
+  }
+}
+
+Stmt IRMutator::Mutate_(const Load *op, const Stmt& s) {
+  return s;
+}
+
+Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
+  Expr value = this->Mutate(op->value);
+  Expr index = this->Mutate(op->index);
+  if (value.same_as(op->value) && index.same_as(op->index)) {
+    return s;
+  } else {
+    return Store::make(op->buffer_var, value, index);
+  }
+}
+
 Stmt IRMutator::Mutate_(const Provide* op, const Stmt& s) {
   auto new_args = MutateArray(op->args, this);
   auto new_value = this->Mutate(op->value);
@@ -183,63 +194,137 @@ Stmt IRMutator::Mutate_(const Realize* op, const Stmt& s) {
   }
 }
 
-Stmt IRMutator::Mutate_(const Store *op, const Stmt& s) {
-  Expr value = this->Mutate(op->value);
-  Expr index = this->Mutate(op->index);
-  if (value.same_as(op->value) && index.same_as(op->index)) {
+Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
+  Stmt first = this->Mutate(op->first);
+  Stmt rest = this->Mutate(op->rest);
+  if (first.same_as(op->first) &&
+      rest.same_as(op->rest)) {
     return s;
   } else {
-    return Store::make(op->buffer_var, value, index);
+    return Block::make(first, rest);
   }
 }
 
-Stmt IRMutator::Mutate_(const Free *op, const Stmt& s) {
-  return s;
-}
-
-Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
+Stmt IRMutator::Mutate_(const AssertStmt *op, const Stmt& s) {
   Expr condition = this->Mutate(op->condition);
-  Stmt then_case = this->Mutate(op->then_case);
-  Stmt else_case;
-  if (else_case.defined()) {
-    else_case = this->Mutate(op->else_case);
-  }
-  if (condition.same_as(op->condition) &&
-      then_case.same_as(op->then_case) &&
-      else_case.same_as(op->else_case)) {
+  Expr message = this->Mutate(op->message);
+
+  if (condition.same_as(op->condition) && message.same_as(op->message)) {
     return s;
   } else {
-    return IfThenElse::make(condition, then_case, else_case);
+    return AssertStmt::make(condition, message);
   }
 }
 
-Stmt IRMutator::Mutate_(const Block* op, const Stmt& s) {
-  Stmt first = this->Mutate(op->first);
-  Stmt rest = this->Mutate(op->rest);
-  if (first.same_as(op->first) &&
-      rest.same_as(op->rest)) {
+Stmt IRMutator::Mutate_(const ProducerConsumer *op, const Stmt& s) {
+  Stmt body = this->Mutate(op->body);
+  if (body.same_as(op->body)) {
     return s;
   } else {
-    return Block::make(first, rest);
+    return ProducerConsumer::make(op->func, op->is_producer, body);
   }
 }
 
-TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
-.DISPATCH_TO_MUTATE_EXPR(Call)
-.DISPATCH_TO_MUTATE_EXPR(Let)
-.DISPATCH_TO_MUTATE_EXPR(Load)
-.DISPATCH_TO_MUTATE_EXPR(Variable);
-
-Expr IRMutator::Mutate_(const Call* op, const Expr& e) {
-  auto new_args = MutateArray(op->args, this);
-  if (op->args.same_as(new_args)) {
-    return e;
+Stmt IRMutator::Mutate_(const Evaluate *op, const Stmt& s) {
+  Expr v = this->Mutate(op->value);
+  if (v.same_as(op->value)) {
+    return s;
   } else {
-    return Call::make(op->type, op->name, new_args, op->call_type,
-                      op->func, op->value_index);
+    return Evaluate::make(v);
   }
 }
 
+#define DEFINE_OP_RETURN_SELF_STMT_MUTATE_(OP)              \
+  Stmt IRMutator::Mutate_(const OP *op, const Stmt& s) {    \
+    return s;                                               \
+  }
+
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Variable)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Let)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Free)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Call)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Add)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Sub)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mul)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Div)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Mod)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Min)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Max)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(EQ)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(NE)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LT)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(LE)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GT)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(GE)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(And)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Or)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Reduce)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Cast)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Not)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Select)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Ramp)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(Broadcast)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(IntImm)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(UIntImm)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(FloatImm)
+DEFINE_OP_RETURN_SELF_STMT_MUTATE_(StringImm)
+
+TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
+.DISPATCH_TO_MUTATE_STMT(Variable)
+.DISPATCH_TO_MUTATE_STMT(LetStmt)
+.DISPATCH_TO_MUTATE_STMT(AttrStmt)
+.DISPATCH_TO_MUTATE_STMT(IfThenElse)
+.DISPATCH_TO_MUTATE_STMT(For)
+.DISPATCH_TO_MUTATE_STMT(Allocate)
+.DISPATCH_TO_MUTATE_STMT(Load)
+.DISPATCH_TO_MUTATE_STMT(Store)
+.DISPATCH_TO_MUTATE_STMT(Let)
+.DISPATCH_TO_MUTATE_STMT(Free)
+.DISPATCH_TO_MUTATE_STMT(Call)
+.DISPATCH_TO_MUTATE_STMT(Add)
+.DISPATCH_TO_MUTATE_STMT(Sub)
+.DISPATCH_TO_MUTATE_STMT(Mul)
+.DISPATCH_TO_MUTATE_STMT(Div)
+.DISPATCH_TO_MUTATE_STMT(Mod)
+.DISPATCH_TO_MUTATE_STMT(Min)
+.DISPATCH_TO_MUTATE_STMT(Max)
+.DISPATCH_TO_MUTATE_STMT(EQ)
+.DISPATCH_TO_MUTATE_STMT(NE)
+.DISPATCH_TO_MUTATE_STMT(LT)
+.DISPATCH_TO_MUTATE_STMT(LE)
+.DISPATCH_TO_MUTATE_STMT(GT)
+.DISPATCH_TO_MUTATE_STMT(GE)
+.DISPATCH_TO_MUTATE_STMT(And)
+.DISPATCH_TO_MUTATE_STMT(Or)
+.DISPATCH_TO_MUTATE_STMT(Reduce)
+.DISPATCH_TO_MUTATE_STMT(Cast)
+.DISPATCH_TO_MUTATE_STMT(Not)
+.DISPATCH_TO_MUTATE_STMT(Select)
+.DISPATCH_TO_MUTATE_STMT(Ramp)
+.DISPATCH_TO_MUTATE_STMT(Broadcast)
+.DISPATCH_TO_MUTATE_STMT(AssertStmt)
+.DISPATCH_TO_MUTATE_STMT(ProducerConsumer)
+.DISPATCH_TO_MUTATE_STMT(Provide)
+.DISPATCH_TO_MUTATE_STMT(Realize)
+.DISPATCH_TO_MUTATE_STMT(Block)
+.DISPATCH_TO_MUTATE_STMT(Evaluate)
+.DISPATCH_TO_MUTATE_STMT(IntImm)
+.DISPATCH_TO_MUTATE_STMT(UIntImm)
+.DISPATCH_TO_MUTATE_STMT(FloatImm)
+.DISPATCH_TO_MUTATE_STMT(StringImm);
+
+
+// Mutate Expr
+
+#define DISPATCH_TO_MUTATE_EXPR(OP)                                 \
+  set_dispatch<OP>([](const OP* op, const Expr& e, IRMutator* m) {  \
+      return m->Mutate_(op, e);                                     \
+    })
+
+Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
+  return e;
+}
+
 Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
   Expr index = this->Mutate(op->index);
   if (index.same_as(op->index)) {
@@ -249,11 +334,6 @@ Expr IRMutator::Mutate_(const Load *op, const Expr& e) {
   }
 }
 
-
-Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
-  return e;
-}
-
 Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
   Expr value = this->Mutate(op->value);
   Expr body = this->Mutate(op->body);
@@ -265,130 +345,172 @@ Expr IRMutator::Mutate_(const Let *op, const Expr& e) {
   }
 }
 
-TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
-.set_dispatch<Reduce>([](const Reduce* op, const Expr& e, IRMutator* m) {
-    Array<IterVar> new_axis  = MutateIterVarArr(op->axis, m);
-    Expr new_source  = m->Mutate(op->source);
-    if (op->axis.same_as(new_axis) &&
-        op->source.same_as(new_source)) {
-      return e;
-    } else {
-      return Reduce::make(op->op, new_source, new_axis);
-    }
-  });
+Expr IRMutator::Mutate_(const Call* op, const Expr& e) {
+  auto new_args = MutateArray(op->args, this);
+  if (op->args.same_as(new_args)) {
+    return e;
+  } else {
+    return Call::make(op->type, op->name, new_args, op->call_type,
+                      op->func, op->value_index);
+  }
+}
 
-TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
-.set_dispatch<IntImm>(ReturnSelfExpr)
-.set_dispatch<UIntImm>(ReturnSelfExpr)
-.set_dispatch<FloatImm>(ReturnSelfExpr)
-.set_dispatch<StringImm>(ReturnSelfExpr);
+#define DEFINE_BIOP_EXPR_MUTATE_(OP)                        \
+  Expr IRMutator::Mutate_(const OP* op, const Expr& e) {    \
+    Expr a = this->Mutate(op->a);                           \
+    Expr b = this->Mutate(op->b);                           \
+    if (a.same_as(op->a) &&                                 \
+        b.same_as(op->b)) {                                 \
+      return e;                                             \
+    } else {                                                \
+      return OP::make(a, b);                                 \
+    }                                                       \
+  }
 
-TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
-.set_dispatch<Cast>([](const Cast* op, const Expr& e, IRMutator* m) {
-    Expr value = m->Mutate(op->value);
-    if (value.same_as(op->value)) {
-      return e;
-    } else {
-      return Cast::make(op->type, value);
-    }
-  });
-
-// binary operator
-template<typename T>
-inline Expr Binary(const T* op, const Expr& e, IRMutator* m) {
-  Expr a = m->Mutate(op->a);
-  Expr b = m->Mutate(op->b);
-  if (a.same_as(op->a) &&
-      b.same_as(op->b)) {
+DEFINE_BIOP_EXPR_MUTATE_(Add)
+DEFINE_BIOP_EXPR_MUTATE_(Sub)
+DEFINE_BIOP_EXPR_MUTATE_(Mul)
+DEFINE_BIOP_EXPR_MUTATE_(Div)
+DEFINE_BIOP_EXPR_MUTATE_(Mod)
+DEFINE_BIOP_EXPR_MUTATE_(Min)
+DEFINE_BIOP_EXPR_MUTATE_(Max)
+DEFINE_BIOP_EXPR_MUTATE_(EQ)
+DEFINE_BIOP_EXPR_MUTATE_(NE)
+DEFINE_BIOP_EXPR_MUTATE_(LT)
+DEFINE_BIOP_EXPR_MUTATE_(LE)
+DEFINE_BIOP_EXPR_MUTATE_(GT)
+DEFINE_BIOP_EXPR_MUTATE_(GE)
+DEFINE_BIOP_EXPR_MUTATE_(And)
+DEFINE_BIOP_EXPR_MUTATE_(Or)
+
+Expr IRMutator::Mutate_(const Reduce *op, const Expr& e) {
+  Array<IterVar> new_axis  = MutateIterVarArr(op->axis, this);
+  Expr new_source = this->Mutate(op->source);
+  if (op->axis.same_as(new_axis) &&
+      op->source.same_as(new_source)) {
     return e;
   } else {
-    return T::make(a, b);
+    return Reduce::make(op->op, new_source, new_axis);
   }
 }
 
-TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
-.set_dispatch<Add>(Binary<Add>)
-.set_dispatch<Sub>(Binary<Sub>)
-.set_dispatch<Mul>(Binary<Mul>)
-.set_dispatch<Div>(Binary<Div>)
-.set_dispatch<Mod>(Binary<Mod>)
-.set_dispatch<Min>(Binary<Min>)
-.set_dispatch<Max>(Binary<Max>)
-.set_dispatch<EQ>(Binary<EQ>)
-.set_dispatch<NE>(Binary<NE>)
-.set_dispatch<LT>(Binary<LT>)
-.set_dispatch<LE>(Binary<LE>)
-.set_dispatch<GT>(Binary<GT>)
-.set_dispatch<GE>(Binary<GE>)
-.set_dispatch<And>(Binary<And>)
-.set_dispatch<Or>(Binary<Or>);
+Expr IRMutator::Mutate_(const Cast *op, const Expr& e) {
+  Expr value = this->Mutate(op->value);
+  if (value.same_as(op->value)) {
+    return e;
+  } else {
+    return Cast::make(op->type, value);
+  }
+}
+
+Expr IRMutator::Mutate_(const Not *op, const Expr& e) {
+  Expr a = this->Mutate(op->a);
+  if (a.same_as(op->a)) {
+    return e;
+  } else {
+    return Not::make(a);
+  }
+}
+
+Expr IRMutator::Mutate_(const Select *op, const Expr& e) {
+  Expr cond = this->Mutate(op->condition);
+  Expr t = this->Mutate(op->true_value);
+  Expr f = this->Mutate(op->false_value);
+  if (cond.same_as(op->condition) &&
+      t.same_as(op->true_value) &&
+      f.same_as(op->false_value)) {
+    return e;
+  } else {
+    return Select::make(cond, t, f);
+  }
+}
+
+Expr IRMutator::Mutate_(const Ramp *op, const Expr& e) {
+  Expr base = this->Mutate(op->base);
+  Expr stride = this->Mutate(op->stride);
+  if (base.same_as(op->base) &&
+      stride.same_as(op->stride)) {
+    return e;
+  } else {
+    return Ramp::make(base, stride, op->lanes);
+  }
+}
+
+Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
+  Expr value = this->Mutate(op->value);
+  if (value.same_as(op->value)) {
+    return e;
+  } else {
+    return Broadcast::make(value, op->lanes);
+  }
+}
+
+#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP)              \
+  Expr IRMutator::Mutate_(const OP *op, const Expr& e) {    \
+    return e;                                               \
+  }
+
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(LetStmt)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AttrStmt)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(For)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IfThenElse)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Allocate)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Store)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Free)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(AssertStmt)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(ProducerConsumer)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Provide)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Realize)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Block)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(Evaluate)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(IntImm)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(UIntImm)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(FloatImm)
+DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(StringImm)
 
 TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
-.set_dispatch<Not>([](const Not* op, const Expr& e, IRMutator* m) {
-    Expr a = m->Mutate(op->a);
-    if (a.same_as(op->a)) {
-      return e;
-    } else {
-      return Not::make(a);
-    }
-  })
-.set_dispatch<Select>([](const Select *op, const Expr& e, IRMutator* m) {
-    Expr cond = m->Mutate(op->condition);
-    Expr t = m->Mutate(op->true_value);
-    Expr f = m->Mutate(op->false_value);
-    if (cond.same_as(op->condition) &&
-        t.same_as(op->true_value) &&
-        f.same_as(op->false_value)) {
-      return e;
-    } else {
-      return Select::make(cond, t, f);
-    }
-  })
-.set_dispatch<Ramp>([](const Ramp *op, const Expr& e, IRMutator* m) {
-    Expr base = m->Mutate(op->base);
-    Expr stride = m->Mutate(op->stride);
-    if (base.same_as(op->base) &&
-        stride.same_as(op->stride)) {
-      return e;
-    } else {
-      return Ramp::make(base, stride, op->lanes);
-    }
-  })
-.set_dispatch<Broadcast>([](const Broadcast *op, const Expr& e, IRMutator* m) {
-    Expr value = m->Mutate(op->value);
-    if (value.same_as(op->value)) {
-      return e;
-    } else {
-      return Broadcast::make(value, op->lanes);
-    }
-  });
+.DISPATCH_TO_MUTATE_EXPR(Variable)
+.DISPATCH_TO_MUTATE_EXPR(LetStmt)
+.DISPATCH_TO_MUTATE_EXPR(AttrStmt)
+.DISPATCH_TO_MUTATE_EXPR(IfThenElse)
+.DISPATCH_TO_MUTATE_EXPR(For)
+.DISPATCH_TO_MUTATE_EXPR(Allocate)
+.DISPATCH_TO_MUTATE_EXPR(Load)
+.DISPATCH_TO_MUTATE_EXPR(Store)
+.DISPATCH_TO_MUTATE_EXPR(Let)
+.DISPATCH_TO_MUTATE_EXPR(Free)
+.DISPATCH_TO_MUTATE_EXPR(Call)
+.DISPATCH_TO_MUTATE_EXPR(Add)
+.DISPATCH_TO_MUTATE_EXPR(Sub)
+.DISPATCH_TO_MUTATE_EXPR(Mul)
+.DISPATCH_TO_MUTATE_EXPR(Div)
+.DISPATCH_TO_MUTATE_EXPR(Mod)
+.DISPATCH_TO_MUTATE_EXPR(Min)
+.DISPATCH_TO_MUTATE_EXPR(Max)
+.DISPATCH_TO_MUTATE_EXPR(EQ)
+.DISPATCH_TO_MUTATE_EXPR(NE)
+.DISPATCH_TO_MUTATE_EXPR(LT)
+.DISPATCH_TO_MUTATE_EXPR(LE)
+.DISPATCH_TO_MUTATE_EXPR(GT)
+.DISPATCH_TO_MUTATE_EXPR(GE)
+.DISPATCH_TO_MUTATE_EXPR(And)
+.DISPATCH_TO_MUTATE_EXPR(Or)
+.DISPATCH_TO_MUTATE_EXPR(Reduce)
+.DISPATCH_TO_MUTATE_EXPR(Cast)
+.DISPATCH_TO_MUTATE_EXPR(Not)
+.DISPATCH_TO_MUTATE_EXPR(Select)
+.DISPATCH_TO_MUTATE_EXPR(Ramp)
+.DISPATCH_TO_MUTATE_EXPR(Broadcast)
+.DISPATCH_TO_MUTATE_EXPR(AssertStmt)
+.DISPATCH_TO_MUTATE_EXPR(ProducerConsumer)
+.DISPATCH_TO_MUTATE_EXPR(Provide)
+.DISPATCH_TO_MUTATE_EXPR(Realize)
+.DISPATCH_TO_MUTATE_EXPR(Block)
+.DISPATCH_TO_MUTATE_EXPR(Evaluate)
+.DISPATCH_TO_MUTATE_EXPR(IntImm)
+.DISPATCH_TO_MUTATE_EXPR(UIntImm)
+.DISPATCH_TO_MUTATE_EXPR(FloatImm)
+.DISPATCH_TO_MUTATE_EXPR(StringImm);
 
-TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_stmt)
-.set_dispatch<AssertStmt>([](const AssertStmt *op, const Stmt& s, IRMutator* m) {
-    Expr condition = m->Mutate(op->condition);
-    Expr message = m->Mutate(op->message);
-
-    if (condition.same_as(op->condition) && message.same_as(op->message)) {
-      return s;
-    } else {
-      return AssertStmt::make(condition, message);
-    }
-  })
-.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, const Stmt& s, IRMutator* m) {
-    Stmt body = m->Mutate(op->body);
-    if (body.same_as(op->body)) {
-      return s;
-    } else {
-      return ProducerConsumer::make(op->func, op->is_producer, body);
-    }
-  })
-.set_dispatch<Evaluate>([](const Evaluate *op, const Stmt& s, IRMutator* m) {
-    Expr v = m->Mutate(op->value);
-    if (v.same_as(op->value)) {
-      return s;
-    } else {
-      return Evaluate::make(v);
-    }
-  });
 }  // namespace ir
 }  // namespace tvm
diff --git a/src/pass/ir_visitor.cc b/src/pass/ir_visitor.cc
index 5baaa8519..f82f9130f 100644
--- a/src/pass/ir_visitor.cc
+++ b/src/pass/ir_visitor.cc
@@ -34,9 +34,6 @@ IRVisitor::FVisit& IRVisitor::vtable() {  // NOLINT(*)
   static FVisit inst; return inst;
 }
 
-void NoOp(const NodeRef& n, IRVisitor* v) {
-}
-
 inline void VisitArray(const Array<Expr>& arr, IRVisitor* v) {
   for (size_t i = 0; i < arr.size(); i++) {
     v->Visit(arr[i]);
@@ -51,24 +48,6 @@ inline void VisitRDom(const Array<IterVar>& rdom, IRVisitor* v) {
   }
 }
 
-#define DISPATCH_TO_VISIT(OP)                       \
-  set_dispatch<OP>([](const OP* op, IRVisitor* v) { \
-      v->Visit_(op);                                \
-    })
-
-TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
-.DISPATCH_TO_VISIT(Variable)
-.DISPATCH_TO_VISIT(LetStmt)
-.DISPATCH_TO_VISIT(AttrStmt)
-.DISPATCH_TO_VISIT(IfThenElse)
-.DISPATCH_TO_VISIT(For)
-.DISPATCH_TO_VISIT(Allocate)
-.DISPATCH_TO_VISIT(Load)
-.DISPATCH_TO_VISIT(Store)
-.DISPATCH_TO_VISIT(Let)
-.DISPATCH_TO_VISIT(Call)
-.DISPATCH_TO_VISIT(Free);
-
 void IRVisitor::Visit_(const Variable* op) {}
 
 void IRVisitor::Visit_(const LetStmt *op) {
@@ -128,91 +107,146 @@ void IRVisitor::Visit_(const Call *op) {
   VisitArray(op->args, this);
 }
 
-TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
-.set_dispatch<Reduce>([](const Reduce* op, IRVisitor* v) {
-    VisitRDom(op->axis, v);
-    v->Visit(op->source);
-  })
-.set_dispatch<IntImm>(NoOp)
-.set_dispatch<UIntImm>(NoOp)
-.set_dispatch<FloatImm>(NoOp)
-.set_dispatch<StringImm>(NoOp);
+#define DEFINE_BINOP_VISIT_(OP)                     \
+  void IRVisitor::Visit_(const OP* op) {            \
+    this->Visit(op->a);                             \
+    this->Visit(op->b);                             \
+  }
 
-TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
-.set_dispatch<Cast>([](const Cast* op, IRVisitor* v) {
-    v->Visit(op->value);
-  });
+DEFINE_BINOP_VISIT_(Add)
+DEFINE_BINOP_VISIT_(Sub)
+DEFINE_BINOP_VISIT_(Mul)
+DEFINE_BINOP_VISIT_(Div)
+DEFINE_BINOP_VISIT_(Mod)
+DEFINE_BINOP_VISIT_(Min)
+DEFINE_BINOP_VISIT_(Max)
+DEFINE_BINOP_VISIT_(EQ)
+DEFINE_BINOP_VISIT_(NE)
+DEFINE_BINOP_VISIT_(LT)
+DEFINE_BINOP_VISIT_(LE)
+DEFINE_BINOP_VISIT_(GT)
+DEFINE_BINOP_VISIT_(GE)
+DEFINE_BINOP_VISIT_(And)
+DEFINE_BINOP_VISIT_(Or)
+
+void IRVisitor::Visit_(const Reduce* op) {
+  VisitRDom(op->axis, this);
+  this->Visit(op->source);
+}
 
-// binary operator
-template<typename T>
-inline void Binary(const T* op, IRVisitor* v) {
-  v->Visit(op->a);
-  v->Visit(op->b);
+void IRVisitor::Visit_(const Cast* op) {
+  this->Visit(op->value);
 }
 
-TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
-.set_dispatch<Add>(Binary<Add>)
-.set_dispatch<Sub>(Binary<Sub>)
-.set_dispatch<Mul>(Binary<Mul>)
-.set_dispatch<Div>(Binary<Div>)
-.set_dispatch<Mod>(Binary<Mod>)
-.set_dispatch<Min>(Binary<Min>)
-.set_dispatch<Max>(Binary<Max>)
-.set_dispatch<EQ>(Binary<EQ>)
-.set_dispatch<NE>(Binary<NE>)
-.set_dispatch<LT>(Binary<LT>)
-.set_dispatch<LE>(Binary<LE>)
-.set_dispatch<GT>(Binary<GT>)
-.set_dispatch<GE>(Binary<GE>)
-.set_dispatch<And>(Binary<And>)
-.set_dispatch<Or>(Binary<Or>);
+void IRVisitor::Visit_(const Not* op) {
+  this->Visit(op->a);
+}
 
-TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
-.set_dispatch<Not>([](const Not* op, IRVisitor* v) {
-    v->Visit(op->a);
-  })
-.set_dispatch<Select>([](const Select *op, IRVisitor* v) {
-    v->Visit(op->condition);
-    v->Visit(op->true_value);
-    v->Visit(op->false_value);
-  })
-.set_dispatch<Ramp>([](const Ramp *op, IRVisitor* v) {
-    v->Visit(op->base);
-    v->Visit(op->stride);
-  })
-.set_dispatch<Broadcast>([](const Broadcast *op, IRVisitor* v) {
-    v->Visit(op->value);
-  });
+void IRVisitor::Visit_(const Select* op) {
+  this->Visit(op->condition);
+  this->Visit(op->true_value);
+  this->Visit(op->false_value);
+}
 
-TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
-.set_dispatch<AssertStmt>([](const AssertStmt *op, IRVisitor* v) {
-    v->Visit(op->condition);
-    v->Visit(op->message);
-  })
-.set_dispatch<ProducerConsumer>([](const ProducerConsumer *op, IRVisitor* v) {
-    v->Visit(op->body);
-  })
-.set_dispatch<Provide>([](const Provide *op, IRVisitor* v) {
-    VisitArray(op->args, v);
-    v->Visit(op->value);
-  })
-.set_dispatch<Realize>([](const Realize *op, IRVisitor* v) {
+void IRVisitor::Visit_(const Ramp *op) {
+  this->Visit(op->base);
+  this->Visit(op->stride);
+}
+
+void IRVisitor::Visit_(const Broadcast *op) {
+  this->Visit(op->value);
+}
+
+void IRVisitor::Visit_(const AssertStmt *op) {
+  this->Visit(op->condition);
+  this->Visit(op->message);
+}
+
+void IRVisitor::Visit_(const ProducerConsumer *op) {
+  this->Visit(op->body);
+}
+
+void IRVisitor::Visit_(const Provide *op) {
+  VisitArray(op->args, this);
+  this->Visit(op->value);
+}
+
+void IRVisitor::Visit_(const Realize *op) {
     // Mutate the bounds
-    for (size_t i = 0; i < op->bounds.size(); i++) {
-      v->Visit(op->bounds[i]->min);
-      v->Visit(op->bounds[i]->extent);
-    }
-
-    v->Visit(op->body);
-    v->Visit(op->condition);
-  })
-.set_dispatch<Block>([](const Block *op, IRVisitor* v) {
-    v->Visit(op->first);
-    v->Visit(op->rest);
-  })
-.set_dispatch<Evaluate>([](const Evaluate *op, IRVisitor* v) {
-    v->Visit(op->value);
-  });
+  for (size_t i = 0; i < op->bounds.size(); i++) {
+    this->Visit(op->bounds[i]->min);
+    this->Visit(op->bounds[i]->extent);
+  }
+
+  this->Visit(op->body);
+  this->Visit(op->condition);
+}
+
+void IRVisitor::Visit_(const Block *op) {
+  this->Visit(op->first);
+  this->Visit(op->rest);
+}
+
+void IRVisitor::Visit_(const Evaluate *op) {
+  this->Visit(op->value);
+}
+
+#define DEFINE_OP_NO_VISIT_(OP)                     \
+  void IRVisitor::Visit_(const OP* op) {}
+
+DEFINE_OP_NO_VISIT_(IntImm)
+DEFINE_OP_NO_VISIT_(UIntImm)
+DEFINE_OP_NO_VISIT_(FloatImm)
+DEFINE_OP_NO_VISIT_(StringImm)
+
+#define DISPATCH_TO_VISIT(OP)                       \
+  set_dispatch<OP>([](const OP* op, IRVisitor* v) { \
+      v->Visit_(op);                                \
+    })
+
+TVM_STATIC_IR_FUNCTOR(IRVisitor, vtable)
+.DISPATCH_TO_VISIT(Variable)
+.DISPATCH_TO_VISIT(LetStmt)
+.DISPATCH_TO_VISIT(AttrStmt)
+.DISPATCH_TO_VISIT(IfThenElse)
+.DISPATCH_TO_VISIT(For)
+.DISPATCH_TO_VISIT(Allocate)
+.DISPATCH_TO_VISIT(Load)
+.DISPATCH_TO_VISIT(Store)
+.DISPATCH_TO_VISIT(Let)
+.DISPATCH_TO_VISIT(Free)
+.DISPATCH_TO_VISIT(Call)
+.DISPATCH_TO_VISIT(Add)
+.DISPATCH_TO_VISIT(Sub)
+.DISPATCH_TO_VISIT(Mul)
+.DISPATCH_TO_VISIT(Div)
+.DISPATCH_TO_VISIT(Mod)
+.DISPATCH_TO_VISIT(Min)
+.DISPATCH_TO_VISIT(Max)
+.DISPATCH_TO_VISIT(EQ)
+.DISPATCH_TO_VISIT(NE)
+.DISPATCH_TO_VISIT(LT)
+.DISPATCH_TO_VISIT(LE)
+.DISPATCH_TO_VISIT(GT)
+.DISPATCH_TO_VISIT(GE)
+.DISPATCH_TO_VISIT(And)
+.DISPATCH_TO_VISIT(Or)
+.DISPATCH_TO_VISIT(Reduce)
+.DISPATCH_TO_VISIT(Cast)
+.DISPATCH_TO_VISIT(Not)
+.DISPATCH_TO_VISIT(Select)
+.DISPATCH_TO_VISIT(Ramp)
+.DISPATCH_TO_VISIT(Broadcast)
+.DISPATCH_TO_VISIT(AssertStmt)
+.DISPATCH_TO_VISIT(ProducerConsumer)
+.DISPATCH_TO_VISIT(Provide)
+.DISPATCH_TO_VISIT(Realize)
+.DISPATCH_TO_VISIT(Block)
+.DISPATCH_TO_VISIT(Evaluate)
+.DISPATCH_TO_VISIT(IntImm)
+.DISPATCH_TO_VISIT(UIntImm)
+.DISPATCH_TO_VISIT(FloatImm)
+.DISPATCH_TO_VISIT(StringImm);
 
 }  // namespace ir
 }  // namespace tvm
-- 
GitLab