From 9c0da90fb14efd45686fa035ae7cef8d83b41913 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Tue, 21 Nov 2017 11:11:41 -0800
Subject: [PATCH] [PASS/SETUP] Fix minior issues (#663)

* [PASS/SETUP] Fix minior issues

* fix lint
---
 include/tvm/ir_pass.h                       | 20 +++++++---
 python/setup.py                             | 42 ++++++++++++++-------
 python/tvm/_ffi/libinfo.py                  |  3 +-
 src/api/api_pass.cc                         | 12 +++++-
 src/arithmetic/canonical.cc                 | 24 ++++++++++++
 tests/python/unittest/test_pass_simplify.py |  8 ++++
 6 files changed, 86 insertions(+), 23 deletions(-)

diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index 897f96c76..b6b248228 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -22,13 +22,21 @@
 namespace tvm {
 namespace ir {
 
-inline Expr Simplify(Expr a) {
-  return Halide::Internal::simplify(a);
-}
+/*!
+ * \brief Simplify the expression.
+ * \param expr The expression to be simplifed.
+ * \param vrange The range information about the variable.
+ * \return Canonicalized statement.
+ */
+Expr Simplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>());
 
-inline Stmt Simplify(Stmt a) {
-  return Halide::Internal::simplify(a);
-}
+/*!
+ * \brief Simplify the statement.
+ * \param stmt The statement to be simplifed.
+ * \param vrange The range information about the variable.
+ * \return Canonicalized statement.
+ */
+Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());
 
 /*!
  * \brief Simplify by applying canonical form.
diff --git a/python/setup.py b/python/setup.py
index 168729391..5a87325e9 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -18,16 +18,25 @@ else:
     from setuptools import setup
     from setuptools.extension import Extension
 
-# We can not import `libinfo.py` in setup.py directly since __init__.py
-# Will be invoked which introduces dependences
-CURRENT_DIR = os.path.dirname(__file__)
-libinfo_py = os.path.join(CURRENT_DIR, './tvm/_ffi/libinfo.py')
-libinfo = {'__file__': libinfo_py}
-exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo)
+def get_lib_path():
+    """Get library path, name and version"""
+    # We can not import `libinfo.py` in setup.py directly since __init__.py
+    # Will be invoked which introduces dependences
+    CURRENT_DIR = os.path.dirname(__file__)
+    libinfo_py = os.path.join(CURRENT_DIR, './tvm/_ffi/libinfo.py')
+    libinfo = {'__file__': libinfo_py}
+    exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo)
+    lib_path = libinfo['find_lib_path']()
+    version = libinfo['__version__']
+    libs = [lib_path[0]]
+    if libs[0].find("runtime") == -1:
+        for name in lib_path[1:]:
+            if name.find("runtime") != -1:
+                libs.append(name)
+                break
+    return libs, version
 
-LIB_PATH = libinfo['find_lib_path']()
-_, LIB_NAME = os.path.split(LIB_PATH[0])
-__version__ = libinfo['__version__']
+LIB_LIST, __version__ = get_lib_path()
 
 def config_cython():
     """Try to configure cython and return cython configuration"""
@@ -81,18 +90,21 @@ class BinaryDistribution(Distribution):
 
 # For bdist_wheel only
 if "bdist_wheel" in sys.argv:
-    shutil.copy(LIB_PATH[0], os.path.join(CURRENT_DIR, 'tvm'))
     with open("MANIFEST.in", "w") as fo:
-        fo.write("include tvm/%s\n" % LIB_NAME)
+        for path in LIB_LIST:
+            shutil.copy(path, os.path.join(CURRENT_DIR, 'tvm'))
+            _, libname = os.path.split(path)
+            fo.write("include tvm/%s\n" % libname)
     setup_kwargs = {
         "include_package_data": True
     }
 else:
     curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
-    rpath = os.path.relpath(LIB_PATH[0], curr_path)
+    for i, path in enumerate(LIB_LIST):
+        LIB_LIST[i] = os.path.relpath(path, curr_path)
     setup_kwargs = {
         "include_package_data": True,
-        "data_files": [('tvm', [rpath])]
+        "data_files": [('tvm', LIB_LIST)]
     }
 
 setup(name='tvm',
@@ -112,4 +124,6 @@ setup(name='tvm',
 # Wheel cleanup
 if "bdist_wheel" in sys.argv:
     os.remove("MANIFEST.in")
-    os.remove("tvm/%s" % LIB_NAME)
+    for path in LIB_LIST:
+        _, libname = os.path.split(path)
+        os.remove("tvm/%s" % LIB_NAME)
diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py
index 273e8f8fb..f3ed174c0 100644
--- a/python/tvm/_ffi/libinfo.py
+++ b/python/tvm/_ffi/libinfo.py
@@ -74,7 +74,8 @@ def find_lib_path(name=None, search_path=None):
     if not use_runtime:
         # try to find lib_dll_path
         lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)]
-    if use_runtime or not lib_found:
+        lib_found += [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)]
+    else:
         # try to find runtime_dll_path
         use_runtime = True
         lib_found = [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)]
diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc
index a3134f511..024af23a3 100644
--- a/src/api/api_pass.cc
+++ b/src/api/api_pass.cc
@@ -16,9 +16,17 @@ namespace ir {
 TVM_REGISTER_API("ir_pass.Simplify")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
     if (args[0].IsNodeType<Stmt>()) {
-      *ret = Simplify(args[0].operator Stmt());
+      if (args.size() > 1) {
+        *ret = Simplify(args[0].operator Stmt(), args[1]);
+      } else {
+        *ret = Simplify(args[0].operator Stmt());
+      }
     } else {
-      *ret = Simplify(args[0].operator Expr());
+      if (args.size() > 1) {
+        *ret = Simplify(args[0].operator Expr(), args[1]);
+      } else {
+        *ret = Simplify(args[0].operator Expr());
+      }
     }
   });
 
diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc
index 933a8f78e..808e070ef 100644
--- a/src/arithmetic/canonical.cc
+++ b/src/arithmetic/canonical.cc
@@ -7,6 +7,7 @@
 #include <tvm/arithmetic.h>
 #include "./canonical.h"
 #include "./compute_expr.h"
+#include "arithmetic/Simplify.h"
 
 namespace tvm {
 namespace arith {
@@ -559,5 +560,28 @@ Stmt CanonicalSimplify(Stmt stmt) {
 Expr CanonicalSimplify(Expr expr) {
   return arith::Canonical().Simplify(expr);
 }
+
+template<typename T>
+T Simplify_(T a, Map<Var, Range> vrange) {
+  using namespace Halide::Internal;
+  Scope<Interval> rscope;
+  for (auto kv : vrange) {
+    Range r = kv.second;
+    rscope.push(
+        kv.first.get(),
+        Interval(r->min,
+                 simplify(r->min + r->extent - make_const(r->min.type(), 1))));
+  }
+  return Halide::Internal::simplify(a, true, rscope);
+}
+
+
+Expr Simplify(Expr a, Map<Var, Range> vrange) {
+  return Simplify_(a, vrange);
+}
+
+Stmt Simplify(Stmt a, Map<Var, Range> vrange) {
+  return Simplify_(a, vrange);
+}
 }  // namespace ir
 }  // namespace tvm
diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py
index 2cc8825e3..9105693b3 100644
--- a/tests/python/unittest/test_pass_simplify.py
+++ b/tests/python/unittest/test_pass_simplify.py
@@ -27,6 +27,13 @@ def test_basic():
     assert str(ret.value) == "(m - 1)"
 
 
+def test_bound():
+    m = tvm.var('m')
+    vrange = tvm.convert({m: tvm.Range(tvm.const(0), tvm.const(10))})
+    ret = tvm.ir_pass.Simplify(m % 10, vrange)
+    assert ret == m
+
+
 def test_canonical():
     x = tvm.var("x")
     z = tvm.const(3)
@@ -37,6 +44,7 @@ def test_canonical():
     assert(tvm.ir_pass.Equal(ret, 0))
 
 if __name__ == "__main__":
+    test_bound()
     test_basic()
     test_simplify()
     test_canonical()
-- 
GitLab