diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h
index 340a7110ae4898a947b1bd31db7b62138c9b6421..904a16530fc76124b66801e8dd15703ff304bf67 100644
--- a/include/tvm/ir_mutator.h
+++ b/include/tvm/ir_mutator.h
@@ -102,6 +102,25 @@ class IRMutator {
   virtual Expr Mutate_(const Shuffle* op, const Expr& e);
 };
 
+/*!
+ * \brief recursively visit the ir in post DFS order node, and transform it
+ *
+ * \param node The ir to be transformed.
+ * \param preorder The function called in before recursive mutation
+ *          If preorder returns None, then the transform will proceed to recursive call.
+ *          If preorder returns a not None Stmt/Expr, the transformer will simply return it and
+ *          won't do further recursion.
+ * \param postorder The function called after recursive mutation.
+ *          The recursive mutation result is passed to postorder for further mutation.
+ * \param only_enable List of StringImm.
+ *          If it is empty, all IRNode will call preorder/postorder
+ *          If it is not empty, preorder/postorder will only be called
+ *          when the IRNode's type key is in the list.
+ */
+Stmt IRTransform(const Stmt& node,
+                 const runtime::PackedFunc& preorder,
+                 const runtime::PackedFunc& postorder,
+                 const Array<Expr>& only_enable = {});
 }  // namespace ir
 }  // namespace tvm
 #endif  // TVM_IR_MUTATOR_H_
diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc
index 4868db657d112d56cb7aa3ae9191961b68e415c9..e5505b68df64dfd5ffbc686058d555274c340969 100644
--- a/src/api/api_pass.cc
+++ b/src/api/api_pass.cc
@@ -7,6 +7,7 @@
 #include <tvm/ir.h>
 #include <tvm/ir_pass.h>
 #include <tvm/ir_visitor.h>
+#include <tvm/ir_mutator.h>
 #include <tvm/api_registry.h>
 
 namespace tvm {
@@ -88,6 +89,7 @@ REGISTER_PASS1(VerifySSA);
 REGISTER_PASS1(RewriteUnsafeSelect);
 REGISTER_PASS4(Inline);
 REGISTER_PASS3(StorageFlatten);
+REGISTER_PASS4(IRTransform);
 REGISTER_PASS1(VectorizeLoop);
 REGISTER_PASS4(UnrollLoop);
 REGISTER_PASS2(ThreadSync);
diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc
index 923f004b530ea9ace04670693f300315f1fc70d2..4e247ed2bf4c2afdd1100ca563a30b8fdaacd2b3 100644
--- a/src/api/dsl_api.cc
+++ b/src/api/dsl_api.cc
@@ -110,7 +110,7 @@ class DSLAPIImpl : public DSLAPI {
     *out_index = static_cast<int>(Node::TypeKey2Index(type_key));
   }
   void NodeGetTypeIndex(NodeHandle handle,
-                          int* out_index) const final {
+                        int* out_index) const final {
     *out_index = static_cast<int>(
         (*static_cast<TVMAPINode*>(handle))->type_index());
   }
diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc
index ef8e1f8771d93c90d6a330c472caa9f0ba1bd4c7..fa77942b6058ca03c6c96969f451dd55779f38f3 100644
--- a/src/pass/coproc_sync.cc
+++ b/src/pass/coproc_sync.cc
@@ -11,7 +11,6 @@
 #include "./ir_util.h"
 #include "./storage_access.h"
 
-
 namespace tvm {
 namespace ir {
 
diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc
index efe7e5ee13d81d83beae64a283aa9ea96127c92b..993b68f835d743f408775340f5925135a39b1c46 100644
--- a/src/pass/ir_mutator.cc
+++ b/src/pass/ir_mutator.cc
@@ -4,11 +4,65 @@
  */
 #include <tvm/ir.h>
 #include <tvm/ir_mutator.h>
+#include <tvm/packed_func_ext.h>
 #include "./ir_util.h"
 
 namespace tvm {
 namespace ir {
 
+class IRTransformer final : public IRMutator {
+ public:
+  IRTransformer(const runtime::PackedFunc& f_preorder,
+                const runtime::PackedFunc& f_postorder,
+                const std::unordered_set<uint32_t>& only_enable)
+      : f_preorder_(f_preorder),
+        f_postorder_(f_postorder),
+        only_enable_(only_enable) {
+  }
+  Stmt Mutate(Stmt stmt) final {
+    return MutateInternal<Stmt>(stmt);
+  }
+  Expr Mutate(Expr expr) final {
+    return MutateInternal<Expr>(expr);
+  }
+
+ private:
+  template<typename T>
+  T MutateInternal(T node) {
+    if (only_enable_.size() &&
+        !only_enable_.count(node->type_index())) {
+      return IRMutator::Mutate(node);
+    }
+    if (f_preorder_ != nullptr) {
+      T pre = f_preorder_(node);
+      if (pre.defined()) return pre;
+    }
+    node = IRMutator::Mutate(node);
+    if (f_postorder_ != nullptr) {
+      T post = f_postorder_(node);
+      if (post.defined()) return post;
+    }
+    return node;
+  }
+  // The functions
+  const runtime::PackedFunc& f_preorder_;
+  const runtime::PackedFunc& f_postorder_;
+  // type indices enabled.
+  const std::unordered_set<uint32_t>& only_enable_;
+};
+
+Stmt IRTransform(const Stmt& ir_node,
+                 const runtime::PackedFunc& f_preorder,
+                 const runtime::PackedFunc& f_postorder,
+                 const Array<Expr>& only_enable) {
+  std::unordered_set<uint32_t> only_type_index;
+  for (Expr s : only_enable) {
+    only_type_index.insert(Node::TypeKey2Index(s.as<StringImm>()->value.c_str()));
+  }
+  return IRTransformer(f_preorder, f_postorder, only_type_index)
+      .Mutate(ir_node);
+}
+
 IRMutator::FMutateExpr& IRMutator::vtable_expr() {  // NOLINT(*)
   static FMutateExpr inst; return inst;
 }
diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py
index 32f341bf25250f9b99747387183690a913b5d1f6..412effa2f0a27ab46511e513b27abb8aedc65144 100644
--- a/tests/python/unittest/test_arith_detect_linear_equation.py
+++ b/tests/python/unittest/test_arith_detect_linear_equation.py
@@ -14,5 +14,11 @@ def test_basic():
     assert m[1].value == 5
     assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7 + 1)).value == 0
 
+    m = tvm.arith.DetectLinearEquation(a * b + 7, a)
+    assert m[1] == b
+
+    m = tvm.arith.DetectLinearEquation(b * 7, a)
+    assert m[1].value == 0
+
 if __name__ == "__main__":
     test_basic()
diff --git a/tests/python/unittest/test_pass_ir_transform.py b/tests/python/unittest/test_pass_ir_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..87c82c5e5f5ff427b8069cbd12863722926a40f6
--- /dev/null
+++ b/tests/python/unittest/test_pass_ir_transform.py
@@ -0,0 +1,29 @@
+import tvm
+
+def test_ir_transform():
+    ib = tvm.ir_builder.create()
+    n = tvm.var("n")
+    with ib.for_range(0, n, name="i") as i:
+        with ib.for_range(0, 10, name="j") as j:
+            x = tvm.call_extern("int32", "TestA", i * 3 + j * 1)
+            ib.emit(tvm.call_extern("int32", "TestB", x))
+            ib.emit(tvm.call_extern("int32", "TestC", x))
+    body = ib.get()
+
+    def preorder(op):
+        if op.name == "TestC":
+            return tvm.const(0, "int32")
+        return None
+
+    def postorder(op):
+        assert isinstance(op, tvm.expr.Call)
+        if op.name == "TestA":
+            return tvm.call_extern("int32", "TestB", op.args[0] + 1)
+        return op
+    body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"])
+    stmt_list = tvm.make.stmt_list(body.body.body)
+    assert stmt_list[0].value.args[0].name == "TestB"
+    assert stmt_list[1].value.value == 0
+
+if __name__ == "__main__":
+    test_ir_transform()