From f01cc0e6e54ca87e1fab8c56a8e443ecce85f2f8 Mon Sep 17 00:00:00 2001 From: Sergei Grechanik <grechanik.sergey@huawei.com> Date: Tue, 16 Oct 2018 19:48:19 +0300 Subject: [PATCH] [TVM] Eagerer const folding for logic ops (#1907) --- src/lang/ir_operator.cc | 22 +++++++----- tests/python/unittest/test_lang_operator.py | 38 +++++++++++++++++++++ 2 files changed, 52 insertions(+), 8 deletions(-) diff --git a/src/lang/ir_operator.cc b/src/lang/ir_operator.cc index 307427643..275752644 100644 --- a/src/lang/ir_operator.cc +++ b/src/lang/ir_operator.cc @@ -310,20 +310,26 @@ Expr operator!=(Expr a, Expr b) { Expr operator&&(Expr a, Expr b) { using ir::UIntImm; - const UIntImm* pa = a.as<UIntImm>(); - const UIntImm* pb = b.as<UIntImm>(); - if (pa && pb) { - return UIntImm::make(UInt(1), pa->value && pb->value); + if (a.type().is_bool() && b.type().is_bool()) { + const UIntImm* pa = a.as<UIntImm>(); + const UIntImm* pb = b.as<UIntImm>(); + if (pa && pa->value) return b; + if (pa && !pa->value) return a; + if (pb && pb->value) return a; + if (pb && !pb->value) return b; } return ir::And::make(a, b); } Expr operator||(Expr a, Expr b) { using ir::UIntImm; - const UIntImm* pa = a.as<UIntImm>(); - const UIntImm* pb = b.as<UIntImm>(); - if (pa && pb) { - return UIntImm::make(UInt(1), pa->value || pb->value); + if (a.type().is_bool() && b.type().is_bool()) { + const UIntImm* pa = a.as<UIntImm>(); + const UIntImm* pb = b.as<UIntImm>(); + if (pa && pa->value) return a; + if (pa && !pa->value) return b; + if (pb && pb->value) return b; + if (pb && !pb->value) return a; } return ir::Or::make(a, b); } diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py index 9c701ed2a..af7d9fd55 100644 --- a/tests/python/unittest/test_lang_operator.py +++ b/tests/python/unittest/test_lang_operator.py @@ -30,6 +30,44 @@ def test_const_fold2(): assert (1 * x).same_as(x) assert isinstance((1 / x), tvm.expr.Div) +def test_const_fold3(): + def check_throws(f): + try: + f() + except tvm.TVMError: + pass + else: + raise AssertionError("Should have raised an exception but didn't.") + + # Test that using ints with logic operations is forbidden + x = tvm.var("x") + for val in [0, 1]: + for func in [tvm.all, tvm.any]: + check_throws(lambda: func(tvm.const(val, 'uint1'), x)) + check_throws(lambda: func(x, tvm.const(val, 'uint1'))) + + # Test const folding when both arguments are const + for tvm_func, py_func in [(tvm.all, lambda a, b: a and b), (tvm.any, lambda a, b: a or b)]: + for v1 in [0, 1]: + for v2 in [0, 1]: + assert tvm.ir_pass.Equal(tvm_func(tvm.const(v1, 'uint1'), tvm.const(v2, 'uint1')), + tvm.const(py_func(v1, v2), 'uint1')) + + x = tvm.var("x", 'uint1') + true = tvm.const(1, 'uint1') + false = tvm.const(0, 'uint1') + + assert tvm.all(x, true).same_as(x) + assert tvm.all(true, x).same_as(x) + assert tvm.any(x, false).same_as(x) + assert tvm.any(false, x).same_as(x) + + assert tvm.all(x, false).same_as(false) + assert tvm.all(false, x).same_as(false) + assert tvm.any(x, true).same_as(true) + assert tvm.any(true, x).same_as(true) + if __name__ == "__main__": test_const_fold() test_const_fold2() + test_const_fold3() -- GitLab