From f2f1526daa654743b165732b522799d60ac1c979 Mon Sep 17 00:00:00 2001
From: Haichen Shen <shenhaichen@gmail.com>
Date: Mon, 16 Jan 2017 14:53:47 -0800
Subject: [PATCH] [PASS] Export simplify and equal to python (#14)

* [PASS] Export simplify and equal to python

* fix naming convention
---
 include/tvm/ir_pass.h           | 17 +++++++++++++++++
 src/c_api/c_api_pass.cc         | 21 +++++++++++++++++++++
 tests/python/test_pass_basic.py | 10 ++++++++++
 3 files changed, 48 insertions(+)

diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index d4456ed74..a45bbbb91 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -9,6 +9,8 @@
 #ifndef TVM_IR_PASS_H_
 #define TVM_IR_PASS_H_
 
+#include <ir/IREquality.h>
+#include <pass/Simplify.h>
 #include <tvm/ir_functor.h>
 #include <unordered_map>
 #include <vector>
@@ -19,6 +21,21 @@
 namespace tvm {
 namespace ir {
 
+inline bool Equal(Expr a, Expr b) {
+  return Halide::Internal::equal(a, b);
+}
+
+inline bool Equal(Stmt a, Stmt b) {
+  return Halide::Internal::equal(a, b);
+}
+
+inline Expr Simplify(Expr a) {
+  return Halide::Internal::simplify(a);
+}
+
+inline Stmt Simplify(Stmt a) {
+  return Halide::Internal::simplify(a);
+}
 
 /*!
  * \brief Schedule s' dependent operations.
diff --git a/src/c_api/c_api_pass.cc b/src/c_api/c_api_pass.cc
index e05f696bd..10ffe95f6 100644
--- a/src/c_api/c_api_pass.cc
+++ b/src/c_api/c_api_pass.cc
@@ -13,6 +13,27 @@ namespace ir {
 using ArgStack = const std::vector<APIVariantValue>;
 using RetValue = APIVariantValue;
 
+TVM_REGISTER_API(_pass_Simplify)
+.set_body([](const ArgStack& args, RetValue *ret) {
+    CHECK(args.at(0).type_id == kNodeHandle);
+    if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
+      *ret = Simplify(args.at(0).operator Expr());
+    } else {
+      *ret = Simplify(args.at(0).operator Stmt());
+    }
+  });
+
+TVM_REGISTER_API(_pass_Equal)
+.set_body([](const ArgStack& args, RetValue *ret) {
+    CHECK(args.at(0).type_id == kNodeHandle);
+    CHECK(args.at(1).type_id == kNodeHandle);
+    if (dynamic_cast<Expr::ContainerType*>(args.at(0).sptr.get())) {
+      *ret = Equal(args.at(0).operator Expr(), args.at(1).operator Expr());
+    } else {
+      *ret = Equal(args.at(0).operator Stmt(), args.at(1).operator Stmt());
+    }
+  });
+
 // make from two arguments
 #define REGISTER_PASS1(PassName)                                  \
   TVM_REGISTER_API(_pass_## PassName)                             \
diff --git a/tests/python/test_pass_basic.py b/tests/python/test_pass_basic.py
index 23262f1cc..ebffc5880 100644
--- a/tests/python/test_pass_basic.py
+++ b/tests/python/test_pass_basic.py
@@ -1,5 +1,15 @@
 import tvm
 
+def test_simplify():
+  x = tvm.Var('x')
+  e1 = tvm.ir_pass.Simplify(x + 2 + 1)
+  assert(tvm.ir_pass.Equal(e1, x + 3))
+  e2 = tvm.ir_pass.Simplify(x * 3 + 5 * x)
+  assert(tvm.ir_pass.Equal(e2, x * 8))
+  e3 = tvm.ir_pass.Simplify(x - x / 3 * 3)
+  assert(tvm.ir_pass.Equal(e3, tvm.make.Mod(x, 3)))
+
+
 def test_verify_ssa():
     x = tvm.Var('x')
     y = tvm.Var()
-- 
GitLab