From f2f1526daa654743b165732b522799d60ac1c979 Mon Sep 17 00:00:00 2001 From: Haichen Shen <shenhaichen@gmail.com> Date: Mon, 16 Jan 2017 14:53:47 -0800 Subject: [PATCH] [PASS] Export simplify and equal to python (#14) * [PASS] Export simplify and equal to python * fix naming convention --- include/tvm/ir_pass.h | 17 +++++++++++++++++ src/c_api/c_api_pass.cc | 21 +++++++++++++++++++++ tests/python/test_pass_basic.py | 10 ++++++++++ 3 files changed, 48 insertions(+) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index d4456ed74..a45bbbb91 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -9,6 +9,8 @@ #ifndef TVM_IR_PASS_H_ #define TVM_IR_PASS_H_ +#include <ir/IREquality.h> +#include <pass/Simplify.h> #include <tvm/ir_functor.h> #include <unordered_map> #include <vector> @@ -19,6 +21,21 @@ namespace tvm { namespace ir { +inline bool Equal(Expr a, Expr b) { + return Halide::Internal::equal(a, b); +} + +inline bool Equal(Stmt a, Stmt b) { + return Halide::Internal::equal(a, b); +} + +inline Expr Simplify(Expr a) { + return Halide::Internal::simplify(a); +} + +inline Stmt Simplify(Stmt a) { + return Halide::Internal::simplify(a); +} /*! * \brief Schedule s' dependent operations. diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc index e05f696bd..10ffe95f6 100644 --- a/src/c_api/c_api_pass.cc +++ b/src/c_api/c_api_pass.cc @@ -13,6 +13,27 @@ namespace ir { using ArgStack = const std::vector<APIVariantValue>; using RetValue = APIVariantValue; +TVM_REGISTER_API(_pass_Simplify) +.set_body([](const ArgStack& args, RetValue *ret) { + CHECK(args.at(0).type_id == kNodeHandle); + if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) { + *ret = Simplify(args.at(0).operator Expr()); + } else { + *ret = Simplify(args.at(0).operator Stmt()); + } + }); + +TVM_REGISTER_API(_pass_Equal) +.set_body([](const ArgStack& args, RetValue *ret) { + CHECK(args.at(0).type_id == kNodeHandle); + CHECK(args.at(1).type_id == kNodeHandle); + if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) { + *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr()); + } else { + *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); + } + }); + // make from two arguments #define REGISTER_PASS1(PassName) \ TVM_REGISTER_API(_pass_## PassName) \ diff --git a/tests/python/test_pass_basic.py b/tests/python/test_pass_basic.py index 23262f1cc..ebffc5880 100644 --- a/tests/python/test_pass_basic.py +++ b/tests/python/test_pass_basic.py @@ -1,5 +1,15 @@ import tvm +def test_simplify(): + x = tvm.Var('x') + e1 = tvm.ir_pass.Simplify(x + 2 + 1) + assert(tvm.ir_pass.Equal(e1, x + 3)) + e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x) + assert(tvm.ir_pass.Equal(e2, x * 8)) + e3 = tvm.ir_pass.Simplify(x - x / 3 * 3) + assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3))) + + def test_verify_ssa(): x = tvm.Var('x') y = tvm.Var() -- GitLab