From 547a0913f69c29deb89b38098880d76fbc563f49 Mon Sep 17 00:00:00 2001
From: Sergei Grechanik <grechanik.sergey@huawei.com>
Date: Fri, 11 Jan 2019 21:21:24 +0300
Subject: [PATCH] [TVM] Reduction simplification improvements (#2284)

---
 src/arithmetic/canonical.cc                 | 126 ++++++++++++++++++++
 tests/python/unittest/test_pass_simplify.py |  84 +++++++++++++
 2 files changed, 210 insertions(+)

diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc
index 0ec306213..5ba602bc3 100644
--- a/src/arithmetic/canonical.cc
+++ b/src/arithmetic/canonical.cc
@@ -781,12 +781,138 @@ T Simplify_(T a, Map<Var, Range> vrange) {
 }
 
 
+/*!
+ * \brief Simplify just the combiner of the given reduce node.
+ *
+ *  This function applies Simplify to the components of the top reduction's
+ *  combiner, but not to the source or condition of the reduction.
+ *  It also removes all components which are not used to
+ *  compute the resulting value (the value_index-th value).
+ *
+ *  If \p expr is not a reduction node, it is left unchanged.
+ *
+ * \param expr The expression to be simplifed.
+ * \return Simplified expression.
+ */
+Expr SimplifyCombiner(const Expr& expr, const Map<Var, Range>& vrange = Map<Var, Range>()) {
+  const Reduce* op = expr.as<Reduce>();
+  if (!op) {
+    return expr;
+  }
+
+  // First simplify the results
+  Array<Expr> simplified_result;
+  for (const auto& res : op->combiner->result) {
+    simplified_result.push_back(Simplify(res, vrange));
+  }
+
+  // Which components to keep
+  std::vector<int> used(op->combiner->result.size(), false);
+
+  // This function recursively marks the used components starting from
+  // the index idx
+  std::function<void(int)> mark_used;
+  mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) {
+    // if the idx-th component was marked as used before, do nothing
+    if (used[idx]) return;
+    used[idx] = true;
+
+    // check if the idx-th result expr uses some lhs or rhs variables
+    // and recursively mark the corresponding components
+    for (size_t i = 0; i < simplified_result.size(); ++i)
+      if (!used[i]) {
+        if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) ||
+            ExprUseVar(simplified_result[idx], op->combiner->rhs[i]))
+          mark_used(i);
+      }
+  };
+
+  // mark all used components starting from the value_index
+  mark_used(op->value_index);
+
+  // components which have side effects should also be preserved
+  for (size_t i = 0; i < used.size(); ++i) {
+    if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) ||
+        HasSideEffect(op->combiner->result[i])) {
+      mark_used(i);
+    }
+  }
+
+  int new_value_index = op->value_index;
+  Array<Expr> new_result;
+  Array<Expr> new_identity;
+  Array<Var> new_lhs;
+  Array<Var> new_rhs;
+  Array<Expr> new_source;
+
+  // new stuff is old stuff which is used
+  for (size_t i = 0; i < used.size(); ++i) {
+    if (used[i]) {
+      // We simplify the result and identity, but not the source
+      new_result.push_back(simplified_result[i]);
+      new_identity.push_back(Simplify(op->combiner->identity_element[i], vrange));
+      new_lhs.push_back(op->combiner->lhs[i]);
+      new_rhs.push_back(op->combiner->rhs[i]);
+      new_source.push_back(op->source[i]);
+    } else if (static_cast<int>(i) < op->value_index) {
+      // value_index should also be adjusted
+      new_value_index--;
+    }
+  }
+
+  CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity);
+  return Reduce::make(new_combiner, new_source, op->axis, op->condition, new_value_index);
+}
+
+/*!
+ * \brief Remove a single reduction over empty axis.
+ *
+ *  If \p e is a reduction node and its axis is empty, replace it with its source,
+ *  otherwise return \p e unchanged.
+ *
+ * \param e The expression to be transformed.
+ * \return The transformed expression.
+ */
+Expr RemoveEmptyReduction(const Expr& e) {
+  const Reduce* r = e.as<Reduce>();
+  if (r && r->axis.empty()) {
+    // Note that here we assume that the identity element is indeed identity. Without this
+    // assumption we would have to perform a single iteration of the loop, i.e. use
+    // `(*r->combiner.get())(r->combiner->identity_element, r->source)[r->value_index]`
+    // instead of `r->source[r->value_index]`. The former may be more difficult to simplify.
+    return Select::make(r->condition,
+                        r->source[r->value_index],
+                        r->combiner->identity_element[r->value_index]);
+  }
+  return e;
+}
+
 Expr Simplify(Expr a, Map<Var, Range> vrange) {
   // We should not pass an expression having a non-HalideIR op to
   // Halide::Internal::simplify. Reduce op is the only such op at this time
   // and it only appears as the top op in an expression. So we strip it
   // first and send the sub-expressions to the simplifier.
   if (const Reduce* r = a.as<Reduce>()) {
+    // If axis is empty, we can remove the reduce op completely.
+    if (r->axis.empty())
+      return Simplify_(RemoveEmptyReduction(a), vrange);
+
+    // Simplify the combiner of the reduction
+    a = SimplifyCombiner(a, vrange);
+    r = a.as<Reduce>();
+
+    // If axis is not empty then we add the information about ranges to vrange
+    for (const IterVar& iv : r->axis) {
+      if (vrange.count(iv->var)) {
+        Range existing_range = vrange[iv->var];
+        CHECK(Equal(existing_range->min, iv->dom->min) &&
+              Equal(existing_range->extent, iv->dom->extent))
+          << "Simplify was given vrange stating that the range of the reduction var "
+          << iv << " is " << existing_range << ". This is probably a mistake.";
+      }
+      vrange.Set(iv->var, iv->dom);
+    }
+
     Array<Expr> new_source;
     for (auto& e : r->source) {
       new_source.push_back(Simplify_(e, vrange));
diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py
index a42230df8..939a08f5b 100644
--- a/tests/python/unittest/test_pass_simplify.py
+++ b/tests/python/unittest/test_pass_simplify.py
@@ -1,5 +1,7 @@
 import tvm
 import numpy
+from tvm import comm_reducer
+from tvm.ir_pass import Simplify, CanonicalSimplify, Equal
 
 def test_simplify():
     """Not yet working, mock design"""
@@ -52,8 +54,90 @@ def test_canonical():
     ret2 = tvm.ir_pass.CanonicalSimplify(x % 3 + x % 4)
     assert (tvm.ir_pass.Equal(ret1, ret2))
 
+
+def test_simplify_combiner():
+    dummy = tvm.var('dummy')
+
+    prod = comm_reducer(lambda x, y: x*y, lambda t0: tvm.const(1, t0))
+
+    sum_or_prod = comm_reducer(lambda x, y: tvm.expr.Select(dummy < 0,
+                                                            x + y, x*y),
+                               lambda t0: tvm.expr.Select(dummy < 0,
+                                                          tvm.const(0, t0), tvm.const(1, t0)))
+
+    sum_and_prod = comm_reducer(lambda x, y: (x[0] + y[0],
+                                              x[1]*y[1]),
+                                lambda t0, t1: (tvm.const(0, t0),
+                                                tvm.const(5, t0) - tvm.const(4, t0)))
+
+    sum_and_prod2 = comm_reducer(lambda x, y: (x[0] + y[0],
+                                               x[1]*y[1] + 0*x[0] + y[0] - y[0]),
+                                 lambda t0, t1: (tvm.const(5, t0) - tvm.const(5, t0),
+                                                 tvm.const(1, t1)))
+
+    some_reducer1 = comm_reducer(lambda x, y: (x[0] + y[0],
+                                               x[0] + y[0] + x[1] + y[1],
+                                               x[0]*y[2] + y[0]*x[2],
+                                               x[1] + y[2],
+                                               4.0),
+                                 lambda t0, t1, t2, t3, t4: (tvm.const(0, t0),
+                                                             tvm.const(1, t1),
+                                                             tvm.const(2, t2),
+                                                             tvm.const(3, t3),
+                                                             tvm.const(4, t4)))
+
+    k = tvm.reduce_axis((0, 10), name="k")
+    A = tvm.placeholder((10,), name='A')
+
+    # Test that SimplifyCombiner makes use of vranges
+    vrange = {dummy: tvm.Range(-10, -5)}
+    assert Equal(Simplify(sum_or_prod(A[k], k), vrange), tvm.sum(A[k], k))
+    vrange = {dummy: tvm.Range(5, 10)}
+    assert Equal(Simplify(sum_or_prod(A[k], k), vrange), prod(A[k], k))
+
+    assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k))
+    assert Equal(Simplify(sum_and_prod((A[k], A[10-k]), k)[1]), prod(A[10-k], k))
+
+    assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[0]), tvm.sum(A[k], k))
+    assert Equal(Simplify(sum_and_prod2((A[k], A[10-k]), k)[1]), prod(A[10-k], k))
+
+    reference_simplified_sources = [[A[0]],
+                                    [A[0], A[1]],
+                                    [A[0], A[2]],
+                                    [A[0], A[1], A[2], A[3]],
+                                    [A[4]]]
+    for j in range(5):
+        # Here we use the j-th component of the result, so only it and the components it
+        # depends on are left.
+        simplified = Simplify(some_reducer1((A[0], A[1], A[2], A[3], A[4]), k)[j])
+
+        # Check that the remaining components are the expected ones.
+        for lhs, rhs in zip(simplified.source, reference_simplified_sources[j]):
+            assert Equal(lhs, rhs)
+
+    # Test that components with side effects are not removed
+    side_effect = lambda *xs: tvm.make.Call("int32", "dummy", xs, tvm.expr.Call.Intrinsic, None, 0)
+    assert Equal(Simplify(sum_and_prod((A[k], side_effect(A[10-k])), k)[0]),
+                 sum_and_prod((A[k], side_effect(A[10-k])), k)[0])
+    assert Equal(Simplify(sum_and_prod((side_effect(A[k]), A[10-k]), k)[0]),
+                 tvm.sum(side_effect(A[k]), k))
+
+
+def test_simplify_reduce():
+    k = tvm.reduce_axis((0, 10), name="k")
+    j = tvm.reduce_axis((-5, 3), name="j")
+    A = tvm.placeholder((10,), name='A')
+
+    assert Equal(Simplify(tvm.sum(k/10, k)), tvm.sum(tvm.const(0, "int32"), k))
+    assert Equal(Simplify(tvm.sum(A[3], [])), A[3])
+    assert Equal(Simplify(tvm.sum(tvm.expr.Select(k + j < 12, k + j, 0), [k, j])),
+                 tvm.sum(k + j, [k, j]))
+
+
 if __name__ == "__main__":
     test_bound()
     test_basic()
     test_simplify()
     test_canonical()
+    test_simplify_combiner()
+    test_simplify_reduce()
-- 
GitLab