From d45b6d4b8425915727ddfe2cffc42bbae17c5664 Mon Sep 17 00:00:00 2001 From: ziheng <ziheng@apache.org> Date: Sat, 29 Apr 2017 14:39:24 -0700 Subject: [PATCH] [DOC] Add intro to 'comm_reducer' in tutorial; fix doc (#108) * [DOC] Add intro to 'comm_reducer' in tutorial; fix doc * Fix * Fix --- include/tvm/ir_pass.h | 7 +++++++ python/setup.py | 2 +- python/tvm/api.py | 4 +++- tutorials/python/reduction.py | 18 ++++++++++++++++++ 4 files changed, 29 insertions(+), 2 deletions(-) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index bad2fb3dd..61731fb0c 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -84,6 +84,13 @@ Stmt CanonicalSimplify(Stmt stmt); * \return The converted form. */ Stmt Substitute(Stmt stmt, const Map<Var, Expr>& value_map); + +/*! + * \brief Substitute the var specified in key->var to be value. + * \param expr The source expression to be substituted + * \param value_map The map of new values. + * \return The converted expression. + */ Expr Substitute(Expr expr, const Map<Var, Expr>& value_map); /*! diff --git a/python/setup.py b/python/setup.py index 23f6d3cc8..1903cf7b3 100644 --- a/python/setup.py +++ b/python/setup.py @@ -70,7 +70,7 @@ setuptools.setup( ], zip_safe=False, packages=[ - 'tvm', 'tvm.addon', + 'tvm', 'tvm.contrib', 'tvm._ffi', 'tvm._ffi._ctypes', 'tvm._ffi._cy2', 'tvm._ffi._cy3' ], diff --git a/python/tvm/api.py b/python/tvm/api.py index 2c1076474..edac2efe5 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -23,10 +23,12 @@ handle = "handle" def min_value(dtype): + """minimum value of dtype""" return _api_internal._min_value(dtype) def max_value(dtype): + """maximum value of dtype""" return _api_internal._max_value(dtype) @@ -438,7 +440,7 @@ def comm_reducer(fcombine, fidentity, name="reduce"): ------- reducer : function A function which creates a reduce expression over axis. - There are two to use it: + There are two ways to use it: 1. accept (expr, axis, where) to produce an Reduce Expr on specified axis; diff --git a/tutorials/python/reduction.py b/tutorials/python/reduction.py index ab79b18b7..21a65c380 100644 --- a/tutorials/python/reduction.py +++ b/tutorials/python/reduction.py @@ -124,6 +124,23 @@ fcuda(a, b) np.testing.assert_allclose( b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4) +###################################################################### +# Define General Commutative Reduction Operation +# ---------------------------------------------- +# Besides the built-in reduction operations like :any:`tvm.sum`, +# :any:`tvm.min` and :any:`tvm.max`, you can also define your +# commutative reduction operation by :any:`tvm.comm_reducer`. +# + +n = tvm.var('n') +m = tvm.var('m') +product = tvm.comm_reducer(lambda x, y: x*y, + lambda t: tvm.const(1, dtype=t), name="product") +A = tvm.placeholder((n, m), name='A') +k = tvm.reduce_axis((0, m), name='k') +B = tvm.compute((n,), lambda i: product(A[i, k], axis=k), name='B') + + ###################################################################### # Summary # ------- @@ -131,3 +148,4 @@ np.testing.assert_allclose( # # - Describe reduction with reduce_axis. # - Use rfactor to factor out axis if we need parallelism. +# - Define new reduction operation by :any:`tvm.comm_reducer` -- GitLab