From 9c0da90fb14efd45686fa035ae7cef8d83b41913 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Tue, 21 Nov 2017 11:11:41 -0800 Subject: [PATCH] [PASS/SETUP] Fix minior issues (#663) * [PASS/SETUP] Fix minior issues * fix lint --- include/tvm/ir_pass.h | 20 +++++++--- python/setup.py | 42 ++++++++++++++------- python/tvm/_ffi/libinfo.py | 3 +- src/api/api_pass.cc | 12 +++++- src/arithmetic/canonical.cc | 24 ++++++++++++ tests/python/unittest/test_pass_simplify.py | 8 ++++ 6 files changed, 86 insertions(+), 23 deletions(-) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 897f96c76..b6b248228 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -22,13 +22,21 @@ namespace tvm { namespace ir { -inline Expr Simplify(Expr a) { - return Halide::Internal::simplify(a); -} +/*! + * \brief Simplify the expression. + * \param expr The expression to be simplifed. + * \param vrange The range information about the variable. + * \return Canonicalized statement. + */ +Expr Simplify(Expr expr, Map<Var, Range> vrange = Map<Var, Range>()); -inline Stmt Simplify(Stmt a) { - return Halide::Internal::simplify(a); -} +/*! + * \brief Simplify the statement. + * \param stmt The statement to be simplifed. + * \param vrange The range information about the variable. + * \return Canonicalized statement. + */ +Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>()); /*! * \brief Simplify by applying canonical form. diff --git a/python/setup.py b/python/setup.py index 168729391..5a87325e9 100644 --- a/python/setup.py +++ b/python/setup.py @@ -18,16 +18,25 @@ else: from setuptools import setup from setuptools.extension import Extension -# We can not import `libinfo.py` in setup.py directly since __init__.py -# Will be invoked which introduces dependences -CURRENT_DIR = os.path.dirname(__file__) -libinfo_py = os.path.join(CURRENT_DIR, './tvm/_ffi/libinfo.py') -libinfo = {'__file__': libinfo_py} -exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo) +def get_lib_path(): + """Get library path, name and version""" + # We can not import `libinfo.py` in setup.py directly since __init__.py + # Will be invoked which introduces dependences + CURRENT_DIR = os.path.dirname(__file__) + libinfo_py = os.path.join(CURRENT_DIR, './tvm/_ffi/libinfo.py') + libinfo = {'__file__': libinfo_py} + exec(compile(open(libinfo_py, "rb").read(), libinfo_py, 'exec'), libinfo, libinfo) + lib_path = libinfo['find_lib_path']() + version = libinfo['__version__'] + libs = [lib_path[0]] + if libs[0].find("runtime") == -1: + for name in lib_path[1:]: + if name.find("runtime") != -1: + libs.append(name) + break + return libs, version -LIB_PATH = libinfo['find_lib_path']() -_, LIB_NAME = os.path.split(LIB_PATH[0]) -__version__ = libinfo['__version__'] +LIB_LIST, __version__ = get_lib_path() def config_cython(): """Try to configure cython and return cython configuration""" @@ -81,18 +90,21 @@ class BinaryDistribution(Distribution): # For bdist_wheel only if "bdist_wheel" in sys.argv: - shutil.copy(LIB_PATH[0], os.path.join(CURRENT_DIR, 'tvm')) with open("MANIFEST.in", "w") as fo: - fo.write("include tvm/%s\n" % LIB_NAME) + for path in LIB_LIST: + shutil.copy(path, os.path.join(CURRENT_DIR, 'tvm')) + _, libname = os.path.split(path) + fo.write("include tvm/%s\n" % libname) setup_kwargs = { "include_package_data": True } else: curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - rpath = os.path.relpath(LIB_PATH[0], curr_path) + for i, path in enumerate(LIB_LIST): + LIB_LIST[i] = os.path.relpath(path, curr_path) setup_kwargs = { "include_package_data": True, - "data_files": [('tvm', [rpath])] + "data_files": [('tvm', LIB_LIST)] } setup(name='tvm', @@ -112,4 +124,6 @@ setup(name='tvm', # Wheel cleanup if "bdist_wheel" in sys.argv: os.remove("MANIFEST.in") - os.remove("tvm/%s" % LIB_NAME) + for path in LIB_LIST: + _, libname = os.path.split(path) + os.remove("tvm/%s" % LIB_NAME) diff --git a/python/tvm/_ffi/libinfo.py b/python/tvm/_ffi/libinfo.py index 273e8f8fb..f3ed174c0 100644 --- a/python/tvm/_ffi/libinfo.py +++ b/python/tvm/_ffi/libinfo.py @@ -74,7 +74,8 @@ def find_lib_path(name=None, search_path=None): if not use_runtime: # try to find lib_dll_path lib_found = [p for p in lib_dll_path if os.path.exists(p) and os.path.isfile(p)] - if use_runtime or not lib_found: + lib_found += [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)] + else: # try to find runtime_dll_path use_runtime = True lib_found = [p for p in runtime_dll_path if os.path.exists(p) and os.path.isfile(p)] diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index a3134f511..024af23a3 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -16,9 +16,17 @@ namespace ir { TVM_REGISTER_API("ir_pass.Simplify") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsNodeType<Stmt>()) { - *ret = Simplify(args[0].operator Stmt()); + if (args.size() > 1) { + *ret = Simplify(args[0].operator Stmt(), args[1]); + } else { + *ret = Simplify(args[0].operator Stmt()); + } } else { - *ret = Simplify(args[0].operator Expr()); + if (args.size() > 1) { + *ret = Simplify(args[0].operator Expr(), args[1]); + } else { + *ret = Simplify(args[0].operator Expr()); + } } }); diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 933a8f78e..808e070ef 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -7,6 +7,7 @@ #include <tvm/arithmetic.h> #include "./canonical.h" #include "./compute_expr.h" +#include "arithmetic/Simplify.h" namespace tvm { namespace arith { @@ -559,5 +560,28 @@ Stmt CanonicalSimplify(Stmt stmt) { Expr CanonicalSimplify(Expr expr) { return arith::Canonical().Simplify(expr); } + +template<typename T> +T Simplify_(T a, Map<Var, Range> vrange) { + using namespace Halide::Internal; + Scope<Interval> rscope; + for (auto kv : vrange) { + Range r = kv.second; + rscope.push( + kv.first.get(), + Interval(r->min, + simplify(r->min + r->extent - make_const(r->min.type(), 1)))); + } + return Halide::Internal::simplify(a, true, rscope); +} + + +Expr Simplify(Expr a, Map<Var, Range> vrange) { + return Simplify_(a, vrange); +} + +Stmt Simplify(Stmt a, Map<Var, Range> vrange) { + return Simplify_(a, vrange); +} } // namespace ir } // namespace tvm diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py index 2cc8825e3..9105693b3 100644 --- a/tests/python/unittest/test_pass_simplify.py +++ b/tests/python/unittest/test_pass_simplify.py @@ -27,6 +27,13 @@ def test_basic(): assert str(ret.value) == "(m - 1)" +def test_bound(): + m = tvm.var('m') + vrange = tvm.convert({m: tvm.Range(tvm.const(0), tvm.const(10))}) + ret = tvm.ir_pass.Simplify(m % 10, vrange) + assert ret == m + + def test_canonical(): x = tvm.var("x") z = tvm.const(3) @@ -37,6 +44,7 @@ def test_canonical(): assert(tvm.ir_pass.Equal(ret, 0)) if __name__ == "__main__": + test_bound() test_basic() test_simplify() test_canonical() -- GitLab