From ff6b8d82116bb26d4e8592a4d236b8020571e17a Mon Sep 17 00:00:00 2001
From: tqchen <tianqi.tchen@gmail.com>
Date: Thu, 17 Nov 2016 20:55:47 -0800
Subject: [PATCH] check substitute

---
 include/tvm/ir_mutator.h     |  8 ++++++++
 src/pass/ir_mutator.cc       | 26 ++++++++++++++++++++++++++
 tests/cpp/ir_mutator_test.cc | 19 +++++++++++++++++++
 3 files changed, 53 insertions(+)

diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h
index 3056fc503..f0c2dcb4b 100644
--- a/include/tvm/ir_mutator.h
+++ b/include/tvm/ir_mutator.h
@@ -7,6 +7,7 @@
 #define TVM_IR_MUTATOR_H_
 
 #include <tvm/ir_node.h>
+#include <unordered_map>
 #include "./expr.h"
 
 namespace tvm {
@@ -72,6 +73,13 @@ class IRMutatorExample : public IRMutator {
   static FMutateStmt& vtable_stmt();  // NOLINT(*)
 };
 
+/*!
+ * \brief Substitute occurance of IRNode to be expr
+ * \param replacements The replacement rule of substitution
+ * \param expr The expression to be substituted.
+ */
+Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr);
+
 }  // namespace ir
 }  // namespace tvm
 #endif  // TVM_IR_MUTATOR_H_
diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc
index 264f1c669..0b19c348d 100644
--- a/src/pass/ir_mutator.cc
+++ b/src/pass/ir_mutator.cc
@@ -8,6 +8,32 @@
 namespace tvm {
 namespace ir {
 
+namespace {
+// visitor to implement apply
+class IRSubstitute : public IRMutator {
+ public:
+  Expr mutate(Expr expr) final {
+    const IRNode* v = expr.get();
+    if (v != nullptr) {
+      auto it = replacements_.find(v);
+      if (it != replacements_.end()) {
+        return it->second;
+      }
+    }
+    return IRMutator::mutate(expr);
+  }
+  explicit IRSubstitute(const std::unordered_map<const IRNode*, Expr>& replacements)
+      : replacements_(replacements) {}
+
+ private:
+  const std::unordered_map<const IRNode*, Expr>& replacements_;
+};
+}  // namespace
+
+Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr) {
+  return IRSubstitute(replacements).mutate(expr);
+}
+
 IRMutator::FMutateExpr& IRMutator::vtable_expr() {  // NOLINT(*)
   static FMutateExpr inst; return inst;
 }
diff --git a/tests/cpp/ir_mutator_test.cc b/tests/cpp/ir_mutator_test.cc
index 93843d7ff..8a56e5097 100644
--- a/tests/cpp/ir_mutator_test.cc
+++ b/tests/cpp/ir_mutator_test.cc
@@ -52,6 +52,25 @@ TEST(IRMutator, Basic) {
   CHECK(os.str() == "(x + 10)");
 }
 
+TEST(IRMutator, Substitute) {
+  using namespace Halide::Internal;
+  using namespace tvm;
+  Var x("x"), y;
+  auto z = x + y;
+  {
+    auto zz = Substitute({{y.get(), 11}}, z);
+    std::ostringstream os;
+    os << zz;
+    CHECK(os.str() == "(x + 11)");
+  }
+  {
+    auto zz = Substitute({{z.get(), 11}}, z);
+    std::ostringstream os;
+    os << zz;
+    CHECK(os.str() == "11");
+  }
+}
+
 int main(int argc, char ** argv) {
   testing::InitGoogleTest(&argc, argv);
   testing::FLAGS_gtest_death_test_style = "threadsafe";
-- 
GitLab