From 2bb1d8e458463a8528d79303d502f249af3fcc30 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Tue, 28 Nov 2017 13:26:17 -0800 Subject: [PATCH] [ARITH] Upgrade CanonicalSimplify to Simplify Mod (#676) --- include/tvm/ir_pass.h | 8 +- src/api/api_pass.cc | 12 ++- src/arithmetic/canonical.cc | 83 +++++++++++++++++--- src/arithmetic/canonical.h | 2 +- tests/python/unittest/test_arith_simplify.py | 18 +++++ 5 files changed, 106 insertions(+), 17 deletions(-) diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index b6b248228..d0f32478e 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -41,16 +41,20 @@ Stmt Simplify(Stmt stmt, Map<Var, Range> vrange = Map<Var, Range>()); /*! * \brief Simplify by applying canonical form. * \param stmt The statement to be canonically simplifed. + * \param vrange The range information about the variable. * \return Canonicalized statement. */ -Stmt CanonicalSimplify(Stmt stmt); +Stmt CanonicalSimplify(Stmt stmt, + Map<Var, Range> vrange = Map<Var, Range>()); /*! * \brief Simplify by applying canonical form. * \param expr The statement to be canonically simplifed. + * \param vrange The range information about the variable. * \return Canonicalized expression. */ -Expr CanonicalSimplify(Expr expr); +Expr CanonicalSimplify(Expr expr, + Map<Var, Range> vrange = Map<Var, Range>()); /*! * \brief Deep compare lhs and rhs diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 024af23a3..23deb03af 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -33,9 +33,17 @@ TVM_REGISTER_API("ir_pass.Simplify") TVM_REGISTER_API("ir_pass.CanonicalSimplify") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsNodeType<Stmt>()) { - *ret = CanonicalSimplify(args[0].operator Stmt()); + if (args.size() > 1) { + *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]); + } else { + *ret = CanonicalSimplify(args[0].operator Stmt()); + } } else { - *ret = CanonicalSimplify(args[0].operator Expr()); + if (args.size() > 1) { + *ret = CanonicalSimplify(args[0].operator Expr(), args[1]); + } else { + *ret = CanonicalSimplify(args[0].operator Expr()); + } } }); diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 808e070ef..e7f9da1b4 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -129,6 +129,11 @@ inline Expr Binary_(const T* op, // internal of canonical engine. class Canonical::Internal : public IRMutator { public: + explicit Internal(Map<Var, Range> vrange) { + for (auto kv : vrange) { + SetRange(kv.first, kv.second, 0); + } + } // stack entry. struct StackEntry { int max_level{0}; @@ -300,9 +305,25 @@ class Canonical::Internal : public IRMutator { Expr Mutate_(const Div* op, const Expr& e) final { return Binary(op, e); } + // Mod operator Expr Mutate_(const Mod* op, const Expr& e) final { - return Binary(op, e); + if (!EnableOpt(op->type)) { + return Binary(op, e); + } + CacheEntry a = Produce(op->a); + CacheEntry b = Produce(op->b); + if (a.has_side_effect || b.has_side_effect) { + return Binary_(op, e, a.value, b.value); + } + if (is_const(a.value) && is_const(b.value)) { + return ComputeExpr<Mul>(a.value, b.value); + } else if (is_const(b.value)) { + return SumModConst(a.AsSum(), b.value); + } else { + return Binary(op, e); + } } + Expr Mutate_(const And* op, const Expr& e) final { Expr expr = IRMutator::Mutate_(op, e); op = expr.as<And>(); @@ -367,7 +388,7 @@ class Canonical::Internal : public IRMutator { private: template<typename T> - Expr Binary(const T* op, const Expr& e) { + Expr Binary(const T* op, Expr e) { Expr a = this->Mutate(op->a); Expr b = this->Mutate(op->b); BinaryExpr key{static_cast<int>(T::_type_info), a, b}; @@ -398,8 +419,8 @@ class Canonical::Internal : public IRMutator { std::vector<Var> var_rec_; // level counter int level_counter_{0}; - // subroutine to do produce - Expr SumMulConst(ComExpr a, Expr v) { + // get constant int value + int64_t GetConstIntValue(const Expr& v) { int64_t value = 0; const int64_t *v1 = as_const_int(v); const uint64_t *v2 = as_const_uint(v); @@ -411,7 +432,45 @@ class Canonical::Internal : public IRMutator { static_cast<uint64_t>(std::numeric_limits<int64_t>::max())); value = static_cast<int64_t>(*v2); } - + return value; + } + // subroutine to do produce a % v + Expr SumModConst(ComExpr a, Expr v) { + int64_t value = GetConstIntValue(v); + std::shared_ptr<ComExprNode> n = std::make_shared<ComExprNode>(); + int mod_level = 0; + n->base = a->base % value; + if (n->base != 0) mod_level = 1; + for (auto e : a->elem) { + if (e.scale % value == 0) continue; + e.scale = e.scale % value; + if (!EvalSet(v - e.value, var_range_).can_prove_positive()) { + mod_level = 2; + } else { + ++mod_level; + } + n->elem.push_back(e); + } + // cannot remove mode because there are more than two parts + if (mod_level >= 2) { + Expr ret = Sum2Expr(ComExpr(n), v.type()) % v; + return Binary(ret.as<Mod>(), ret); + } + ret_entry_.sum = ComExpr(n); + ret_entry_.max_level = stack_.back().max_level; + ret_entry_.has_side_effect = stack_.back().has_side_effect; + auto it = cache_sum_.find(ret_entry_.sum); + if (it != cache_sum_.end()) { + ret_entry_ = it->second; + } else { + ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type()); + cache_sum_[ret_entry_.sum] = ret_entry_; + } + return ret_entry_.value; + } + // subroutine to do produce + Expr SumMulConst(ComExpr a, Expr v) { + int64_t value = GetConstIntValue(v); if (value == 0) { return make_zero(v.type()); } @@ -421,9 +480,9 @@ class Canonical::Internal : public IRMutator { for (auto& e : vsum->elem) { e.scale *= value; } + ret_entry_.sum = ComExpr(vsum); ret_entry_.max_level = stack_.back().max_level; ret_entry_.has_side_effect = stack_.back().has_side_effect; - ret_entry_.sum = ComExpr(vsum); auto it = cache_sum_.find(ret_entry_.sum); if (it != cache_sum_.end()) { ret_entry_ = it->second; @@ -536,8 +595,8 @@ class Canonical::Internal : public IRMutator { using CInternal = Canonical::Internal; -Canonical::Canonical() - : ptr_(std::make_shared<Internal>()) {} +Canonical::Canonical(Map<Var, Range> vrange) + : ptr_(std::make_shared<Internal>(vrange)) {} Expr Canonical::Simplify(Expr expr) { return ptr_->Mutate(expr); @@ -553,12 +612,12 @@ void Canonical::SetRange(Var v, Range r, int level) { } // namespace arith namespace ir { -Stmt CanonicalSimplify(Stmt stmt) { - return arith::Canonical().Simplify(stmt); +Stmt CanonicalSimplify(Stmt stmt, Map<Var, Range> vrange) { + return arith::Canonical(vrange).Simplify(stmt); } -Expr CanonicalSimplify(Expr expr) { - return arith::Canonical().Simplify(expr); +Expr CanonicalSimplify(Expr expr, Map<Var, Range> vrange) { + return arith::Canonical(vrange).Simplify(expr); } template<typename T> diff --git a/src/arithmetic/canonical.h b/src/arithmetic/canonical.h index 174acc20a..37f9a178f 100644 --- a/src/arithmetic/canonical.h +++ b/src/arithmetic/canonical.h @@ -22,7 +22,7 @@ namespace arith { class Canonical { public: /*! \brief constructor */ - Canonical(); + explicit Canonical(Map<Var, Range> var_range); /*! * \brief simplify expression e. * \param expr The expression to be simplified. diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py index 9ff8571ea..8ce1773ee 100644 --- a/tests/python/unittest/test_arith_simplify.py +++ b/tests/python/unittest/test_arith_simplify.py @@ -20,5 +20,23 @@ def test_simplify(): zz = zz.a assert zz.a == x and zz.b.value == 4 +def test_simplify_mod(): + """Not yet working, mock design""" + ib = tvm.ir_builder.create() + n = tvm.var('n') + j = tvm.var('j') + A = ib.pointer("float32", name="A") + with ib.for_range(0, 16, name="i") as i: + A[i] = A[((n * 4 + j * 2) * 8 + i+1) % 16] + body = ib.get() + stmt = tvm.ir_pass.CanonicalSimplify(body) + diff = tvm.ir_pass.CanonicalSimplify(stmt.body.value.index - (1 + i) % 16) + assert diff.value == 0 + index = tvm.ir_pass.CanonicalSimplify( + (j + n * 32) % 16, {j: tvm.Range(0, 6)}) + assert index == j + + if __name__ == "__main__": + test_simplify_mod() test_simplify() -- GitLab