diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc index 27ccfb09cdeb0e8eceeb7e51784b93b0bed948e1..e219b5541bdc390cc42143e69aa384455c080f3d 100644 --- a/src/arithmetic/canonical.cc +++ b/src/arithmetic/canonical.cc @@ -29,9 +29,17 @@ struct ComExprEntry { inline bool operator<(const ComExprEntry& other) const { if (level < other.level) return true; if (level > other.level) return false; + // compare top operator of entries and sort on that if possible (fast check) if (value.type_index() < other.value.type_index()) return true; if (value.type_index() > other.value.type_index()) return false; - return value.get() < other.value.get(); + // if none of the above distinguishes the terms, compare the expression tree of the entries. + // This is a slower check. + int compare_result = Compare(value, other.value); + if (compare_result < 0) return true; + if (compare_result > 0) return false; + // it's a problem if we see identical entries at this point. They should've been merged earlier. + LOG(WARNING) << "we should not have identical entries at this point"; + return false; } }; diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py index 9105693b3835691df81269db84e598ece0e7e503..29b5b3a8450dc7919835ce4c4ce273abb13dd4a5 100644 --- a/tests/python/unittest/test_pass_simplify.py +++ b/tests/python/unittest/test_pass_simplify.py @@ -43,6 +43,16 @@ def test_canonical(): ret = tvm.ir_pass.CanonicalSimplify(x / (z+z) - x / (z+z)) assert(tvm.ir_pass.Equal(ret, 0)) + #make sure terms are ordered based on their top operators (e.g., / always precedes %) + ret1 = tvm.ir_pass.CanonicalSimplify(x % 3 + x / 3) + ret2 = tvm.ir_pass.CanonicalSimplify(x / 3 + x % 3) + assert(tvm.ir_pass.Equal(ret1, ret2)) + + #when top operators match, compare string representation of terms + ret1 = tvm.ir_pass.CanonicalSimplify(x % 4 + x % 3) + ret2 = tvm.ir_pass.CanonicalSimplify(x % 3 + x % 4) + assert (tvm.ir_pass.Equal(ret1, ret2)) + if __name__ == "__main__": test_bound() test_basic()