From 110c9bec7491b49fd8e20fb7ce5b78416c86f4b3 Mon Sep 17 00:00:00 2001 From: Haichen Shen <shenhaichen@gmail.com> Date: Wed, 18 Jan 2017 14:48:16 -0800 Subject: [PATCH] [PASS] Assign unique names to variables in ConvertSSA pass (#18) * [PASS] Assign unique names to variables in ConvertSSA pass * revert change to ConverSSA pass --- src/c_api/c_api_pass.cc | 18 +++++++++--------- tests/python/test_pass_basic.py | 8 ++++++-- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc index 10ffe95f6..e45e25a26 100644 --- a/src/c_api/c_api_pass.cc +++ b/src/c_api/c_api_pass.cc @@ -15,22 +15,22 @@ using RetValue = APIVariantValue; TVM_REGISTER_API(_pass_Simplify) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); - if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) { - *ret = Simplify(args.at(0).operator Expr()); - } else { + if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) { *ret = Simplify(args.at(0).operator Stmt()); + } else { + *ret = Simplify(args.at(0).operator Expr()); } }); TVM_REGISTER_API(_pass_Equal) .set_body([](const ArgStack& args, RetValue *ret) { - CHECK(args.at(0).type_id == kNodeHandle); - CHECK(args.at(1).type_id == kNodeHandle); - if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) { - *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr()); - } else { + if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) { + CHECK(args.at(1).type_id == kNodeHandle); *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt()); + } else { + Expr a = args.at(0).operator Expr(); + Expr b = args.at(1).operator Expr(); + *ret = Equal(a, b); } }); diff --git a/tests/python/test_pass_basic.py b/tests/python/test_pass_basic.py index ebffc5880..b9e8d501e 100644 --- a/tests/python/test_pass_basic.py +++ b/tests/python/test_pass_basic.py @@ -8,6 +8,9 @@ def test_simplify(): assert(tvm.ir_pass.Equal(e2, x * 8)) e3 = tvm.ir_pass.Simplify(x - x / 3 * 3) assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3))) + let = tvm.make.Let(x, 1, x + 3) + e4 = tvm.ir_pass.Simplify(let) + assert(tvm.ir_pass.Equal(e4, 4)) def test_verify_ssa(): @@ -20,8 +23,9 @@ def test_verify_ssa(): def test_convert_ssa(): x = tvm.Var('x') y = tvm.Var() - let = tvm.make.Let(x, 1, x + 1) - z = tvm.make.Evaluate(let + let) + let1 = tvm.make.Let(x, 1, x + 1) + let2 = tvm.make.Let(x, 1, x + y) + z = tvm.make.Evaluate(let1 + let2) assert(not tvm.ir_pass.VerifySSA(z)) z_ssa = tvm.ir_pass.ConvertSSA(z) assert(tvm.ir_pass.VerifySSA(z_ssa)) -- GitLab