From 515d4b6f0a2e8f364dd4720bf675686bb9bcf59d Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Wed, 11 Apr 2018 11:43:38 -0700 Subject: [PATCH] [PASS] More simplifier for mod and div (#1100) * [PASS] More simplifier for mod and div * fix testcase --- src/arithmetic/canonical.cc | 103 +++++++++++++++---- src/pass/lower_warp_memory.cc | 4 +- tests/python/unittest/test_arith_simplify.py | 20 ++++ tests/python/unittest/test_pass_simplify.py | 2 +- 4 files changed, 108 insertions(+), 21 deletions(-) diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 736b8dad7..ed6239961 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -312,9 +312,23 @@ class Canonical::Internal : public IRMutator { return e; } } - // binary ops + // Div operator Expr Mutate_(const Div* op, const Expr& e) final { - return Binary(op, e); + if (!EnableOpt(op->type)) { + return Binary(op, e); + } + CacheEntry a = Produce(op->a); + CacheEntry b = Produce(op->b); + if (a.has_side_effect || b.has_side_effect) { + return Binary_(op, e, a.value, b.value); + } + if (is_const(a.value) && is_const(b.value)) { + return ComputeExpr<Div>(a.value, b.value); + } else if (is_const(b.value)) { + return SumDivConst(a.AsSum(), b.value); + } else { + return Binary(op, e); + } } // Mod operator Expr Mutate_(const Mod* op, const Expr& e) final { @@ -445,29 +459,80 @@ class Canonical::Internal : public IRMutator { } return value; } - // subroutine to do produce a % v - Expr SumModConst(ComExpr a, Expr v) { - int64_t value = GetConstIntValue(v); - std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>(); - int mod_level = 0; - n->base = a->base % value; - if (n->base != 0) mod_level = 1; - for (auto e : a->elem) { - if (e.scale % value == 0) continue; - e.scale = e.scale % value; - if (!EvalSet(v - e.value, var_range_).can_prove_positive()) { - mod_level = 2; + // Detect if a = x * coeff + y, where y \in [0, coeff), x >= 0 + // return true if such detection is successful + // return false if it is not. + std::vector<ComExpr> TryLinearEquation(const ComExpr& a, + const Expr& coeff) { + Type type = coeff.type(); + int64_t value = GetConstIntValue(coeff); + if (value < 0) return {}; + std::shared_ptr<ComExprNode> xnode = std::make_shared<ComExprNode>(); + std::shared_ptr<ComExprNode> ynode = std::make_shared<ComExprNode>(); + if (a->base % value == 0) { + xnode->base = a->base; + } else { + ynode->base = a->base; + } + for (const auto& e : a->elem) { + if (e.scale % value == 0) { + xnode->elem.push_back(e); } else { - ++mod_level; + ynode->elem.push_back(e); } - n->elem.push_back(e); } - // cannot remove mode because there are more than two parts - if (mod_level >= 2) { + Expr yres = Sum2Expr(ComExpr(ynode), type); + IntSet yset = EvalSet(yres, var_range_); + // This relies on the integer division rounds down + // Most cases it is good for integer division. + if (yset.min().type() == type && + can_prove(yset.min() >= make_zero(type)) && + yset.max().type() == type && + can_prove(yset.max() < coeff)) { + xnode->base /= value; + for (auto &e : xnode->elem) { + e.scale /= value; + } + return {ComExpr(xnode), ComExpr(ynode)}; + } else { + return {}; + } + } + // subroutine to do produce a % v + Expr SumModConst(ComExpr a, Expr v) { + std::vector<ComExpr> pair = TryLinearEquation(a, v); + if (pair.size() == 0) { + int64_t value = GetConstIntValue(v); + std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>(); + n->base = a->base % value; + for (auto e : a->elem) { + if (e.scale % value == 0) continue; + e.scale = e.scale % value; + n->elem.push_back(e); + } Expr ret = Sum2Expr(ComExpr(n), v.type()) % v; return Binary(ret.as<Mod>(), ret); } - ret_entry_.sum = ComExpr(n); + ret_entry_.sum = pair[1]; + ret_entry_.max_level = stack_.back().max_level; + ret_entry_.has_side_effect = stack_.back().has_side_effect; + auto it = cache_sum_.find(ret_entry_.sum); + if (it != cache_sum_.end()) { + ret_entry_ = it->second; + } else { + ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type()); + cache_sum_[ret_entry_.sum] = ret_entry_; + } + return ret_entry_.value; + } + // subroutine to do produce a % v + Expr SumDivConst(ComExpr a, Expr v) { + std::vector<ComExpr> pair = TryLinearEquation(a, v); + if (pair.size() == 0) { + Expr ret = Sum2Expr(a, v.type()) / v; + return Binary(ret.as<Div>(), ret); + } + ret_entry_.sum = pair[0]; ret_entry_.max_level = stack_.back().max_level; ret_entry_.has_side_effect = stack_.back().has_side_effect; auto it = cache_sum_.find(ret_entry_.sum); diff --git a/src/pass/lower_warp_memory.cc b/src/pass/lower_warp_memory.cc index d0412adc0..8f153fd61 100644 --- a/src/pass/lower_warp_memory.cc +++ b/src/pass/lower_warp_memory.cc @@ -279,7 +279,9 @@ class WarpMemoryRewriter : private IRMutator { Stmt Rewrite(Stmt stmt) { if (warp_size_ == 1) return stmt; - return this->Mutate(stmt); + stmt = this->Mutate(stmt); + stmt = CanonicalSimplify(stmt); + return stmt; } private: diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py index 8ce1773ee..e6689dddf 100644 --- a/tests/python/unittest/test_arith_simplify.py +++ b/tests/python/unittest/test_arith_simplify.py @@ -37,6 +37,26 @@ def test_simplify_mod(): assert index == j +def test_modular(): + rx = tvm.var("rx") + ry = tvm.var("ry") + y = tvm.var("y") + x = tvm.var("x") + vmap = {rx: tvm.Range(tvm.const(0), tvm.const(3)), + ry: tvm.Range(tvm.const(0), tvm.const(3)), + y: tvm.Range(tvm.const(0), tvm.const(2)), + x: tvm.Range(tvm.const(0), tvm.const(14))} + idx = ry * 16 + rx + y * 16 + x + z1 = tvm.ir_pass.CanonicalSimplify(idx // 16, vmap) + z2 = tvm.ir_pass.CanonicalSimplify(idx % 16, vmap) + assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0 + assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0 + + + + + if __name__ == "__main__": test_simplify_mod() + test_modular() test_simplify() diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py index 29b5b3a84..c38083822 100644 --- a/tests/python/unittest/test_pass_simplify.py +++ b/tests/python/unittest/test_pass_simplify.py @@ -33,7 +33,6 @@ def test_bound(): ret = tvm.ir_pass.Simplify(m % 10, vrange) assert ret == m - def test_canonical(): x = tvm.var("x") z = tvm.const(3) @@ -54,6 +53,7 @@ def test_canonical(): assert (tvm.ir_pass.Equal(ret1, ret2)) if __name__ == "__main__": + test_modular() test_bound() test_basic() test_simplify() -- GitLab