diff --git a/src/lang/ir_operator.cc b/src/lang/ir_operator.cc index 30742764351d6d28e5e27967da702010d6bae3bf..275752644be9a743141c35fbfff1d251cfc32fd8 100644 --- a/src/lang/ir_operator.cc +++ b/src/lang/ir_operator.cc @@ -310,20 +310,26 @@ Expr operator!=(Expr a, Expr b) { Expr operator&&(Expr a, Expr b) { using ir::UIntImm; - const UIntImm* pa = a.as<UIntImm>(); - const UIntImm* pb = b.as<UIntImm>(); - if (pa && pb) { - return UIntImm::make(UInt(1), pa->value && pb->value); + if (a.type().is_bool() && b.type().is_bool()) { + const UIntImm* pa = a.as<UIntImm>(); + const UIntImm* pb = b.as<UIntImm>(); + if (pa && pa->value) return b; + if (pa && !pa->value) return a; + if (pb && pb->value) return a; + if (pb && !pb->value) return b; } return ir::And::make(a, b); } Expr operator||(Expr a, Expr b) { using ir::UIntImm; - const UIntImm* pa = a.as<UIntImm>(); - const UIntImm* pb = b.as<UIntImm>(); - if (pa && pb) { - return UIntImm::make(UInt(1), pa->value || pb->value); + if (a.type().is_bool() && b.type().is_bool()) { + const UIntImm* pa = a.as<UIntImm>(); + const UIntImm* pb = b.as<UIntImm>(); + if (pa && pa->value) return a; + if (pa && !pa->value) return b; + if (pb && pb->value) return b; + if (pb && !pb->value) return a; } return ir::Or::make(a, b); } diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index 9c701ed2abe355d4f97c1d5ec704dc82c791a9c5..af7d9fd5544afe6a68efa9b05fc5cdff1b5eccd0 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -30,6 +30,44 @@ def test_const_fold2(): assert (1 * x).same_as(x) assert isinstance((1 / x), tvm.expr.Div) +def test_const_fold3(): + def check_throws(f): + try: + f() + except tvm.TVMError: + pass + else: + raise AssertionError("Should have raised an exception but didn't.") + + # Test that using ints with logic operations is forbidden + x = tvm.var("x") + for val in [0, 1]: + for func in [tvm.all, tvm.any]: + check_throws(lambda: func(tvm.const(val, 'uint1'), x)) + check_throws(lambda: func(x, tvm.const(val, 'uint1'))) + + # Test const folding when both arguments are const + for tvm_func, py_func in [(tvm.all, lambda a, b: a and b), (tvm.any, lambda a, b: a or b)]: + for v1 in [0, 1]: + for v2 in [0, 1]: + assert tvm.ir_pass.Equal(tvm_func(tvm.const(v1, 'uint1'), tvm.const(v2, 'uint1')), + tvm.const(py_func(v1, v2), 'uint1')) + + x = tvm.var("x", 'uint1') + true = tvm.const(1, 'uint1') + false = tvm.const(0, 'uint1') + + assert tvm.all(x, true).same_as(x) + assert tvm.all(true, x).same_as(x) + assert tvm.any(x, false).same_as(x) + assert tvm.any(false, x).same_as(x) + + assert tvm.all(x, false).same_as(false) + assert tvm.all(false, x).same_as(false) + assert tvm.any(x, true).same_as(true) + assert tvm.any(true, x).same_as(true) + if __name__ == "__main__": test_const_fold() test_const_fold2() + test_const_fold3()