From f01cc0e6e54ca87e1fab8c56a8e443ecce85f2f8 Mon Sep 17 00:00:00 2001
From: Sergei Grechanik <grechanik.sergey@huawei.com>
Date: Tue, 16 Oct 2018 19:48:19 +0300
Subject: [PATCH] [TVM] Eagerer const folding for logic ops (#1907)

---
 src/lang/ir_operator.cc                     | 22 +++++++-----
 tests/python/unittest/test_lang_operator.py | 38 +++++++++++++++++++++
 2 files changed, 52 insertions(+), 8 deletions(-)

diff --git a/src/lang/ir_operator.cc b/src/lang/ir_operator.cc
index 307427643..275752644 100644
--- a/src/lang/ir_operator.cc
+++ b/src/lang/ir_operator.cc
@@ -310,20 +310,26 @@ Expr operator!=(Expr a, Expr b) {
 
 Expr operator&&(Expr a, Expr b) {
   using ir::UIntImm;
-  const UIntImm* pa = a.as<UIntImm>();
-  const UIntImm* pb = b.as<UIntImm>();
-  if (pa && pb) {
-    return UIntImm::make(UInt(1), pa->value && pb->value);
+  if (a.type().is_bool() && b.type().is_bool()) {
+    const UIntImm* pa = a.as<UIntImm>();
+    const UIntImm* pb = b.as<UIntImm>();
+    if (pa && pa->value) return b;
+    if (pa && !pa->value) return a;
+    if (pb && pb->value) return a;
+    if (pb && !pb->value) return b;
   }
   return ir::And::make(a, b);
 }
 
 Expr operator||(Expr a, Expr b) {
   using ir::UIntImm;
-  const UIntImm* pa = a.as<UIntImm>();
-  const UIntImm* pb = b.as<UIntImm>();
-  if (pa && pb) {
-    return UIntImm::make(UInt(1), pa->value || pb->value);
+  if (a.type().is_bool() && b.type().is_bool()) {
+    const UIntImm* pa = a.as<UIntImm>();
+    const UIntImm* pb = b.as<UIntImm>();
+    if (pa && pa->value) return a;
+    if (pa && !pa->value) return b;
+    if (pb && pb->value) return b;
+    if (pb && !pb->value) return a;
   }
   return ir::Or::make(a, b);
 }
diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py
index 9c701ed2a..af7d9fd55 100644
--- a/tests/python/unittest/test_lang_operator.py
+++ b/tests/python/unittest/test_lang_operator.py
@@ -30,6 +30,44 @@ def test_const_fold2():
     assert (1 * x).same_as(x)
     assert isinstance((1 / x), tvm.expr.Div)
 
+def test_const_fold3():
+    def check_throws(f):
+        try:
+            f()
+        except tvm.TVMError:
+            pass
+        else:
+            raise AssertionError("Should have raised an exception but didn't.")
+
+    # Test that using ints with logic operations is forbidden
+    x = tvm.var("x")
+    for val in [0, 1]:
+        for func in [tvm.all, tvm.any]:
+            check_throws(lambda: func(tvm.const(val, 'uint1'), x))
+            check_throws(lambda: func(x, tvm.const(val, 'uint1')))
+
+    # Test const folding when both arguments are const
+    for tvm_func, py_func in [(tvm.all, lambda a, b: a and b), (tvm.any, lambda a, b: a or b)]:
+        for v1 in [0, 1]:
+            for v2 in [0, 1]:
+                assert tvm.ir_pass.Equal(tvm_func(tvm.const(v1, 'uint1'), tvm.const(v2, 'uint1')),
+                                         tvm.const(py_func(v1, v2), 'uint1'))
+
+    x = tvm.var("x", 'uint1')
+    true = tvm.const(1, 'uint1')
+    false = tvm.const(0, 'uint1')
+
+    assert tvm.all(x, true).same_as(x)
+    assert tvm.all(true, x).same_as(x)
+    assert tvm.any(x, false).same_as(x)
+    assert tvm.any(false, x).same_as(x)
+
+    assert tvm.all(x, false).same_as(false)
+    assert tvm.all(false, x).same_as(false)
+    assert tvm.any(x, true).same_as(true)
+    assert tvm.any(true, x).same_as(true)
+
 if __name__ == "__main__":
     test_const_fold()
     test_const_fold2()
+    test_const_fold3()
-- 
GitLab