diff --git a/src/lang/ir_operator.cc b/src/lang/ir_operator.cc
index 30742764351d6d28e5e27967da702010d6bae3bf..275752644be9a743141c35fbfff1d251cfc32fd8 100644
--- a/src/lang/ir_operator.cc
+++ b/src/lang/ir_operator.cc
@@ -310,20 +310,26 @@ Expr operator!=(Expr a, Expr b) {
 
 Expr operator&&(Expr a, Expr b) {
   using ir::UIntImm;
-  const UIntImm* pa = a.as<UIntImm>();
-  const UIntImm* pb = b.as<UIntImm>();
-  if (pa && pb) {
-    return UIntImm::make(UInt(1), pa->value && pb->value);
+  if (a.type().is_bool() && b.type().is_bool()) {
+    const UIntImm* pa = a.as<UIntImm>();
+    const UIntImm* pb = b.as<UIntImm>();
+    if (pa && pa->value) return b;
+    if (pa && !pa->value) return a;
+    if (pb && pb->value) return a;
+    if (pb && !pb->value) return b;
   }
   return ir::And::make(a, b);
 }
 
 Expr operator||(Expr a, Expr b) {
   using ir::UIntImm;
-  const UIntImm* pa = a.as<UIntImm>();
-  const UIntImm* pb = b.as<UIntImm>();
-  if (pa && pb) {
-    return UIntImm::make(UInt(1), pa->value || pb->value);
+  if (a.type().is_bool() && b.type().is_bool()) {
+    const UIntImm* pa = a.as<UIntImm>();
+    const UIntImm* pb = b.as<UIntImm>();
+    if (pa && pa->value) return a;
+    if (pa && !pa->value) return b;
+    if (pb && pb->value) return b;
+    if (pb && !pb->value) return a;
   }
   return ir::Or::make(a, b);
 }
diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py
index 9c701ed2abe355d4f97c1d5ec704dc82c791a9c5..af7d9fd5544afe6a68efa9b05fc5cdff1b5eccd0 100644
--- a/tests/python/unittest/test_lang_operator.py
+++ b/tests/python/unittest/test_lang_operator.py
@@ -30,6 +30,44 @@ def test_const_fold2():
     assert (1 * x).same_as(x)
     assert isinstance((1 / x), tvm.expr.Div)
 
+def test_const_fold3():
+    def check_throws(f):
+        try:
+            f()
+        except tvm.TVMError:
+            pass
+        else:
+            raise AssertionError("Should have raised an exception but didn't.")
+
+    # Test that using ints with logic operations is forbidden
+    x = tvm.var("x")
+    for val in [0, 1]:
+        for func in [tvm.all, tvm.any]:
+            check_throws(lambda: func(tvm.const(val, 'uint1'), x))
+            check_throws(lambda: func(x, tvm.const(val, 'uint1')))
+
+    # Test const folding when both arguments are const
+    for tvm_func, py_func in [(tvm.all, lambda a, b: a and b), (tvm.any, lambda a, b: a or b)]:
+        for v1 in [0, 1]:
+            for v2 in [0, 1]:
+                assert tvm.ir_pass.Equal(tvm_func(tvm.const(v1, 'uint1'), tvm.const(v2, 'uint1')),
+                                         tvm.const(py_func(v1, v2), 'uint1'))
+
+    x = tvm.var("x", 'uint1')
+    true = tvm.const(1, 'uint1')
+    false = tvm.const(0, 'uint1')
+
+    assert tvm.all(x, true).same_as(x)
+    assert tvm.all(true, x).same_as(x)
+    assert tvm.any(x, false).same_as(x)
+    assert tvm.any(false, x).same_as(x)
+
+    assert tvm.all(x, false).same_as(false)
+    assert tvm.all(false, x).same_as(false)
+    assert tvm.any(x, true).same_as(true)
+    assert tvm.any(true, x).same_as(true)
+
 if __name__ == "__main__":
     test_const_fold()
     test_const_fold2()
+    test_const_fold3()