diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index bad2fb3dd54e87225dded19bbf37679f5ebe423e..61731fb0c3e5180f919715118650442c099562bd 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -84,6 +84,13 @@ Stmt CanonicalSimplify(Stmt stmt); * \return The converted form. */ Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map); + +/*! + * \brief Substitute the var specified in key->var to be value. + * \param expr The source expression to be substituted + * \param value_map The map of new values. + * \return The converted expression. + */ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map); /*! diff --git a/python/setup.py b/python/setup.py index 23f6d3cc82622baf2a66217483901d88b1d690a6..1903cf7b368285078728643b34cabca59178f13e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -70,7 +70,7 @@ setuptools.setup( ], zip_safe=False, packages=[ - 'tvm', 'tvm.addon', + 'tvm', 'tvm.contrib', 'tvm._ffi', 'tvm._ffi._ctypes', 'tvm._ffi._cy2', 'tvm._ffi._cy3' ], diff --git a/python/tvm/api.py b/python/tvm/api.py index 2c1076474ebe9252b75068a286e96e2d7d8fd923..edac2efe57062c320ed2c7247480f6e5d0050ea2 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -23,10 +23,12 @@ handle = "handle" def min_value(dtype): + """minimum value of dtype""" return _api_internal._min_value(dtype) def max_value(dtype): + """maximum value of dtype""" return _api_internal._max_value(dtype) @@ -438,7 +440,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"): ------- reducer : function A function which creates a reduce expression over axis. - There are two to use it: + There are two ways to use it: 1. accept (expr, axis, where) to produce an Reduce Expr on specified axis; diff --git a/tutorials/python/reduction.py b/tutorials/python/reduction.py index ab79b18b771d6875e25b4e5de0fd21d95492d677..21a65c380b1e74c0a43b124d13f8ae419318fbfa 100644 --- a/tutorials/python/reduction.py +++ b/tutorials/python/reduction.py @@ -124,6 +124,23 @@ fcuda(a, b) np.testing.assert_allclose( b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4) +###################################################################### +# Define General Commutative Reduction Operation +# ---------------------------------------------- +# Besides the built-in reduction operations like :any:`tvm.sum`, +# :any:`tvm.min` and :any:`tvm.max`, you can also define your +# commutative reduction operation by :any:`tvm.comm_reducer`. +# + +n = tvm.var('n') +m = tvm.var('m') +product = tvm.comm_reducer(lambda x, y: x*y, + lambda t: tvm.const(1, dtype=t), name="product") +A = tvm.placeholder((n, m), name='A') +k = tvm.reduce_axis((0, m), name='k') +B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B') + + ###################################################################### # Summary # ------- @@ -131,3 +148,4 @@ np.testing.assert_allclose( # # - Describe reduction with reduce_axis. # - Use rfactor to factor out axis if we need parallelism. +# - Define new reduction operation by :any:`tvm.comm_reducer`