From 4578048ce14d9b914a5bf4772cbcb8cfc01d52c3 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Wed, 30 Aug 2017 16:01:35 -0700 Subject: [PATCH] [PASS] IRTransform to enable IR pass proptype in python (#401) --- include/tvm/ir_mutator.h | 19 +++++++ src/api/api_pass.cc | 2 + src/api/dsl_api.cc | 2 +- src/pass/coproc_sync.cc | 1 - src/pass/ir_mutator.cc | 54 +++++++++++++++++++ .../test_arith_detect_linear_equation.py | 6 +++ .../python/unittest/test_pass_ir_transform.py | 29 ++++++++++ 7 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 tests/python/unittest/test_pass_ir_transform.py diff --git a/include/tvm/ir_mutator.h b/include/tvm/ir_mutator.h index 340a7110a..904a16530 100644 --- a/include/tvm/ir_mutator.h +++ b/include/tvm/ir_mutator.h @@ -102,6 +102,25 @@ class IRMutator { virtual Expr Mutate_(const Shuffle* op, const Expr& e); }; +/*! + * \brief recursively visit the ir in post DFS order node, and transform it + * + * \param node The ir to be transformed. + * \param preorder The function called in before recursive mutation + * If preorder returns None, then the transform will proceed to recursive call. + * If preorder returns a not None Stmt/Expr, the transformer will simply return it and + * won't do further recursion. + * \param postorder The function called after recursive mutation. + * The recursive mutation result is passed to postorder for further mutation. + * \param only_enable List of StringImm. + * If it is empty, all IRNode will call preorder/postorder + * If it is not empty, preorder/postorder will only be called + * when the IRNode's type key is in the list. + */ +Stmt IRTransform(const Stmt& node, + const runtime::PackedFunc& preorder, + const runtime::PackedFunc& postorder, + const Array<Expr>& only_enable = {}); } // namespace ir } // namespace tvm #endif // TVM_IR_MUTATOR_H_ diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 4868db657..e5505b68d 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -7,6 +7,7 @@ #include <tvm/ir.h> #include <tvm/ir_pass.h> #include <tvm/ir_visitor.h> +#include <tvm/ir_mutator.h> #include <tvm/api_registry.h> namespace tvm { @@ -88,6 +89,7 @@ REGISTER_PASS1(VerifySSA); REGISTER_PASS1(RewriteUnsafeSelect); REGISTER_PASS4(Inline); REGISTER_PASS3(StorageFlatten); +REGISTER_PASS4(IRTransform); REGISTER_PASS1(VectorizeLoop); REGISTER_PASS4(UnrollLoop); REGISTER_PASS2(ThreadSync); diff --git a/src/api/dsl_api.cc b/src/api/dsl_api.cc index 923f004b5..4e247ed2b 100644 --- a/src/api/dsl_api.cc +++ b/src/api/dsl_api.cc @@ -110,7 +110,7 @@ class DSLAPIImpl : public DSLAPI { *out_index = static_cast<int>(Node::TypeKey2Index(type_key)); } void NodeGetTypeIndex(NodeHandle handle, - int* out_index) const final { + int* out_index) const final { *out_index = static_cast<int>( (*static_cast<TVMAPINode*>(handle))->type_index()); } diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc index ef8e1f877..fa77942b6 100644 --- a/src/pass/coproc_sync.cc +++ b/src/pass/coproc_sync.cc @@ -11,7 +11,6 @@ #include "./ir_util.h" #include "./storage_access.h" - namespace tvm { namespace ir { diff --git a/src/pass/ir_mutator.cc b/src/pass/ir_mutator.cc index efe7e5ee1..993b68f83 100644 --- a/src/pass/ir_mutator.cc +++ b/src/pass/ir_mutator.cc @@ -4,11 +4,65 @@ */ #include <tvm/ir.h> #include <tvm/ir_mutator.h> +#include <tvm/packed_func_ext.h> #include "./ir_util.h" namespace tvm { namespace ir { +class IRTransformer final : public IRMutator { + public: + IRTransformer(const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, + const std::unordered_set<uint32_t>& only_enable) + : f_preorder_(f_preorder), + f_postorder_(f_postorder), + only_enable_(only_enable) { + } + Stmt Mutate(Stmt stmt) final { + return MutateInternal<Stmt>(stmt); + } + Expr Mutate(Expr expr) final { + return MutateInternal<Expr>(expr); + } + + private: + template<typename T> + T MutateInternal(T node) { + if (only_enable_.size() && + !only_enable_.count(node->type_index())) { + return IRMutator::Mutate(node); + } + if (f_preorder_ != nullptr) { + T pre = f_preorder_(node); + if (pre.defined()) return pre; + } + node = IRMutator::Mutate(node); + if (f_postorder_ != nullptr) { + T post = f_postorder_(node); + if (post.defined()) return post; + } + return node; + } + // The functions + const runtime::PackedFunc& f_preorder_; + const runtime::PackedFunc& f_postorder_; + // type indices enabled. + const std::unordered_set<uint32_t>& only_enable_; +}; + +Stmt IRTransform(const Stmt& ir_node, + const runtime::PackedFunc& f_preorder, + const runtime::PackedFunc& f_postorder, + const Array<Expr>& only_enable) { + std::unordered_set<uint32_t> only_type_index; + for (Expr s : only_enable) { + only_type_index.insert(Node::TypeKey2Index(s.as<StringImm>()->value.c_str())); + } + return IRTransformer(f_preorder, f_postorder, only_type_index) + .Mutate(ir_node); +} + IRMutator::FMutateExpr& IRMutator::vtable_expr() { // NOLINT(*) static FMutateExpr inst; return inst; } diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index 32f341bf2..412effa2f 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -14,5 +14,11 @@ def test_basic(): assert m[1].value == 5 assert tvm.ir_pass.Simplify(m[0] - (b * 6 + 7 + 1)).value == 0 + m = tvm.arith.DetectLinearEquation(a * b + 7, a) + assert m[1] == b + + m = tvm.arith.DetectLinearEquation(b * 7, a) + assert m[1].value == 0 + if __name__ == "__main__": test_basic() diff --git a/tests/python/unittest/test_pass_ir_transform.py b/tests/python/unittest/test_pass_ir_transform.py new file mode 100644 index 000000000..87c82c5e5 --- /dev/null +++ b/tests/python/unittest/test_pass_ir_transform.py @@ -0,0 +1,29 @@ +import tvm + +def test_ir_transform(): + ib = tvm.ir_builder.create() + n = tvm.var("n") + with ib.for_range(0, n, name="i") as i: + with ib.for_range(0, 10, name="j") as j: + x = tvm.call_extern("int32", "TestA", i * 3 + j * 1) + ib.emit(tvm.call_extern("int32", "TestB", x)) + ib.emit(tvm.call_extern("int32", "TestC", x)) + body = ib.get() + + def preorder(op): + if op.name == "TestC": + return tvm.const(0, "int32") + return None + + def postorder(op): + assert isinstance(op, tvm.expr.Call) + if op.name == "TestA": + return tvm.call_extern("int32", "TestB", op.args[0] + 1) + return op + body = tvm.ir_pass.IRTransform(body, preorder, postorder, ["Call"]) + stmt_list = tvm.make.stmt_list(body.body.body) + assert stmt_list[0].value.args[0].name == "TestB" + assert stmt_list[1].value.value == 0 + +if __name__ == "__main__": + test_ir_transform() -- GitLab