diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 340a7110ae4898a947b1bd31db7b62138c9b6421..904a16530fc76124b66801e8dd15703ff304bf67 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -102,6 +102,25 @@ class IRMutator { virtual Expr Mutate_(const Shuffle* op, const Expr& e); }; +/*! + * \brief recursively visit the ir in post DFS order node, and transform it + * + * \param node The ir to be transformed. + * \param preorder The function called in before recursive mutation + * If preorder returns None, then the transform will proceed to recursive call. + * If preorder returns a not None Stmt/Expr, the transformer will simply return it and + * won't do further recursion. + * \param postorder The function called after recursive mutation. + * The recursive mutation result is passed to postorder for further mutation. + * \param only_enable List of StringImm. + * If it is empty, all IRNode will call preorder/postorder + * If it is not empty, preorder/postorder will only be called + * when the IRNode's type key is in the list. + */ +Stmt IRTransform(const Stmt& node, + const runtime::PackedFunc& preorder, + const runtime::PackedFunc& postorder, + const Array<Expr>& only_enable = {}); } // namespace ir } // namespace tvm #endif // TVM_IR_MUTATOR_H_ diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 4868db657d112d56cb7aa3ae9191961b68e415c9..e5505b68df64dfd5ffbc686058d555274c340969 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -7,6 +7,7 @@ #include <tvm/ir.h> #include <tvm/ir_pass.h> #include <tvm/ir_visitor.h> +#include <tvm/ir_mutator.h> #include <tvm/api_registry.h> namespace tvm { @@ -88,6 +89,7 @@ REGISTER_PASS1(VerifySSA); REGISTER_PASS1(RewriteUnsafeSelect); REGISTER_PASS4(Inline); REGISTER_PASS3(StorageFlatten); +REGISTER_PASS4(IRTransform); REGISTER_PASS1(VectorizeLoop); REGISTER_PASS4(UnrollLoop); REGISTER_PASS2(ThreadSync); diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc index 923f004b530ea9ace04670693f300315f1fc70d2..4e247ed2bf4c2afdd1100ca563a30b8fdaacd2b3 100644 --- a/src/api/dsl_api.cc +++ b/src/api/dsl_api.cc @@ -110,7 +110,7 @@ class DSLAPIImpl : public DSLAPI { *out_index = static_cast<int>(Node::TypeKey2Index(type_key)); } void NodeGetTypeIndex(NodeHandle handle, - int* out_index) const final { + int* out_index) const final { *out_index = static_cast<int>( (*static_cast<TVMAPINode*>(handle))->type_index()); } diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc index ef8e1f8771d93c90d6a330c472caa9f0ba1bd4c7..fa77942b6058ca03c6c96969f451dd55779f38f3 100644 --- a/src/pass/coproc_sync.cc +++ b/src/pass/coproc_sync.cc @@ -11,7 +11,6 @@ #include "./ir_util.h" #include "./storage_access.h" - namespace tvm { namespace ir { diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index efe7e5ee13d81d83beae64a283aa9ea96127c92b..993b68f835d743f408775340f5925135a39b1c46 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -4,11 +4,65 @@ */ #include <tvm/ir.h> #include <tvm/ir_mutator.h> +#include <tvm/packed_func_ext.h> #include "./ir_util.h" namespace tvm { namespace ir { +class IRTransformer final : public IRMutator { + public: + IRTransformer(const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, + const std::unordered_set<uint32_t>& only_enable) + : f_preorder_(f_preorder), + f_postorder_(f_postorder), + only_enable_(only_enable) { + } + Stmt Mutate(Stmt stmt) final { + return MutateInternal<Stmt>(stmt); + } + Expr Mutate(Expr expr) final { + return MutateInternal<Expr>(expr); + } + + private: + template<typename T> + T MutateInternal(T node) { + if (only_enable_.size() && + !only_enable_.count(node->type_index())) { + return IRMutator::Mutate(node); + } + if (f_preorder_ != nullptr) { + T pre = f_preorder_(node); + if (pre.defined()) return pre; + } + node = IRMutator::Mutate(node); + if (f_postorder_ != nullptr) { + T post = f_postorder_(node); + if (post.defined()) return post; + } + return node; + } + // The functions + const runtime::PackedFunc& f_preorder_; + const runtime::PackedFunc& f_postorder_; + // type indices enabled. + const std::unordered_set<uint32_t>& only_enable_; +}; + +Stmt IRTransform(const Stmt& ir_node, + const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, + const Array<Expr>& only_enable) { + std::unordered_set<uint32_t> only_type_index; + for (Expr s : only_enable) { + only_type_index.insert(Node::TypeKey2Index(s.as<StringImm>()->value.c_str())); + } + return IRTransformer(f_preorder, f_postorder, only_type_index) + .Mutate(ir_node); +} + IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*) static FMutateExpr inst; return inst; } diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index 32f341bf25250f9b99747387183690a913b5d1f6..412effa2f0a27ab46511e513b27abb8aedc65144 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -14,5 +14,11 @@ def test_basic(): assert m[1].value == 5 assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7 + 1)).value == 0 + m = tvm.arith.DetectLinearEquation(a * b + 7, a) + assert m[1] == b + + m = tvm.arith.DetectLinearEquation(b * 7, a) + assert m[1].value == 0 + if __name__ == "__main__": test_basic() diff --git a/tests/python/unittest/test_pass_ir_transform.py b/tests/python/unittest/test_pass_ir_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..87c82c5e5f5ff427b8069cbd12863722926a40f6 --- /dev/null +++ b/tests/python/unittest/test_pass_ir_transform.py @@ -0,0 +1,29 @@ +import tvm + +def test_ir_transform(): + ib = tvm.ir_builder.create() + n = tvm.var("n") + with ib.for_range(0, n, name="i") as i: + with ib.for_range(0, 10, name="j") as j: + x = tvm.call_extern("int32", "TestA", i * 3 + j * 1) + ib.emit(tvm.call_extern("int32", "TestB", x)) + ib.emit(tvm.call_extern("int32", "TestC", x)) + body = ib.get() + + def preorder(op): + if op.name == "TestC": + return tvm.const(0, "int32") + return None + + def postorder(op): + assert isinstance(op, tvm.expr.Call) + if op.name == "TestA": + return tvm.call_extern("int32", "TestB", op.args[0] + 1) + return op + body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"]) + stmt_list = tvm.make.stmt_list(body.body.body) + assert stmt_list[0].value.args[0].name == "TestB" + assert stmt_list[1].value.value == 0 + +if __name__ == "__main__": + test_ir_transform()