Skip to content
Snippets Groups Projects
Commit ff6b8d82 authored by tqchen's avatar tqchen
Browse files

check substitute

parent 363cc280
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@
#define TVM_IR_MUTATOR_H_
#include <tvm/ir_node.h>
#include <unordered_map>
#include "./expr.h"
namespace tvm {
......@@ -72,6 +73,13 @@ class IRMutatorExample : public IRMutator {
static FMutateStmt& vtable_stmt(); // NOLINT(*)
};
/*!
* \brief Substitute occurance of IRNode to be expr
* \param replacements The replacement rule of substitution
* \param expr The expression to be substituted.
*/
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr);
} // namespace ir
} // namespace tvm
#endif // TVM_IR_MUTATOR_H_
......@@ -8,6 +8,32 @@
namespace tvm {
namespace ir {
namespace {
// visitor to implement apply
class IRSubstitute : public IRMutator {
public:
Expr mutate(Expr expr) final {
const IRNode* v = expr.get();
if (v != nullptr) {
auto it = replacements_.find(v);
if (it != replacements_.end()) {
return it->second;
}
}
return IRMutator::mutate(expr);
}
explicit IRSubstitute(const std::unordered_map<const IRNode*, Expr>& replacements)
: replacements_(replacements) {}
private:
const std::unordered_map<const IRNode*, Expr>& replacements_;
};
} // namespace
Expr Substitute(const std::unordered_map<const IRNode*, Expr>& replacements, Expr expr) {
return IRSubstitute(replacements).mutate(expr);
}
IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*)
static FMutateExpr inst; return inst;
}
......
......@@ -52,6 +52,25 @@ TEST(IRMutator, Basic) {
CHECK(os.str() == "(x + 10)");
}
TEST(IRMutator, Substitute) {
using namespace Halide::Internal;
using namespace tvm;
Var x("x"), y;
auto z = x + y;
{
auto zz = Substitute({{y.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "(x + 11)");
}
{
auto zz = Substitute({{z.get(), 11}}, z);
std::ostringstream os;
os << zz;
CHECK(os.str() == "11");
}
}
int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment