From 2bb1d8e458463a8528d79303d502f249af3fcc30 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Tue, 28 Nov 2017 13:26:17 -0800
Subject: [PATCH] [ARITH] Upgrade CanonicalSimplify to Simplify Mod (#676)

---
 include/tvm/ir_pass.h                        |  8 +-
 src/api/api_pass.cc                          | 12 ++-
 src/arithmetic/canonical.cc                  | 83 +++++++++++++++++---
 src/arithmetic/canonical.h                   |  2 +-
 tests/python/unittest/test_arith_simplify.py | 18 +++++
 5 files changed, 106 insertions(+), 17 deletions(-)

diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h
index b6b248228..d0f32478e 100644
--- a/include/tvm/ir_pass.h
+++ b/include/tvm/ir_pass.h
@@ -41,16 +41,20 @@ Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>());
 /*!
  * \brief Simplify by applying canonical form.
  * \param stmt The statement to be canonically simplifed.
+ * \param vrange The range information about the variable.
  * \return Canonicalized statement.
  */
-Stmt CanonicalSimplify(Stmt stmt);
+Stmt CanonicalSimplify(Stmt stmt,
+                       Map<Var, Range> vrange = Map<Var, Range>());
 
 /*!
  * \brief Simplify by applying canonical form.
  * \param expr The statement to be canonically simplifed.
+ * \param vrange The range information about the variable.
  * \return Canonicalized expression.
  */
-Expr CanonicalSimplify(Expr expr);
+Expr CanonicalSimplify(Expr expr,
+                       Map<Var, Range> vrange = Map<Var, Range>());
 
 /*!
  * \brief Deep compare lhs and rhs
diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc
index 024af23a3..23deb03af 100644
--- a/src/api/api_pass.cc
+++ b/src/api/api_pass.cc
@@ -33,9 +33,17 @@ TVM_REGISTER_API("ir_pass.Simplify")
 TVM_REGISTER_API("ir_pass.CanonicalSimplify")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
     if (args[0].IsNodeType<Stmt>()) {
-      *ret = CanonicalSimplify(args[0].operator Stmt());
+      if (args.size() > 1) {
+        *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]);
+      } else {
+        *ret = CanonicalSimplify(args[0].operator Stmt());
+      }
     } else {
-      *ret = CanonicalSimplify(args[0].operator Expr());
+      if (args.size() > 1) {
+        *ret = CanonicalSimplify(args[0].operator Expr(), args[1]);
+      } else {
+        *ret = CanonicalSimplify(args[0].operator Expr());
+      }
     }
   });
 
diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc
index 808e070ef..e7f9da1b4 100644
--- a/src/arithmetic/canonical.cc
+++ b/src/arithmetic/canonical.cc
@@ -129,6 +129,11 @@ inline Expr Binary_(const T* op,
 // internal of canonical engine.
 class Canonical::Internal : public IRMutator {
  public:
+  explicit Internal(Map<Var, Range> vrange) {
+    for (auto kv : vrange) {
+      SetRange(kv.first, kv.second, 0);
+    }
+  }
   // stack entry.
   struct StackEntry {
     int max_level{0};
@@ -300,9 +305,25 @@ class Canonical::Internal : public IRMutator {
   Expr Mutate_(const Div* op, const Expr& e) final {
     return Binary(op, e);
   }
+  // Mod operator
   Expr Mutate_(const Mod* op, const Expr& e) final {
-    return Binary(op, e);
+    if (!EnableOpt(op->type)) {
+      return Binary(op, e);
+    }
+    CacheEntry a = Produce(op->a);
+    CacheEntry b = Produce(op->b);
+    if (a.has_side_effect || b.has_side_effect) {
+      return Binary_(op, e, a.value, b.value);
+    }
+    if (is_const(a.value) && is_const(b.value)) {
+      return ComputeExpr<Mul>(a.value, b.value);
+    } else if (is_const(b.value)) {
+      return SumModConst(a.AsSum(), b.value);
+    } else {
+      return Binary(op, e);
+    }
   }
+
   Expr Mutate_(const And* op, const Expr& e) final {
     Expr expr = IRMutator::Mutate_(op, e);
     op = expr.as<And>();
@@ -367,7 +388,7 @@ class Canonical::Internal : public IRMutator {
 
  private:
   template<typename T>
-  Expr Binary(const T* op, const Expr& e) {
+  Expr Binary(const T* op, Expr e) {
     Expr a = this->Mutate(op->a);
     Expr b = this->Mutate(op->b);
     BinaryExpr key{static_cast<int>(T::_type_info), a, b};
@@ -398,8 +419,8 @@ class Canonical::Internal : public IRMutator {
   std::vector<Var> var_rec_;
   // level counter
   int level_counter_{0};
-  // subroutine to do produce
-  Expr SumMulConst(ComExpr a, Expr v) {
+  // get constant int value
+  int64_t GetConstIntValue(const Expr& v) {
     int64_t value = 0;
     const int64_t *v1 = as_const_int(v);
     const uint64_t *v2 = as_const_uint(v);
@@ -411,7 +432,45 @@ class Canonical::Internal : public IRMutator {
                static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
       value = static_cast<int64_t>(*v2);
     }
-
+    return value;
+  }
+  // subroutine to do produce a % v
+  Expr SumModConst(ComExpr a, Expr v) {
+    int64_t value = GetConstIntValue(v);
+    std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>();
+    int mod_level = 0;
+    n->base = a->base % value;
+    if (n->base != 0) mod_level = 1;
+    for (auto e : a->elem) {
+      if (e.scale % value == 0) continue;
+      e.scale = e.scale % value;
+      if (!EvalSet(v - e.value, var_range_).can_prove_positive()) {
+        mod_level = 2;
+      } else {
+        ++mod_level;
+      }
+      n->elem.push_back(e);
+    }
+    // cannot remove mode because there are more than two parts
+    if (mod_level >= 2) {
+      Expr ret = Sum2Expr(ComExpr(n), v.type()) % v;
+      return Binary(ret.as<Mod>(), ret);
+    }
+    ret_entry_.sum = ComExpr(n);
+    ret_entry_.max_level = stack_.back().max_level;
+    ret_entry_.has_side_effect = stack_.back().has_side_effect;
+    auto it = cache_sum_.find(ret_entry_.sum);
+    if (it != cache_sum_.end()) {
+      ret_entry_ = it->second;
+    } else {
+      ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type());
+      cache_sum_[ret_entry_.sum] = ret_entry_;
+    }
+    return ret_entry_.value;
+  }
+  // subroutine to do produce
+  Expr SumMulConst(ComExpr a, Expr v) {
+    int64_t value = GetConstIntValue(v);
     if (value == 0) {
       return make_zero(v.type());
     }
@@ -421,9 +480,9 @@ class Canonical::Internal : public IRMutator {
     for (auto& e : vsum->elem) {
       e.scale *= value;
     }
+    ret_entry_.sum = ComExpr(vsum);
     ret_entry_.max_level = stack_.back().max_level;
     ret_entry_.has_side_effect = stack_.back().has_side_effect;
-    ret_entry_.sum = ComExpr(vsum);
     auto it = cache_sum_.find(ret_entry_.sum);
     if (it != cache_sum_.end()) {
       ret_entry_ = it->second;
@@ -536,8 +595,8 @@ class Canonical::Internal : public IRMutator {
 
 using CInternal = Canonical::Internal;
 
-Canonical::Canonical()
-    : ptr_(std::make_shared<Internal>()) {}
+Canonical::Canonical(Map<Var, Range> vrange)
+    : ptr_(std::make_shared<Internal>(vrange)) {}
 
 Expr Canonical::Simplify(Expr expr) {
   return ptr_->Mutate(expr);
@@ -553,12 +612,12 @@ void Canonical::SetRange(Var v, Range r, int level) {
 }  // namespace arith
 
 namespace ir {
-Stmt CanonicalSimplify(Stmt stmt) {
-  return arith::Canonical().Simplify(stmt);
+Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) {
+  return arith::Canonical(vrange).Simplify(stmt);
 }
 
-Expr CanonicalSimplify(Expr expr) {
-  return arith::Canonical().Simplify(expr);
+Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) {
+  return arith::Canonical(vrange).Simplify(expr);
 }
 
 template<typename T>
diff --git a/src/arithmetic/canonical.h b/src/arithmetic/canonical.h
index 174acc20a..37f9a178f 100644
--- a/src/arithmetic/canonical.h
+++ b/src/arithmetic/canonical.h
@@ -22,7 +22,7 @@ namespace arith {
 class Canonical {
  public:
   /*! \brief constructor */
-  Canonical();
+  explicit Canonical(Map<Var, Range> var_range);
   /*!
    * \brief simplify expression e.
    * \param expr The expression to be simplified.
diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py
index 9ff8571ea..8ce1773ee 100644
--- a/tests/python/unittest/test_arith_simplify.py
+++ b/tests/python/unittest/test_arith_simplify.py
@@ -20,5 +20,23 @@ def test_simplify():
     zz = zz.a
     assert zz.a == x and zz.b.value == 4
 
+def test_simplify_mod():
+    """Not yet working, mock design"""
+    ib = tvm.ir_builder.create()
+    n = tvm.var('n')
+    j = tvm.var('j')
+    A = ib.pointer("float32", name="A")
+    with ib.for_range(0, 16, name="i") as i:
+        A[i] = A[((n * 4 + j * 2) * 8 + i+1) % 16]
+    body = ib.get()
+    stmt = tvm.ir_pass.CanonicalSimplify(body)
+    diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16)
+    assert diff.value == 0
+    index = tvm.ir_pass.CanonicalSimplify(
+        (j + n * 32) % 16, {j: tvm.Range(0, 6)})
+    assert index == j
+
+
 if __name__ == "__main__":
+    test_simplify_mod()
     test_simplify()
-- 
GitLab