diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc
index 27ccfb09cdeb0e8eceeb7e51784b93b0bed948e1..e219b5541bdc390cc42143e69aa384455c080f3d 100644
--- a/src/arithmetic/canonical.cc
+++ b/src/arithmetic/canonical.cc
@@ -29,9 +29,17 @@ struct ComExprEntry {
   inline bool operator<(const ComExprEntry& other) const {
     if (level < other.level) return true;
     if (level > other.level) return false;
+    // compare top operator of entries and sort on that if possible (fast check)
     if (value.type_index() < other.value.type_index()) return true;
     if (value.type_index() > other.value.type_index()) return false;
-    return value.get() < other.value.get();
+    // if none of the above distinguishes the terms, compare the expression tree of the entries.
+    // This is a slower check.
+    int compare_result = Compare(value, other.value);
+    if (compare_result < 0) return true;
+    if (compare_result > 0) return false;
+    // it's a problem if we see identical entries at this point. They should've been merged earlier.
+    LOG(WARNING) << "we should not have identical entries at this point";
+    return false;
   }
 };
 
diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py
index 9105693b3835691df81269db84e598ece0e7e503..29b5b3a8450dc7919835ce4c4ce273abb13dd4a5 100644
--- a/tests/python/unittest/test_pass_simplify.py
+++ b/tests/python/unittest/test_pass_simplify.py
@@ -43,6 +43,16 @@ def test_canonical():
     ret = tvm.ir_pass.CanonicalSimplify(x / (z+z) - x / (z+z))
     assert(tvm.ir_pass.Equal(ret, 0))
 
+    #make sure terms are ordered based on their top operators (e.g., / always precedes %)
+    ret1 = tvm.ir_pass.CanonicalSimplify(x % 3 + x / 3)
+    ret2 = tvm.ir_pass.CanonicalSimplify(x / 3 + x % 3)
+    assert(tvm.ir_pass.Equal(ret1, ret2))
+
+    #when top operators match, compare string representation of terms
+    ret1 = tvm.ir_pass.CanonicalSimplify(x % 4 + x % 3)
+    ret2 = tvm.ir_pass.CanonicalSimplify(x % 3 + x % 4)
+    assert (tvm.ir_pass.Equal(ret1, ret2))
+
 if __name__ == "__main__":
     test_bound()
     test_basic()