From 515d4b6f0a2e8f364dd4720bf675686bb9bcf59d Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Wed, 11 Apr 2018 11:43:38 -0700
Subject: [PATCH] [PASS] More simplifier for mod and div (#1100)

* [PASS] More simplifier for mod and div

* fix testcase
---
 src/arithmetic/canonical.cc                  | 103 +++++++++++++++----
 src/pass/lower_warp_memory.cc                |   4 +-
 tests/python/unittest/test_arith_simplify.py |  20 ++++
 tests/python/unittest/test_pass_simplify.py  |   2 +-
 4 files changed, 108 insertions(+), 21 deletions(-)

diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc
index 736b8dad7..ed6239961 100644
--- a/src/arithmetic/canonical.cc
+++ b/src/arithmetic/canonical.cc
@@ -312,9 +312,23 @@ class Canonical::Internal : public IRMutator {
       return e;
     }
   }
-  // binary ops
+  // Div operator
   Expr Mutate_(const Div* op, const Expr& e) final {
-    return Binary(op, e);
+    if (!EnableOpt(op->type)) {
+      return Binary(op, e);
+    }
+    CacheEntry a = Produce(op->a);
+    CacheEntry b = Produce(op->b);
+    if (a.has_side_effect || b.has_side_effect) {
+      return Binary_(op, e, a.value, b.value);
+    }
+    if (is_const(a.value) && is_const(b.value)) {
+      return ComputeExpr<Div>(a.value, b.value);
+    } else if (is_const(b.value)) {
+      return SumDivConst(a.AsSum(), b.value);
+    } else {
+      return Binary(op, e);
+    }
   }
   // Mod operator
   Expr Mutate_(const Mod* op, const Expr& e) final {
@@ -445,29 +459,80 @@ class Canonical::Internal : public IRMutator {
     }
     return value;
   }
-  // subroutine to do produce a % v
-  Expr SumModConst(ComExpr a, Expr v) {
-    int64_t value = GetConstIntValue(v);
-    std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
-    int mod_level = 0;
-    n->base = a->base % value;
-    if (n->base != 0) mod_level = 1;
-    for (auto e : a->elem) {
-      if (e.scale % value == 0) continue;
-      e.scale = e.scale % value;
-      if (!EvalSet(v - e.value, var_range_).can_prove_positive()) {
-        mod_level = 2;
+  // Detect if a = x * coeff + y, where y \in [0, coeff), x >= 0
+  // return true if such detection is successful
+  // return false if it is not.
+  std::vector<ComExpr> TryLinearEquation(const ComExpr& a,
+                                         const Expr& coeff) {
+    Type type = coeff.type();
+    int64_t value = GetConstIntValue(coeff);
+    if (value < 0) return {};
+    std::shared_ptr<ComExprNode> xnode = std::make_shared<ComExprNode>();
+    std::shared_ptr<ComExprNode> ynode = std::make_shared<ComExprNode>();
+    if (a->base % value == 0) {
+      xnode->base = a->base;
+    } else {
+      ynode->base = a->base;
+    }
+    for (const auto& e : a->elem) {
+      if (e.scale % value == 0) {
+        xnode->elem.push_back(e);
       } else {
-        ++mod_level;
+        ynode->elem.push_back(e);
       }
-      n->elem.push_back(e);
     }
-    // cannot remove mode because there are more than two parts
-    if (mod_level >= 2) {
+    Expr yres = Sum2Expr(ComExpr(ynode), type);
+    IntSet yset = EvalSet(yres, var_range_);
+    // This relies on the integer division rounds down
+    // Most cases it is good for integer division.
+    if (yset.min().type() == type &&
+        can_prove(yset.min() >= make_zero(type)) &&
+        yset.max().type() == type &&
+        can_prove(yset.max() < coeff)) {
+      xnode->base /= value;
+      for (auto &e : xnode->elem) {
+        e.scale /= value;
+      }
+      return {ComExpr(xnode), ComExpr(ynode)};
+    } else {
+      return {};
+    }
+  }
+  // subroutine to do produce a % v
+  Expr SumModConst(ComExpr a, Expr v) {
+    std::vector<ComExpr> pair = TryLinearEquation(a, v);
+    if (pair.size() == 0) {
+      int64_t value = GetConstIntValue(v);
+      std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
+      n->base = a->base % value;
+      for (auto e : a->elem) {
+        if (e.scale % value == 0) continue;
+        e.scale = e.scale % value;
+        n->elem.push_back(e);
+      }
       Expr ret = Sum2Expr(ComExpr(n), v.type()) % v;
       return Binary(ret.as<Mod>(), ret);
     }
-    ret_entry_.sum = ComExpr(n);
+    ret_entry_.sum = pair[1];
+    ret_entry_.max_level = stack_.back().max_level;
+    ret_entry_.has_side_effect = stack_.back().has_side_effect;
+    auto it = cache_sum_.find(ret_entry_.sum);
+    if (it != cache_sum_.end()) {
+      ret_entry_ = it->second;
+    } else {
+      ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
+      cache_sum_[ret_entry_.sum] = ret_entry_;
+    }
+    return ret_entry_.value;
+  }
+  // subroutine to do produce a % v
+  Expr SumDivConst(ComExpr a, Expr v) {
+    std::vector<ComExpr> pair = TryLinearEquation(a, v);
+    if (pair.size() == 0) {
+      Expr ret = Sum2Expr(a, v.type()) / v;
+      return Binary(ret.as<Div>(), ret);
+    }
+    ret_entry_.sum = pair[0];
     ret_entry_.max_level = stack_.back().max_level;
     ret_entry_.has_side_effect = stack_.back().has_side_effect;
     auto it = cache_sum_.find(ret_entry_.sum);
diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc
index d0412adc0..8f153fd61 100644
--- a/src/pass/lower_warp_memory.cc
+++ b/src/pass/lower_warp_memory.cc
@@ -279,7 +279,9 @@ class WarpMemoryRewriter : private IRMutator {
 
   Stmt Rewrite(Stmt stmt) {
     if (warp_size_ == 1) return stmt;
-    return this->Mutate(stmt);
+    stmt = this->Mutate(stmt);
+    stmt = CanonicalSimplify(stmt);
+    return stmt;
   }
 
  private:
diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py
index 8ce1773ee..e6689dddf 100644
--- a/tests/python/unittest/test_arith_simplify.py
+++ b/tests/python/unittest/test_arith_simplify.py
@@ -37,6 +37,26 @@ def test_simplify_mod():
     assert index == j
 
 
+def test_modular():
+    rx = tvm.var("rx")
+    ry = tvm.var("ry")
+    y = tvm.var("y")
+    x = tvm.var("x")
+    vmap = {rx: tvm.Range(tvm.const(0), tvm.const(3)),
+            ry: tvm.Range(tvm.const(0), tvm.const(3)),
+            y: tvm.Range(tvm.const(0), tvm.const(2)),
+            x: tvm.Range(tvm.const(0), tvm.const(14))}
+    idx = ry * 16 + rx + y * 16 + x
+    z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap)
+    z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap)
+    assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0
+    assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0
+
+
+
+
+
 if __name__ == "__main__":
     test_simplify_mod()
+    test_modular()
     test_simplify()
diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py
index 29b5b3a84..c38083822 100644
--- a/tests/python/unittest/test_pass_simplify.py
+++ b/tests/python/unittest/test_pass_simplify.py
@@ -33,7 +33,6 @@ def test_bound():
     ret = tvm.ir_pass.Simplify(m % 10, vrange)
     assert ret == m
 
-
 def test_canonical():
     x = tvm.var("x")
     z = tvm.const(3)
@@ -54,6 +53,7 @@ def test_canonical():
     assert (tvm.ir_pass.Equal(ret1, ret2))
 
 if __name__ == "__main__":
+    test_modular()
     test_bound()
     test_basic()
     test_simplify()
-- 
GitLab