From 547a0913f69c29deb89b38098880d76fbc563f49 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik <grechanik.sergey@huawei.com> Date: Fri, 11 Jan 2019 21:21:24 +0300 Subject: [PATCH] [TVM] Reduction simplification improvements (#2284) --- src/arithmetic/canonical.cc | 126 ++++++++++++++++++++ tests/python/unittest/test_pass_simplify.py | 84 +++++++++++++ 2 files changed, 210 insertions(+) diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 0ec306213..5ba602bc3 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -781,12 +781,138 @@ T Simplify_(T a, Map<Var, Range> vrange) { } +/*! + * \brief Simplify just the combiner of the given reduce node. + * + * This function applies Simplify to the components of the top reduction's + * combiner, but not to the source or condition of the reduction. + * It also removes all components which are not used to + * compute the resulting value (the value_index-th value). + * + * If \p expr is not a reduction node, it is left unchanged. + * + * \param expr The expression to be simplifed. + * \return Simplified expression. + */ +Expr SimplifyCombiner(const Expr& expr, const Map<Var, Range>& vrange = Map<Var, Range>()) { + const Reduce* op = expr.as<Reduce>(); + if (!op) { + return expr; + } + + // First simplify the results + Array<Expr> simplified_result; + for (const auto& res : op->combiner->result) { + simplified_result.push_back(Simplify(res, vrange)); + } + + // Which components to keep + std::vector<int> used(op->combiner->result.size(), false); + + // This function recursively marks the used components starting from + // the index idx + std::function<void(int)> mark_used; + mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) { + // if the idx-th component was marked as used before, do nothing + if (used[idx]) return; + used[idx] = true; + + // check if the idx-th result expr uses some lhs or rhs variables + // and recursively mark the corresponding components + for (size_t i = 0; i < simplified_result.size(); ++i) + if (!used[i]) { + if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) || + ExprUseVar(simplified_result[idx], op->combiner->rhs[i])) + mark_used(i); + } + }; + + // mark all used components starting from the value_index + mark_used(op->value_index); + + // components which have side effects should also be preserved + for (size_t i = 0; i < used.size(); ++i) { + if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) || + HasSideEffect(op->combiner->result[i])) { + mark_used(i); + } + } + + int new_value_index = op->value_index; + Array<Expr> new_result; + Array<Expr> new_identity; + Array<Var> new_lhs; + Array<Var> new_rhs; + Array<Expr> new_source; + + // new stuff is old stuff which is used + for (size_t i = 0; i < used.size(); ++i) { + if (used[i]) { + // We simplify the result and identity, but not the source + new_result.push_back(simplified_result[i]); + new_identity.push_back(Simplify(op->combiner->identity_element[i], vrange)); + new_lhs.push_back(op->combiner->lhs[i]); + new_rhs.push_back(op->combiner->rhs[i]); + new_source.push_back(op->source[i]); + } else if (static_cast<int>(i) < op->value_index) { + // value_index should also be adjusted + new_value_index--; + } + } + + CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); + return Reduce::make(new_combiner, new_source, op->axis, op->condition, new_value_index); +} + +/*! + * \brief Remove a single reduction over empty axis. + * + * If \p e is a reduction node and its axis is empty, replace it with its source, + * otherwise return \p e unchanged. + * + * \param e The expression to be transformed. + * \return The transformed expression. + */ +Expr RemoveEmptyReduction(const Expr& e) { + const Reduce* r = e.as<Reduce>(); + if (r && r->axis.empty()) { + // Note that here we assume that the identity element is indeed identity. Without this + // assumption we would have to perform a single iteration of the loop, i.e. use + // `(*r->combiner.get())(r->combiner->identity_element, r->source)[r->value_index]` + // instead of `r->source[r->value_index]`. The former may be more difficult to simplify. + return Select::make(r->condition, + r->source[r->value_index], + r->combiner->identity_element[r->value_index]); + } + return e; +} + Expr Simplify(Expr a, Map<Var, Range> vrange) { // We should not pass an expression having a non-HalideIR op to // Halide::Internal::simplify. Reduce op is the only such op at this time // and it only appears as the top op in an expression. So we strip it // first and send the sub-expressions to the simplifier. if (const Reduce* r = a.as<Reduce>()) { + // If axis is empty, we can remove the reduce op completely. + if (r->axis.empty()) + return Simplify_(RemoveEmptyReduction(a), vrange); + + // Simplify the combiner of the reduction + a = SimplifyCombiner(a, vrange); + r = a.as<Reduce>(); + + // If axis is not empty then we add the information about ranges to vrange + for (const IterVar& iv : r->axis) { + if (vrange.count(iv->var)) { + Range existing_range = vrange[iv->var]; + CHECK(Equal(existing_range->min, iv->dom->min) && + Equal(existing_range->extent, iv->dom->extent)) + << "Simplify was given vrange stating that the range of the reduction var " + << iv << " is " << existing_range << ". This is probably a mistake."; + } + vrange.Set(iv->var, iv->dom); + } + Array<Expr> new_source; for (auto& e : r->source) { new_source.push_back(Simplify_(e, vrange)); diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py index a42230df8..939a08f5b 100644 --- a/tests/python/unittest/test_pass_simplify.py +++ b/tests/python/unittest/test_pass_simplify.py @@ -1,5 +1,7 @@ import tvm import numpy +from tvm import comm_reducer +from tvm.ir_pass import Simplify, CanonicalSimplify, Equal def test_simplify(): """Not yet working, mock design""" @@ -52,8 +54,90 @@ def test_canonical(): ret2 = tvm.ir_pass.CanonicalSimplify(x % 3 + x % 4) assert (tvm.ir_pass.Equal(ret1, ret2)) + +def test_simplify_combiner(): + dummy = tvm.var('dummy') + + prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0)) + + sum_or_prod = comm_reducer(lambda x, y: tvm.expr.Select(dummy < 0, + x + y, x*y), + lambda t0: tvm.expr.Select(dummy < 0, + tvm.const(0, t0), tvm.const(1, t0))) + + sum_and_prod = comm_reducer(lambda x, y: (x[0] + y[0], + x[1]*y[1]), + lambda t0, t1: (tvm.const(0, t0), + tvm.const(5, t0) - tvm.const(4, t0))) + + sum_and_prod2 = comm_reducer(lambda x, y: (x[0] + y[0], + x[1]*y[1] + 0*x[0] + y[0] - y[0]), + lambda t0, t1: (tvm.const(5, t0) - tvm.const(5, t0), + tvm.const(1, t1))) + + some_reducer1 = comm_reducer(lambda x, y: (x[0] + y[0], + x[0] + y[0] + x[1] + y[1], + x[0]*y[2] + y[0]*x[2], + x[1] + y[2], + 4.0), + lambda t0, t1, t2, t3, t4: (tvm.const(0, t0), + tvm.const(1, t1), + tvm.const(2, t2), + tvm.const(3, t3), + tvm.const(4, t4))) + + k = tvm.reduce_axis((0, 10), name="k") + A = tvm.placeholder((10,), name='A') + + # Test that SimplifyCombiner makes use of vranges + vrange = {dummy: tvm.Range(-10, -5)} + assert Equal(Simplify(sum_or_prod(A[k], k), vrange), tvm.sum(A[k], k)) + vrange = {dummy: tvm.Range(5, 10)} + assert Equal(Simplify(sum_or_prod(A[k], k), vrange), prod(A[k], k)) + + assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k)) + assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[1]), prod(A[10-k], k)) + + assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k)) + assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[1]), prod(A[10-k], k)) + + reference_simplified_sources = [[A[0]], + [A[0], A[1]], + [A[0], A[2]], + [A[0], A[1], A[2], A[3]], + [A[4]]] + for j in range(5): + # Here we use the j-th component of the result, so only it and the components it + # depends on are left. + simplified = Simplify(some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j]) + + # Check that the remaining components are the expected ones. + for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]): + assert Equal(lhs, rhs) + + # Test that components with side effects are not removed + side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0) + assert Equal(Simplify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0]), + sum_and_prod((A[k], side_effect(A[10-k])), k)[0]) + assert Equal(Simplify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0]), + tvm.sum(side_effect(A[k]), k)) + + +def test_simplify_reduce(): + k = tvm.reduce_axis((0, 10), name="k") + j = tvm.reduce_axis((-5, 3), name="j") + A = tvm.placeholder((10,), name='A') + + assert Equal(Simplify(tvm.sum(k/10, k)), tvm.sum(tvm.const(0, "int32"), k)) + assert Equal(Simplify(tvm.sum(A[3], [])), A[3]) + assert Equal(Simplify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j])), + tvm.sum(k + j, [k, j])) + + if __name__ == "__main__": test_bound() test_basic() test_simplify() test_canonical() + test_simplify_combiner() + test_simplify_reduce() -- GitLab