diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc
index 87300a8fa767807cb51cacfc0e7e0c0592dec702..ad18fecd1bb0275d5919fb709861748fe5e93dbf 100644
--- a/src/arithmetic/canonical.cc
+++ b/src/arithmetic/canonical.cc
@@ -93,6 +93,26 @@ struct ComExpr {
   std::shared_ptr<ComExprNode> ptr_;
 };
 
+// binary comparison op.
+struct BinaryExpr {
+  int kind;
+  Expr lhs, rhs;
+  // comparator
+  bool operator<(const BinaryExpr& b) const {
+    if (kind < b.kind) return true;
+    if (kind > b.kind) return false;
+    if (lhs.get() < b.lhs.get()) return true;
+    if (lhs.get() > b.lhs.get()) return false;
+    return rhs.get() < b.rhs.get();
+  }
+  // equality
+  bool operator==(const BinaryExpr& b) const {
+    return kind == b.kind &&
+        lhs.same_as(b.lhs) &&
+        rhs.same_as(b.rhs);
+  }
+};
+
 template<typename T>
 inline Expr Binary_(const T* op,
                     const Expr& e,
@@ -104,12 +124,6 @@ inline Expr Binary_(const T* op,
   }
 }
 
-template<typename T>
-inline Expr Binary(
-    const T* op, const Expr& e, IRMutator* m) {
-  return Binary_(op, e, m->Mutate(op->a), m->Mutate(op->b));
-}
-
 // internal of canonical engine.
 class Canonical::Internal : public IRMutator {
  public:
@@ -200,7 +214,7 @@ class Canonical::Internal : public IRMutator {
   // Add
   Expr Mutate_(const Add* op, const Expr& e) final {
     if (!EnableOpt(op->type)) {
-      return Binary(op, e, this);
+      return Binary(op, e);
     }
     CacheEntry a = Produce(op->a);
     CacheEntry b = Produce(op->b);
@@ -212,7 +226,7 @@ class Canonical::Internal : public IRMutator {
   // Sub
   Expr Mutate_(const Sub* op, const Expr& e) final {
     if (!EnableOpt(op->type)) {
-      return Binary(op, e, this);
+      return Binary(op, e);
     }
     CacheEntry a = Produce(op->a);
     CacheEntry b = Produce(op->b);
@@ -224,7 +238,7 @@ class Canonical::Internal : public IRMutator {
   // Mul
   Expr Mutate_(const Mul* op, const Expr& e) final {
     if (!EnableOpt(op->type)) {
-      return Binary(op, e, this);
+      return Binary(op, e);
     }
     CacheEntry a = Produce(op->a);
     CacheEntry b = Produce(op->b);
@@ -252,7 +266,7 @@ class Canonical::Internal : public IRMutator {
   // comparison
   Expr Mutate_(const LT* op, const Expr& e) {
     if (!EnableOpt(op->a.type())) {
-      return Binary(op, e, this);
+      return Binary(op, e);
     }
     CacheEntry a = Produce(op->a);
     CacheEntry b = Produce(op->b);
@@ -266,6 +280,23 @@ class Canonical::Internal : public IRMutator {
       return Binary_(op, e, a.value, b.value);
     }
   }
+  // IntImm
+  Expr Mutate_(const IntImm* op, const Expr& e) final {
+    auto it = cache_intimm_.find(op->value);
+    if (it != cache_intimm_.end()) {
+      return it->second;
+    } else {
+      cache_intimm_[op->value] = e;
+      return e;
+    }
+  }
+  // binary ops
+  Expr Mutate_(const Div* op, const Expr& e) final {
+    return Binary(op, e);
+  }
+  Expr Mutate_(const Mod* op, const Expr& e) final {
+    return Binary(op, e);
+  }
   // Call
   Expr Mutate_(const Call* op, const Expr& e) final {
     if (!op->is_pure()) {
@@ -309,12 +340,30 @@ class Canonical::Internal : public IRMutator {
   }
 
  private:
+  template<typename T>
+  Expr Binary(const T* op, const Expr& e) {
+    Expr a = this->Mutate(op->a);
+    Expr b = this->Mutate(op->b);
+    BinaryExpr key{static_cast<int>(T::_type_info), a, b};
+    auto it = cache_binary_.find(key);
+    if (it != cache_binary_.end()) {
+      return it->second;
+    } else {
+      Expr ret = Binary_(op, e, a, b);
+      cache_binary_[key] = ret;
+      return ret;
+    }
+  }
   // return entry
   CacheEntry ret_entry_;
   // internal information stack
   std::vector<StackEntry> stack_;
   // cache sum
   std::map<ComExpr, CacheEntry> cache_sum_;
+  // cache of normal binary op
+  std::map<BinaryExpr, Expr> cache_binary_;
+  // cache of int constant
+  std::unordered_map<int64_t, Expr> cache_intimm_;
   // range of each var
   std::unordered_map<const Variable*, IntSet> var_range_;
   // level of each var
diff --git a/tests/python/unittest/test_arith_simplify.py b/tests/python/unittest/test_arith_simplify.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ff8571eac42a3ed7b670cc1a9c8bb29ff134538
--- /dev/null
+++ b/tests/python/unittest/test_arith_simplify.py
@@ -0,0 +1,24 @@
+import tvm
+
+def csimplify(z):
+    return tvm.ir_pass.CanonicalSimplify(
+        tvm.make.Evaluate(z)).value
+
+def test_simplify():
+    x = tvm.var('n')
+    z = x * 4  - x * 2
+    zz = csimplify(z)
+    assert zz.b.value == 2
+
+    z = (x / 4) * 2  - (x / 4)
+    zz = csimplify(z)
+    assert zz.a == x and zz.b.value == 4
+
+    z = (x % 4) * 3  + (x % 4)
+    zz = csimplify(z)
+    assert zz.b.value == 4
+    zz = zz.a
+    assert zz.a == x and zz.b.value == 4
+
+if __name__ == "__main__":
+    test_simplify()