diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 8f913ccd43504f1178be57c61fb472bee06f98f7..2151ebf2adba21f0fa4ecace5c15cbfe9e8fa9d9 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -236,6 +236,24 @@ class Canonical::Internal : public IRMutator { bool EnableOpt(Type t) const { return (t.lanes() == 1 && (t.is_int() || t.is_uint())); } + // Max + Expr Mutate_(const Max* op, const Expr& e) final { + CacheEntry a = Produce(op->a); + CacheEntry b = Produce(op->b); + if (a.has_side_effect || b.has_side_effect) { + return Binary_(op, e, a.value, b.value); + } + return Binary(op, e); + } + // Min + Expr Mutate_(const Min* op, const Expr& e) final { + CacheEntry a = Produce(op->a); + CacheEntry b = Produce(op->b); + if (a.has_side_effect || b.has_side_effect) { + return Binary_(op, e, a.value, b.value); + } + return Binary(op, e); + } // Add Expr Mutate_(const Add* op, const Expr& e) final { if (!EnableOpt(op->type)) { @@ -277,7 +295,7 @@ class Canonical::Internal : public IRMutator { } else if (is_const(b.value)) { return SumMulConst(a.AsSum(), b.value); } else { - return Binary_(op, e, a.value, b.value); + return Binary(op, e); } } // Variable diff --git a/tests/cpp/ir_simplify_test.cc b/tests/cpp/ir_simplify_test.cc index 0667dc27367c07b5233a05fe7714178e8e4b6e34..8114bb51b7710c6c89bcde56d71970ca15e93413 100644 --- a/tests/cpp/ir_simplify_test.cc +++ b/tests/cpp/ir_simplify_test.cc @@ -1,5 +1,6 @@ #include <dmlc/logging.h> #include <gtest/gtest.h> +#include <tvm/ir_pass.h> #include <tvm/tvm.h> #include <arithmetic/Simplify.h> @@ -8,6 +9,24 @@ TEST(IRSIMPLIFY, Basic) { simplify_test(); } +TEST(IRSIMPLIFY, MinMax) { + auto x = tvm::var("x"); + auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)) ; + auto e1s = tvm::ir::CanonicalSimplify(e1); + CHECK(is_zero(e1s)); + + auto e2 = (x * tvm::min(x, 1)) - (x * tvm::min(x, 1)); + auto e2s = tvm::ir::CanonicalSimplify(e2); + CHECK(is_zero(e2s)); +} + +TEST(IRSIMPLIFY, Mul) { + auto x = tvm::var("x"); + auto e = (x * x) - (x * x) ; + auto es = tvm::ir::CanonicalSimplify(e); + CHECK(is_zero(es)); +} + int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py index e9315eda3257909fe3319e6ab32b7c6df24ce5b3..f6a78b6e3770863929546e464db268a1e7e03acd 100644 --- a/tests/python/unittest/test_arith_simplify.py +++ b/tests/python/unittest/test_arith_simplify.py @@ -46,6 +46,21 @@ def test_simplify_mod(): (j + n * 32) % 16, {j: tvm.Range(0, 6)}) assert index == j +def test_simplify_minmax(): + x = tvm.var('x') + e1 = tvm.max(x, 1) - tvm.max(x, 1) + e1s = tvm.ir_pass.CanonicalSimplify(e1) + assert e1s.value == 0 + + e2 = tvm.min(x, 1) - tvm.min(x, 1) + e2s = tvm.ir_pass.CanonicalSimplify(e2) + assert e2s.value == 0 + +def test_mul(): + x = tvm.var('x') + e = x * x - x * x + es = tvm.ir_pass.CanonicalSimplify(e) + assert es.value == 0 def test_modular(): rx = tvm.var("rx") @@ -62,11 +77,9 @@ def test_modular(): assert tvm.ir_pass.CanonicalSimplify(z1 - (ry + y)).value == 0 assert tvm.ir_pass.CanonicalSimplify(z2 - (rx + x)).value == 0 - - - - if __name__ == "__main__": test_simplify_mod() test_modular() test_simplify() + test_mul() + test_simplify_minmax() \ No newline at end of file