diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 87300a8fa767807cb51cacfc0e7e0c0592dec702..ad18fecd1bb0275d5919fb709861748fe5e93dbf 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -93,6 +93,26 @@ struct ComExpr { std::shared_ptr<ComExprNode> ptr_; }; +// binary comparison op. +struct BinaryExpr { + int kind; + Expr lhs, rhs; + // comparator + bool operator<(const BinaryExpr& b) const { + if (kind < b.kind) return true; + if (kind > b.kind) return false; + if (lhs.get() < b.lhs.get()) return true; + if (lhs.get() > b.lhs.get()) return false; + return rhs.get() < b.rhs.get(); + } + // equality + bool operator==(const BinaryExpr& b) const { + return kind == b.kind && + lhs.same_as(b.lhs) && + rhs.same_as(b.rhs); + } +}; + template<typename T> inline Expr Binary_(const T* op, const Expr& e, @@ -104,12 +124,6 @@ inline Expr Binary_(const T* op, } } -template<typename T> -inline Expr Binary( - const T* op, const Expr& e, IRMutator* m) { - return Binary_(op, e, m->Mutate(op->a), m->Mutate(op->b)); -} - // internal of canonical engine. class Canonical::Internal : public IRMutator { public: @@ -200,7 +214,7 @@ class Canonical::Internal : public IRMutator { // Add Expr Mutate_(const Add* op, const Expr& e) final { if (!EnableOpt(op->type)) { - return Binary(op, e, this); + return Binary(op, e); } CacheEntry a = Produce(op->a); CacheEntry b = Produce(op->b); @@ -212,7 +226,7 @@ class Canonical::Internal : public IRMutator { // Sub Expr Mutate_(const Sub* op, const Expr& e) final { if (!EnableOpt(op->type)) { - return Binary(op, e, this); + return Binary(op, e); } CacheEntry a = Produce(op->a); CacheEntry b = Produce(op->b); @@ -224,7 +238,7 @@ class Canonical::Internal : public IRMutator { // Mul Expr Mutate_(const Mul* op, const Expr& e) final { if (!EnableOpt(op->type)) { - return Binary(op, e, this); + return Binary(op, e); } CacheEntry a = Produce(op->a); CacheEntry b = Produce(op->b); @@ -252,7 +266,7 @@ class Canonical::Internal : public IRMutator { // comparison Expr Mutate_(const LT* op, const Expr& e) { if (!EnableOpt(op->a.type())) { - return Binary(op, e, this); + return Binary(op, e); } CacheEntry a = Produce(op->a); CacheEntry b = Produce(op->b); @@ -266,6 +280,23 @@ class Canonical::Internal : public IRMutator { return Binary_(op, e, a.value, b.value); } } + // IntImm + Expr Mutate_(const IntImm* op, const Expr& e) final { + auto it = cache_intimm_.find(op->value); + if (it != cache_intimm_.end()) { + return it->second; + } else { + cache_intimm_[op->value] = e; + return e; + } + } + // binary ops + Expr Mutate_(const Div* op, const Expr& e) final { + return Binary(op, e); + } + Expr Mutate_(const Mod* op, const Expr& e) final { + return Binary(op, e); + } // Call Expr Mutate_(const Call* op, const Expr& e) final { if (!op->is_pure()) { @@ -309,12 +340,30 @@ class Canonical::Internal : public IRMutator { } private: + template<typename T> + Expr Binary(const T* op, const Expr& e) { + Expr a = this->Mutate(op->a); + Expr b = this->Mutate(op->b); + BinaryExpr key{static_cast<int>(T::_type_info), a, b}; + auto it = cache_binary_.find(key); + if (it != cache_binary_.end()) { + return it->second; + } else { + Expr ret = Binary_(op, e, a, b); + cache_binary_[key] = ret; + return ret; + } + } // return entry CacheEntry ret_entry_; // internal information stack std::vector<StackEntry> stack_; // cache sum std::map<ComExpr, CacheEntry> cache_sum_; + // cache of normal binary op + std::map<BinaryExpr, Expr> cache_binary_; + // cache of int constant + std::unordered_map<int64_t, Expr> cache_intimm_; // range of each var std::unordered_map<const Variable*, IntSet> var_range_; // level of each var diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py new file mode 100644 index 0000000000000000000000000000000000000000..9ff8571eac42a3ed7b670cc1a9c8bb29ff134538 --- /dev/null +++ b/tests/python/unittest/test_arith_simplify.py @@ -0,0 +1,24 @@ +import tvm + +def csimplify(z): + return tvm.ir_pass.CanonicalSimplify( + tvm.make.Evaluate(z)).value + +def test_simplify(): + x = tvm.var('n') + z = x * 4 - x * 2 + zz = csimplify(z) + assert zz.b.value == 2 + + z = (x / 4) * 2 - (x / 4) + zz = csimplify(z) + assert zz.a == x and zz.b.value == 4 + + z = (x % 4) * 3 + (x % 4) + zz = csimplify(z) + assert zz.b.value == 4 + zz = zz.a + assert zz.a == x and zz.b.value == 4 + +if __name__ == "__main__": + test_simplify()