From ff6b8d82116bb26d4e8592a4d236b8020571e17a Mon Sep 17 00:00:00 2001 From: tqchen <tianqi.tchen@gmail.com> Date: Thu, 17 Nov 2016 20:55:47 -0800 Subject: [PATCH] check substitute --- include/tvm/ir_mutator.h | 8 ++++++++ src/pass/ir_mutator.cc | 26 ++++++++++++++++++++++++++ tests/cpp/ir_mutator_test.cc | 19 +++++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 3056fc503..f0c2dcb4b 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -7,6 +7,7 @@ #define TVM_IR_MUTATOR_H_ #include <tvm/ir_node.h> +#include <unordered_map> #include "./expr.h" namespace tvm { @@ -72,6 +73,13 @@ class IRMutatorExample : public IRMutator { static FMutateStmt& vtable_stmt(); // NOLINT(*) }; +/*! + * \brief Substitute occurance of IRNode to be expr + * \param replacements The replacement rule of substitution + * \param expr The expression to be substituted. + */ +Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr); + } // namespace ir } // namespace tvm #endif // TVM_IR_MUTATOR_H_ diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index 264f1c669..0b19c348d 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -8,6 +8,32 @@ namespace tvm { namespace ir { +namespace { +// visitor to implement apply +class IRSubstitute : public IRMutator { + public: + Expr mutate(Expr expr) final { + const IRNode* v = expr.get(); + if (v != nullptr) { + auto it = replacements_.find(v); + if (it != replacements_.end()) { + return it->second; + } + } + return IRMutator::mutate(expr); + } + explicit IRSubstitute(const std::unordered_map<const IRNode*, Expr>& replacements) + : replacements_(replacements) {} + + private: + const std::unordered_map<const IRNode*, Expr>& replacements_; +}; +} // namespace + +Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr) { + return IRSubstitute(replacements).mutate(expr); +} + IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*) static FMutateExpr inst; return inst; } diff --git a/tests/cpp/ir_mutator_test.cc b/tests/cpp/ir_mutator_test.cc index 93843d7ff..8a56e5097 100644 --- a/tests/cpp/ir_mutator_test.cc +++ b/tests/cpp/ir_mutator_test.cc @@ -52,6 +52,25 @@ TEST(IRMutator, Basic) { CHECK(os.str() == "(x + 10)"); } +TEST(IRMutator, Substitute) { + using namespace Halide::Internal; + using namespace tvm; + Var x("x"), y; + auto z = x + y; + { + auto zz = Substitute({{y.get(), 11}}, z); + std::ostringstream os; + os << zz; + CHECK(os.str() == "(x + 11)"); + } + { + auto zz = Substitute({{z.get(), 11}}, z); + std::ostringstream os; + os << zz; + CHECK(os.str() == "11"); + } +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; -- GitLab