From 110c9bec7491b49fd8e20fb7ce5b78416c86f4b3 Mon Sep 17 00:00:00 2001
From: Haichen Shen <shenhaichen@gmail.com>
Date: Wed, 18 Jan 2017 14:48:16 -0800
Subject: [PATCH] [PASS] Assign unique names to variables in ConvertSSA pass
 (#18)

* [PASS] Assign unique names to variables in ConvertSSA pass

* revert change to ConverSSA pass
---
 src/c_api/c_api_pass.cc         | 18 +++++++++---------
 tests/python/test_pass_basic.py |  8 ++++++--
 2 files changed, 15 insertions(+), 11 deletions(-)

diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc
index 10ffe95f6..e45e25a26 100644
--- a/src/c_api/c_api_pass.cc
+++ b/src/c_api/c_api_pass.cc
@@ -15,22 +15,22 @@ using RetValue = APIVariantValue;
 
 TVM_REGISTER_API(_pass_Simplify)
 .set_body([](const ArgStack& args, RetValue *ret) {
-    CHECK(args.at(0).type_id == kNodeHandle);
-    if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
-      *ret = Simplify(args.at(0).operator Expr());
-    } else {
+    if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) {
       *ret = Simplify(args.at(0).operator Stmt());
+    } else {
+      *ret = Simplify(args.at(0).operator Expr());
     }
   });
 
 TVM_REGISTER_API(_pass_Equal)
 .set_body([](const ArgStack& args, RetValue *ret) {
-    CHECK(args.at(0).type_id == kNodeHandle);
-    CHECK(args.at(1).type_id == kNodeHandle);
-    if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
-      *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
-    } else {
+    if (dynamic_cast<Stmt::ContainerType*>(args.at(0).sptr.get())) {
+      CHECK(args.at(1).type_id == kNodeHandle);
       *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
+    } else {
+      Expr a = args.at(0).operator Expr();
+      Expr b = args.at(1).operator Expr();
+      *ret = Equal(a, b);
     }
   });
 
diff --git a/tests/python/test_pass_basic.py b/tests/python/test_pass_basic.py
index ebffc5880..b9e8d501e 100644
--- a/tests/python/test_pass_basic.py
+++ b/tests/python/test_pass_basic.py
@@ -8,6 +8,9 @@ def test_simplify():
   assert(tvm.ir_pass.Equal(e2, x * 8))
   e3 = tvm.ir_pass.Simplify(x - x / 3 * 3)
   assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3)))
+  let = tvm.make.Let(x, 1, x + 3)
+  e4 = tvm.ir_pass.Simplify(let)
+  assert(tvm.ir_pass.Equal(e4, 4))
 
 
 def test_verify_ssa():
@@ -20,8 +23,9 @@ def test_verify_ssa():
 def test_convert_ssa():
     x = tvm.Var('x')
     y = tvm.Var()
-    let = tvm.make.Let(x, 1, x + 1)
-    z = tvm.make.Evaluate(let + let)
+    let1 = tvm.make.Let(x, 1, x + 1)
+    let2 = tvm.make.Let(x, 1, x + y)
+    z = tvm.make.Evaluate(let1 + let2)
     assert(not tvm.ir_pass.VerifySSA(z))
     z_ssa = tvm.ir_pass.ConvertSSA(z)
     assert(tvm.ir_pass.VerifySSA(z_ssa))
-- 
GitLab