From 32af4d2899c6dd1a3a5a24f27957a3383a022fde Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Mon, 1 Oct 2018 13:45:53 -0700 Subject: [PATCH] [IR] eager constant folding in operator overloading (#1789) --- include/tvm/buffer.h | 1 + include/tvm/expr.h | 10 - include/tvm/ir.h | 2 - include/tvm/ir_operator.h | 589 ++++++++++++++++-- include/tvm/tensor.h | 1 + nnvm/src/top/tensor/reduce.cc | 2 +- python/tvm/api.py | 12 +- python/tvm/expr.py | 16 +- python/tvm/generic.py | 8 +- src/api/api_ir.cc | 99 +-- src/arithmetic/compute_expr.h | 61 +- src/arithmetic/detect_linear_equation.cc | 4 +- src/codegen/codegen_cuda.cc | 2 +- src/codegen/verilog/verilog_ir.cc | 2 +- src/lang/expr.cc | 1 + src/lang/ir_operator.cc | 402 +++++++++++- src/pass/ir_util.h | 5 +- src/pass/split_pipeline.cc | 3 +- src/pass/storage_rewrite.cc | 2 +- src/pass/vectorize_loop.cc | 1 - tests/cpp/ir_mutator_test.cc | 1 + tests/python/unittest/test_arith_intset.py | 9 +- tests/python/unittest/test_lang_basic.py | 2 +- tests/python/unittest/test_lang_operator.py | 35 ++ tests/python/unittest/test_lang_reflection.py | 2 +- tests/python/unittest/test_pass_simplify.py | 1 - topi/include/topi/elemwise.h | 9 +- topi/include/topi/nn/pooling.h | 10 +- topi/python/topi/vision/ssd/multibox.py | 10 +- 29 files changed, 1106 insertions(+), 196 deletions(-) create mode 100644 tests/python/unittest/test_lang_operator.py diff --git a/include/tvm/buffer.h b/include/tvm/buffer.h index 5901a27fe..cda76cd14 100644 --- a/include/tvm/buffer.h +++ b/include/tvm/buffer.h @@ -10,6 +10,7 @@ #include "base.h" #include "expr.h" +#include "ir_operator.h" #include "node/container.h" namespace tvm { diff --git a/include/tvm/expr.h b/include/tvm/expr.h index a199d656c..050ab4c33 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -7,7 +7,6 @@ #define TVM_EXPR_H_ #include <ir/Expr.h> -#include <ir/IROperator.h> #include <ir/IRPrinter.h> #include <string> #include <algorithm> @@ -34,15 +33,6 @@ using HalideIR::Internal::Stmt; using HalideIR::Internal::IRPrinter; using HalideIR::Internal::Variable; -using HalideIR::Internal::make_const; -using HalideIR::Internal::make_zero; -using HalideIR::Internal::make_one; -using HalideIR::Internal::as_const_int; -using HalideIR::Internal::as_const_uint; -using HalideIR::Internal::const_true; -using HalideIR::Internal::const_false; -using HalideIR::Internal::is_no_op; - inline Type TVMShapeIndexType() { if (std::is_signed<tvm_index_t>::value) { return Int(sizeof(tvm_index_t) * 8); diff --git a/include/tvm/ir.h b/include/tvm/ir.h index b75d75c18..14e601465 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -495,8 +495,6 @@ using HalideIR::Internal::Block; using HalideIR::Internal::IfThenElse; using HalideIR::Internal::Evaluate; using HalideIR::Internal::Shuffle; -// ir functions -using HalideIR::Internal::is_const_power_of_two_integer; /*! * \brief Create a type annotation expression diff --git a/include/tvm/ir_operator.h b/include/tvm/ir_operator.h index 39588a222..5abd95b8c 100644 --- a/include/tvm/ir_operator.h +++ b/include/tvm/ir_operator.h @@ -1,24 +1,426 @@ /*! - * Copyright (c) 2017 by Contributors + * Copyright (c) 2018 by Contributors * \file tvm/ir_operator.h - * \brief Common operators of Expr + * \brief Common operators defined for Expr. + * + * \note Most of the operator defined here perform simple constant folding + * when the type is int32 or int64 for simplifying the index expressions. */ #ifndef TVM_IR_OPERATOR_H_ #define TVM_IR_OPERATOR_H_ #include <algorithm> +#include <type_traits> #include "expr.h" #include "ir.h" namespace tvm { +/*! + * \brief Make a const value with certain data type. + * \param t The target type. + * \param value The input value + * \return the result expression. + * \tparam ValueType The constant value type + */ +template<typename ValueType, + typename = typename std::enable_if<std::is_pod<ValueType>::value>::type> +inline Expr make_const(Type t, ValueType value); +/*! + * \brief Make a const zero expr. + * \param t The target type. + * \return the result expression. + */ +inline Expr make_zero(Type t); +/*! + * \brief Make a constant true expression. + * \param lanes The number of lanes in the bool + * \return The result expression. + */ +inline Expr const_true(int lanes = 1) { + return make_const(UInt(1, lanes), 1); +} +/*! + * \brief Make a constant false expression. + * \param lanes The number of lanes in the bool + * \return The result expression. + */ +inline Expr const_false(int lanes = 1) { + return make_const(UInt(1, lanes), 0); +} +/*! + * \brief Get x as constant int expression. + * \param x The expression + * \return the address to the int expression, + * return nullptr, if x is not IntImm. + */ +inline const int64_t* as_const_int(const Expr& x) { + if (!x.defined()) return nullptr; + if (const ir::IntImm* op = x.as<ir::IntImm>()) { + return &(op->value); + } else { + return nullptr; + } +} + +/*! + * \brief Get x as constant uint expression. + * \param x The expression + * \return the address to the int expression, + * return nullptr, if x is not UIntImm. + */ +inline const uint64_t* as_const_uint(const Expr& x) { + if (!x.defined()) return nullptr; + if (const ir::UIntImm* op = x.as<ir::UIntImm>()) { + return &(op->value); + } else { + return nullptr; + } +} + +/*! + * \brief Check whether x is a constant integer expression. + * \param x The input argument + * \param value the value to be compared against. + * \return whether x is constant expression. + */ +inline bool is_const_int(const Expr& x, int64_t value); + +/*! + * \brief Check whether stmt is nop. + * \param stmt The input statement + * \return whether stmt is nop + */ +inline bool is_no_op(const Stmt& stmt); + +/*! + * \brief Check whether x is a constant integer 1 + * \param x The input argument. + * \note This only return true for integer types. + * \return whether x is constant 1 + */ +inline bool is_one(const Expr& x) { + return is_const_int(x, 1); +} -using HalideIR::likely; -using HalideIR::likely_if_innermost; -// functions -using HalideIR::cast; -using HalideIR::min; -using HalideIR::max; -using HalideIR::select; +/*! + * \brief Check whether x is a constant integer 0 + * \param x The input argument + * \return whether x is constant 0 + * \note This only return true for integer types. + */ +inline bool is_zero(const Expr& x) { + return is_const_int(x, 0); +} + +/*! + * \brief Check whether x is a constant. + * \note This only return true for integer types. + * \return whether x is constant + */ +inline bool is_const(const Expr& x); + +/*! + * \brief Check whether x is a constant power of two + * If x is power of two, write the power to the shift. + * + * \param x The input expression. + * \param shift The output shift if x is power of two. + * \return whether x is constant power of two + */ +TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift); + +/*! + * \brief cast value to type. + * + * \param t the target type. + * \param value The value + * \return The result expression. + * \note This function may return value if the type is the same. + */ +TVM_DLL Expr cast(const Type& t, Expr value); +/*! + * \brief perform reinterpret cast value to type. + * + * \param t the target type. + * \param value The value + * \return The result expression. + * \note This function may return value if the type is the same. + */ +TVM_DLL Expr reinterpret(const Type& t, Expr value); +/*! + * \brief add operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator+(Expr a, Expr b); +/*! + * \brief subtraction operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator-(Expr a, Expr b); +/*! + * \brief negation. + * + * \param a input. + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator-(Expr a); +/*! + * \brief multiplication operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator*(Expr a, Expr b); +/*! + * \brief division operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator/(Expr a, Expr b); +/*! + * \brief mod operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator%(Expr a, Expr b); +/*! + * \brief left shift operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator<<(Expr a, Expr b); +/*! + * \brief right shift operator + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator>>(Expr a, Expr b); +/*! + * \brief greater + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator>(Expr a, Expr b); +/*! + * \brief greater_equal + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator>=(Expr a, Expr b); +/*! + * \brief less + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator<(Expr a, Expr b); +/*! + * \brief less_equal + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator<=(Expr a, Expr b); +/*! + * \brief equal + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator==(Expr a, Expr b); +/*! + * \brief not_equal + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator!=(Expr a, Expr b); +/*! + * \brief and + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note This operator does eager constant folding. + */ +TVM_DLL Expr operator&&(Expr a, Expr b); +/*! + * \brief or + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note This operator does eager constant folding. + */ +TVM_DLL Expr operator||(Expr a, Expr b); +/*! + * \brief not + * + * \param a left operand + * \return The result expression. + * \note This operator does eager constant folding. + */ +TVM_DLL Expr operator!(Expr a); +/*! + * \brief take maximum of two values + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr max(Expr a, Expr b); +/*! + * \brief take minimum of two values + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr min(Expr a, Expr b); +/*! + * \brief right shift + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator>>(Expr a, Expr b); +/*! + * \brief left shift + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator<<(Expr a, Expr b); +/*! + * \brief take bitwise and of two values + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator&(Expr a, Expr b); +/*! + * \brief take bitwise or of two values + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator|(Expr a, Expr b); +/*! + * \brief take bitwise xor of two values + * + * \param a left operand + * \param b right operand + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator^(Expr a, Expr b); +/*! + * \brief take bitwise negation of two values + * + * \param a the input expression. + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr operator~(Expr a); +/*! + * \brief select result by condition + * + * \param cond The condition + * \param true_value The value when results are true. + * \param false_value The value when results are false. + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL Expr select(Expr cond, Expr true_value, Expr false_value); +/*! + * \brief Mark condition as likely. + * \param cond The condition + * \return The marked expression. + */ +TVM_DLL Expr likely(Expr cond); +/*! + * \brief Calculate power(x, y) + * \param x The left operand. + * \param y The right operand. + */ +TVM_DLL Expr pow(Expr x, Expr y); +/*! + * \brief Calculate absolute value of x. + * \param x The input data + * + * \return The aboslute value of input data x + */ +TVM_DLL Expr abs(Expr x); /*! * \brief sum of of source expression over axis @@ -48,13 +450,12 @@ TVM_DLL Expr min(Expr source, Array<IterVar> axis); */ TVM_DLL Expr prod(Expr source, Array<IterVar> axis); -// Unary intrinsic operators +// Intrinsic operators #define TVM_DECLARE_INTRIN_UNARY(OpName) \ inline Expr OpName(Expr x) { \ return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \ } \ - TVM_DECLARE_INTRIN_UNARY(exp); TVM_DECLARE_INTRIN_UNARY(tanh); TVM_DECLARE_INTRIN_UNARY(sigmoid); @@ -64,38 +465,152 @@ TVM_DECLARE_INTRIN_UNARY(floor); TVM_DECLARE_INTRIN_UNARY(ceil); TVM_DECLARE_INTRIN_UNARY(round); TVM_DECLARE_INTRIN_UNARY(trunc); +TVM_DECLARE_INTRIN_UNARY(popcount); -/*! - * \brief Calculate power(x, y) - * \param x The left operand. - * \param y The right operand. - */ -inline Expr pow(Expr x, Expr y) { - match_types(x, y); - CHECK(x.type().is_float()) << "power only applies to float"; - return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic); + +// Implementation details after this +inline bool is_const(const Expr& x) { + if (x.as<ir::IntImm>() || x.as<ir::UIntImm>()) { + return true; + } else if (const auto* op = x.as<ir::Broadcast>()) { + const Expr& val = op->value; + if (val.as<ir::IntImm>() || val.as<ir::UIntImm>()) { + return true; + } + } + return false; } -/*! - * \brief Calculate absolute value of x, elementwise - * \param x The input data - * - * \return The aboslute value of input data x - */ -inline Expr abs(Expr x) { - if (x.type().is_int()) { - return select(x >= make_zero(x.type()), x, -x); - } else if (x.type().is_float()) { - return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic); - } else if (x.type().is_uint()) { - return x; +inline bool is_positive_const(const Expr& a) { + if (const ir::IntImm* op = a.as<ir::IntImm>()) { + return op->value > 0; + } else if (const ir::UIntImm* op = a.as<ir::UIntImm>()) { + return op->value > 0; } else { - LOG(WARNING) << "Warning: Data type " << x.type() - <<" not supported for absolute op. Skipping absolute op..."; - return x; + return false; } } -} // namespace tvm +inline bool is_negative_const(const Expr& a) { + if (const ir::IntImm* op = a.as<ir::IntImm>()) { + return op->value < 0; + } else { + return false; + } +} + +inline bool is_const_int(const Expr& x, int64_t value) { + if (const auto* op = x.as<ir::IntImm>()) { + return op->value == value; + } else if (const auto* op = x.as<ir::UIntImm>()) { + return op->value == static_cast<uint64_t>(value); + } else if (const auto* op = x.as<ir::Broadcast>()) { + const Expr& val = op->value; + if (const auto* opv = val.as<ir::IntImm>()) { + return opv->value == value; + } else if (const auto* opv = val.as<ir::UIntImm>()) { + return opv->value == static_cast<uint64_t>(value); + } + } + return false; +} + +inline bool is_no_op(const Stmt& stmt) { + if (!stmt.defined()) return true; + if (const auto* op = stmt.as<ir::Evaluate>()) { + return is_const(op->value); + } + return false; +} + +template<typename ValueType> +inline Expr MakeConstScalar(Type t, ValueType value) { + if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value)); + if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value)); + if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value)); + LOG(FATAL) << "cannot make const for type " << t; + return Expr(); +} + +template<typename ValueType, typename> +inline Expr make_const(Type t, ValueType value) { + if (t.lanes() == 1) { + return MakeConstScalar(t, value); + } else { + return ir::Broadcast::make( + MakeConstScalar(t.element_of(), value), t.lanes()); + } +} + +inline Expr make_zero(Type t) { + if (t.is_handle()) { + return reinterpret(t, make_const(UInt(64), 0)); + } + return make_const(t, 0); +} + +// additional const expression overloading +#define TVM_DEFINE_ASSIGN_OP_OVERLOAD(Name, OpFunc) \ + inline Expr Name(Expr& a, Expr b) { \ + a = OpFunc(a, b); \ + return a; \ + } +#define TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(Name) \ + inline Expr Name(const Expr& a, float b) { \ + return Name(a, Expr(b)); \ + } \ + inline Expr Name(float a, const Expr& b) { \ + return Name(Expr(a), b); \ + } \ + inline Expr Name(int a, const Expr& b) { \ + return Name(make_const(b.type(), a), b); \ + } \ + inline Expr Name(const Expr& a, int b) { \ + return Name(a, make_const(a.type(), b)); \ + } + +#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \ + inline Expr Name(const Expr& a, bool b) { \ + return Name(a, Expr(b)); \ + } \ + inline Expr Name(bool a, const Expr& b) { \ + return Name(Expr(a), b); \ + } + +#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \ + inline Expr Name(const Expr& a, int b) { \ + return Name(a, make_const(a.type(), b)); \ + } \ + inline Expr Name(int a, const Expr& b) { \ + return Name(make_const(b.type(), a), b); \ + } + + +TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator+=, operator+); +TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator-=, operator-); +TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator*=, operator*); +TVM_DEFINE_ASSIGN_OP_OVERLOAD(operator/=, operator/); +TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator+); +TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator-); +TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator*); +TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator/); +TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(max); +TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(min); +TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>); // NOLINT(*) +TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator>=); +TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<); // NOLINT(*) +TVM_DEFINE_BINOP_CONST_VAL_OVERLOAD(operator<=); +// integer related ops +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator%); +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator>>); // NOLINT(*) +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator<<); // NOLINT(*) +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator&); +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator|); +TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(operator^); +// logical ops +TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&); +TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator||); + +} // namespace tvm #endif // TVM_IR_OPERATOR_H_ diff --git a/include/tvm/tensor.h b/include/tvm/tensor.h index e205f6b9f..7665e724b 100644 --- a/include/tvm/tensor.h +++ b/include/tvm/tensor.h @@ -13,6 +13,7 @@ #include "base.h" #include "expr.h" +#include "ir_operator.h" #include "arithmetic.h" #include "node/container.h" diff --git a/nnvm/src/top/tensor/reduce.cc b/nnvm/src/top/tensor/reduce.cc index 91d2ea720..7241c4b4b 100644 --- a/nnvm/src/top/tensor/reduce.cc +++ b/nnvm/src/top/tensor/reduce.cc @@ -354,7 +354,7 @@ Example:: if (!r_axes.ndim()) return Array<Tensor> { topi::identity(inputs[0]) }; auto axis = ShapeToArray(r_axes); - Expr count = make_one(inputs[0]->dtype); + Expr count = make_const(inputs[0]->dtype, 1); for (auto& i : r_axes) { count *= inputs[0]->shape[i]; } diff --git a/python/tvm/api.py b/python/tvm/api.py index 34fe2ba49..8cf507de6 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -156,9 +156,9 @@ def any(*args): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - ret = _expr.Or(args[0], args[1]) + ret = _make._OpOr(args[0], args[1]) for i in range(2, len(args)): - ret = _expr.Or(ret, args[i]) + ret = _make._OpOr(ret, args[i]) return ret @@ -180,9 +180,9 @@ def all(*args): raise ValueError("Any must take at least 1 argument") if len(args) == 1: return args[0] - ret = _expr.And(args[0], args[1]) + ret = _make._OpAnd(args[0], args[1]) for i in range(2, len(args)): - ret = _expr.And(ret, args[i]) + ret = _make._OpAnd(ret, args[i]) return ret @@ -773,5 +773,5 @@ def comm_reducer(fcombine, fidentity, name="reduce"): _init_api("tvm.api") #pylint: disable=unnecessary-lambda sum = comm_reducer(lambda x, y: x+y, lambda t: const(0, dtype=t), name="sum") -min = comm_reducer(lambda x, y: _expr.Min(x, y), max_value, name='min') -max = comm_reducer(lambda x, y: _expr.Max(x, y), min_value, name='max') +min = comm_reducer(lambda x, y: _make._OpMin(x, y), max_value, name='min') +max = comm_reducer(lambda x, y: _make._OpMax(x, y), min_value, name='max') diff --git a/python/tvm/expr.py b/python/tvm/expr.py index 1c1c9f82c..00a523416 100644 --- a/python/tvm/expr.py +++ b/python/tvm/expr.py @@ -60,7 +60,7 @@ class ExprOp(object): return self.__rdiv__(other) def __mod__(self, other): - return _make.Mod(self, other) + return _make._OpMod(self, other) def __neg__(self): neg_one = _api_internal._const(-1, self.dtype) @@ -85,10 +85,10 @@ class ExprOp(object): return _make.Call(self.dtype, "bitwise_not", [self], Call.PureIntrinsic, None, 0) def __lt__(self, other): - return _make.LT(self, other) + return _make._OpLT(self, other) def __le__(self, other): - return _make.LE(self, other) + return _make._OpLE(self, other) def __eq__(self, other): return EqualOp(self, other) @@ -97,10 +97,10 @@ class ExprOp(object): return NotEqualOp(self, other) def __gt__(self, other): - return _make.GT(self, other) + return _make._OpGT(self, other) def __ge__(self, other): - return _make.GE(self, other) + return _make._OpGE(self, other) def __nonzero__(self): raise ValueError("Cannot use and / or / not operator to Expr, hint: " + @@ -122,7 +122,7 @@ class ExprOp(object): ret : Expr The equality expression. """ - return _make.EQ(self, other) + return _make._OpEQ(self, other) def astype(self, dtype): """Cast the expression to other type. @@ -169,7 +169,7 @@ class EqualOp(NodeGeneric, ExprOp): def asnode(self): """Convert node.""" - return _make.EQ(self.a, self.b) + return _make._OpEQ(self.a, self.b) class NotEqualOp(NodeGeneric, ExprOp): @@ -201,7 +201,7 @@ class NotEqualOp(NodeGeneric, ExprOp): def asnode(self): """Convert node.""" - return _make.NE(self.a, self.b) + return _make._OpNE(self.a, self.b) class Expr(ExprOp, NodeBase): diff --git a/python/tvm/generic.py b/python/tvm/generic.py index 2926f73d5..ab1a80d3f 100644 --- a/python/tvm/generic.py +++ b/python/tvm/generic.py @@ -24,7 +24,7 @@ def add(lhs, rhs): op : tvm.Expr The result Expr of add operaton. """ - return _make.Add(lhs, rhs) + return _make._OpAdd(lhs, rhs) def subtract(lhs, rhs): @@ -42,7 +42,7 @@ def subtract(lhs, rhs): op : tvm.Expr The result Expr of subtract operaton. """ - return _make.Sub(lhs, rhs) + return _make._OpSub(lhs, rhs) def multiply(lhs, rhs): @@ -60,7 +60,7 @@ def multiply(lhs, rhs): op : tvm.Expr The result Expr of multiply operaton. """ - return _make.Mul(lhs, rhs) + return _make._OpMul(lhs, rhs) def divide(lhs, rhs): @@ -78,7 +78,7 @@ def divide(lhs, rhs): op : tvm.Expr The result Expr of divide operaton. """ - return _make.Div(lhs, rhs) + return _make._OpDiv(lhs, rhs) def cast(src, dtype): diff --git a/src/api/api_ir.cc b/src/api/api_ir.cc index 8a65260a0..1040f6ce6 100644 --- a/src/api/api_ir.cc +++ b/src/api/api_ir.cc @@ -5,7 +5,7 @@ */ #include <tvm/expr.h> #include <tvm/ir.h> -#include <ir/IROperator.h> +#include <tvm/ir_operator.h> #include <tvm/api_registry.h> #include <tvm/ir_operator.h> @@ -117,6 +117,50 @@ TVM_REGISTER_API("make.CommReducer") *ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \ }) \ + +REGISTER_MAKE5(Reduce); +REGISTER_MAKE4(AttrStmt); + +REGISTER_MAKE2(IntImm); +REGISTER_MAKE2(UIntImm); +REGISTER_MAKE2(FloatImm); +REGISTER_MAKE1(StringImm); + +REGISTER_MAKE2(Add); +REGISTER_MAKE2(Sub); +REGISTER_MAKE2(Mul); +REGISTER_MAKE2(Div); +REGISTER_MAKE2(Mod); +REGISTER_MAKE2(Min); +REGISTER_MAKE2(Max); +REGISTER_MAKE2(EQ); +REGISTER_MAKE2(NE); +REGISTER_MAKE2(LT); +REGISTER_MAKE2(LE); +REGISTER_MAKE2(GT); +REGISTER_MAKE2(GE); +REGISTER_MAKE2(And); +REGISTER_MAKE2(Or); + +REGISTER_MAKE1(Not); +REGISTER_MAKE3(Select); +REGISTER_MAKE3(Ramp); +REGISTER_MAKE2(Cast); +REGISTER_MAKE2(Broadcast); +REGISTER_MAKE2(Shuffle); +REGISTER_MAKE3(Let); +REGISTER_MAKE3(LetStmt); +REGISTER_MAKE3(AssertStmt); +REGISTER_MAKE3(ProducerConsumer); +REGISTER_MAKE5(Allocate); +REGISTER_MAKE4(Provide); +REGISTER_MAKE4(Prefetch); +REGISTER_MAKE1(Free); +REGISTER_MAKE2(Block); +REGISTER_MAKE3(IfThenElse); +REGISTER_MAKE1(Evaluate); + +// operator overloading, smarter than make #define REGISTER_MAKE_BINARY_OP(Node, Func) \ TVM_REGISTER_API("make."#Node) \ .set_body([](TVMArgs args, TVMRetValue *ret) { \ @@ -138,50 +182,27 @@ TVM_REGISTER_API("make.CommReducer") } \ }) -REGISTER_MAKE5(Reduce); -REGISTER_MAKE4(AttrStmt); -REGISTER_MAKE2(IntImm); -REGISTER_MAKE2(UIntImm); -REGISTER_MAKE2(FloatImm); -REGISTER_MAKE1(StringImm); -REGISTER_MAKE_BINARY_OP(Add, operator+); -REGISTER_MAKE_BINARY_OP(Sub, operator-); -REGISTER_MAKE_BINARY_OP(Mul, operator*); -REGISTER_MAKE_BINARY_OP(Div, operator/); -REGISTER_MAKE_BINARY_OP(Mod, operator%); -REGISTER_MAKE_BINARY_OP(Min, min); -REGISTER_MAKE_BINARY_OP(Max, max); -REGISTER_MAKE_BINARY_OP(EQ, operator==); -REGISTER_MAKE_BINARY_OP(NE, operator!=); -REGISTER_MAKE_BINARY_OP(LT, operator<); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(LE, operator<=); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(GT, operator>); // NOLINT(*) -REGISTER_MAKE_BINARY_OP(GE, operator>=); -REGISTER_MAKE_BINARY_OP(And, operator&&); -REGISTER_MAKE_BINARY_OP(Or, operator||); +REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); +REGISTER_MAKE_BINARY_OP(_OpSub, operator-); +REGISTER_MAKE_BINARY_OP(_OpMul, operator*); +REGISTER_MAKE_BINARY_OP(_OpDiv, operator/); +REGISTER_MAKE_BINARY_OP(_OpMod, operator%); +REGISTER_MAKE_BINARY_OP(_OpMin, min); +REGISTER_MAKE_BINARY_OP(_OpMax, max); +REGISTER_MAKE_BINARY_OP(_OpEQ, operator==); +REGISTER_MAKE_BINARY_OP(_OpNE, operator!=); +REGISTER_MAKE_BINARY_OP(_OpLT, operator<); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpLE, operator<=); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpGT, operator>); // NOLINT(*) +REGISTER_MAKE_BINARY_OP(_OpGE, operator>=); +REGISTER_MAKE_BINARY_OP(_OpAnd, operator&&); +REGISTER_MAKE_BINARY_OP(_OpOr, operator||); REGISTER_MAKE_BIT_OP(bitwise_and, operator&); REGISTER_MAKE_BIT_OP(bitwise_or, operator|); REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(right_shift, operator>>); -REGISTER_MAKE1(Not); -REGISTER_MAKE3(Select); -REGISTER_MAKE3(Ramp); -REGISTER_MAKE2(Cast); -REGISTER_MAKE2(Broadcast); -REGISTER_MAKE2(Shuffle); -REGISTER_MAKE3(Let); -REGISTER_MAKE3(LetStmt); -REGISTER_MAKE3(AssertStmt); -REGISTER_MAKE3(ProducerConsumer); -REGISTER_MAKE5(Allocate); -REGISTER_MAKE4(Provide); -REGISTER_MAKE4(Prefetch); -REGISTER_MAKE1(Free); -REGISTER_MAKE2(Block); -REGISTER_MAKE3(IfThenElse); -REGISTER_MAKE1(Evaluate); } // namespace ir } // namespace tvm diff --git a/src/arithmetic/compute_expr.h b/src/arithmetic/compute_expr.h index 5f44347f3..218e9d218 100644 --- a/src/arithmetic/compute_expr.h +++ b/src/arithmetic/compute_expr.h @@ -14,10 +14,6 @@ namespace tvm { namespace arith { -using HalideIR::Internal::add_would_overflow; -using HalideIR::Internal::sub_would_overflow; -using HalideIR::Internal::mul_would_overflow; - /*! * \brief Compute the expression with the given binary op. * \param lhs The left operand @@ -42,23 +38,9 @@ template<typename Op> inline Expr ComputeReduce( const Array<Expr>& values, Expr empty_value); -template<typename T> -inline bool GetConst(Expr e, T* out); - -template<> -inline bool GetConst<int64_t>(Expr e, int64_t *out) { - if (e.type().is_vector()) return false; - const int64_t *v = as_const_int(e); - if (v) { - *out = *v; return true; - } else { - return false; - } -} -template<> -inline bool GetConst<uint64_t>(Expr e, uint64_t *out) { +inline bool GetConst(Expr e, int64_t* out) { if (e.type().is_vector()) return false; - const uint64_t *v = as_const_uint(e); + const int64_t* v = as_const_int(e); if (v) { *out = *v; return true; } else { @@ -69,66 +51,37 @@ inline bool GetConst<uint64_t>(Expr e, uint64_t *out) { // get a small constant int inline bool GetConstInt(Expr e, int* out) { int64_t v1 = 0; - uint64_t v2 = 0; if (GetConst(e, &v1)) { if (v1 > static_cast<int64_t>( std::numeric_limits<int>::max())) return false; *out = static_cast<int>(v1); return true; } - if (GetConst(e, &v2)) { - if (v2 > static_cast<uint64_t>( - std::numeric_limits<int>::max())) return false; - *out = static_cast<int>(v2); return true; - } return false; } -#define TVM_CONST_PROPAGATION(OP_NAME, OP) \ - int64_t ia = 0, ib = 0; \ - if (GetConst(a, &ia) && GetConst(b, &ib)) { \ - if (OP_NAME ## _would_overflow(a.type().bits(), ia, ib)) { \ - LOG(FATAL) << "signed int overflow"; \ - } \ - return ir::IntImm::make(a.type(), ia OP ib); \ - } \ - uint64_t ua = 0, ub = 0; \ - if (GetConst(a, &ua) && GetConst(b, &ub)) { \ - return ir::UIntImm::make(a.type(), ua OP ub); \ - } \ - template<> inline Expr ComputeExpr<ir::Add>(Expr a, Expr b) { - if (is_zero(a)) return b; - if (is_zero(b)) return a; - TVM_CONST_PROPAGATION(add, +); - return ir::Add::make(a, b); + return a + b; } template<> inline Expr ComputeExpr<ir::Sub>(Expr a, Expr b) { - if (is_zero(b)) return a; - TVM_CONST_PROPAGATION(sub, -); - return ir::Sub::make(a, b); + return a - b; } template<> inline Expr ComputeExpr<ir::Mul>(Expr a, Expr b) { - if (is_one(a)) return b; - if (is_one(b)) return a; - TVM_CONST_PROPAGATION(mul, *); - return ir::Mul::make(a, b); + return a * b; } template<> inline Expr ComputeExpr<ir::Div>(Expr a, Expr b) { - if (is_one(b)) return a; - return ir::Div::make(a, b); + return a / b; } template<> inline Expr ComputeExpr<ir::Mod>(Expr a, Expr b) { - if (is_zero(a)) return make_zero(a.type()); - return ir::Mod::make(a, b); + return a % b; } template<> diff --git a/src/arithmetic/detect_linear_equation.cc b/src/arithmetic/detect_linear_equation.cc index 109cdc6d9..4e6d8caf3 100644 --- a/src/arithmetic/detect_linear_equation.cc +++ b/src/arithmetic/detect_linear_equation.cc @@ -194,7 +194,7 @@ bool DetectClipBound( if (!LinearEqDetector(var).Detect(canonical, &ret)) return false; ret.coeff = Simplify(ret.coeff); IntervalEntry& p = (*bmap)[var.get()]; - if (is_one(ret.coeff)) { + if (is_const_int(ret.coeff, 1)) { // var + shift >=0 -> var >= -shift if (p.min_value.defined()) { p.min_value = ir::Max::make(p.min_value, -ret.base); @@ -203,7 +203,7 @@ bool DetectClipBound( } return true; } - if (is_const(ret.coeff, -1)) { + if (is_const_int(ret.coeff, -1)) { // -var + shift >=0 -> var <= shift if (p.max_value.defined()) { p.max_value = ir::Min::make(p.max_value, ret.base); diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 7c8399cfc..0960106ae 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -42,7 +42,7 @@ std::string CodeGenCUDA::Finish() { } void CodeGenCUDA::VisitStmt_(const ir::For* op) { - CHECK(is_zero(op->min)); + CHECK(is_const_int(op->min, 0)); if (op->for_type == ir::ForType::Unrolled) { PrintIndent(); stream << "#pragma unroll\n"; diff --git a/src/codegen/verilog/verilog_ir.cc b/src/codegen/verilog/verilog_ir.cc index dea8ebaeb..0cc4b9cf3 100644 --- a/src/codegen/verilog/verilog_ir.cc +++ b/src/codegen/verilog/verilog_ir.cc @@ -195,7 +195,7 @@ class PipelineExtractor: public IRVisitor { ChannelEntry& cb = cmap_.at(ch->handle_var.get()); trigger->signal_index = static_cast<int>(cb.node->ctrl_signals.size()); // Grab the advance constant size. - int trigger_size; + int trigger_size = 0; if (attr->attr_key == attr::pipeline_stage_scope) { cb.node->ctrl_signals.push_back( ControlSignalNode::make(kComputeFinish, 0)); diff --git a/src/lang/expr.cc b/src/lang/expr.cc index 062ea9217..7ac0e3723 100644 --- a/src/lang/expr.cc +++ b/src/lang/expr.cc @@ -5,6 +5,7 @@ #include <tvm/base.h> #include <tvm/expr.h> #include <tvm/ir.h> +#include <tvm/ir_operator.h> #include <ir/IRPrinter.h> #include <memory> diff --git a/src/lang/ir_operator.cc b/src/lang/ir_operator.cc index 5cad23e8c..307427643 100644 --- a/src/lang/ir_operator.cc +++ b/src/lang/ir_operator.cc @@ -8,6 +8,406 @@ namespace tvm { +/*! + * \brief Check whether type is used to represent index. + * + * Index types are frequently used in shape computation + * and need to be aggressively constant-folded. + * + * \param type The type to represent index. + * \return the checked result. + */ +inline bool IsIndexType(const Type& type) { + return type.is_int() && type.lanes() == 1 && + (type.bits() == 32 || type.bits() == 64); +} + +// simple cast that only checks if type matches and cast +inline Expr SimpleCast(const Type& t, Expr value) { + if (value.type() == t) return value; + return ir::Cast::make(t, value); +} + +// The public function with a quick checking path. +void BinaryOpMatchTypes(Expr& lhs, Expr& rhs) { // NOLINT(*) + if (lhs.type() == rhs.type()) return; + Type ltype = lhs.type(); + Type rtype = rhs.type(); + if (ltype.lanes() == 1 && rtype.lanes() != 1) { + lhs = ir::Broadcast::make(lhs, rtype.lanes()); + } else if (rtype.lanes() == 1 && ltype.lanes() != 1) { + rhs = ir::Broadcast::make(rhs, ltype.lanes()); + } else { + CHECK(ltype.lanes() == rtype.lanes()) + << "Cannot match type " << ltype << " vs " << rtype; + } + if (lhs.type() == rhs.type()) return; + // Only do very simple type coversion + // int->float, int(32)->int(64) + // require the types to be relatively consistent + // This will the reduce amount code generated by operators + // and also help user to find potential type conversion problems. + if (!lhs.type().is_float() && rhs.type().is_float()) { + // int->float + lhs = ir::Cast::make(rhs.type(), lhs); + } else if (lhs.type().is_float() && !rhs.type().is_float()) { + // int->float + rhs = ir::Cast::make(lhs.type(), rhs); + } else if ((lhs.type().is_int() && rhs.type().is_int()) || + (lhs.type().is_uint() && rhs.type().is_uint())) { + // promote int to higher bits + if (lhs.type().bits() < rhs.type().bits()) { + lhs = ir::Cast::make(rhs.type(), lhs); + } else { + rhs = ir::Cast::make(lhs.type(), rhs); + } + } else if ((lhs.type().is_int() && rhs.type().is_uint()) || + (lhs.type().is_uint() && rhs.type().is_int())) { + int bits = std::max(lhs.type().bits(), rhs.type().bits()); + lhs = SimpleCast(Int(bits, lhs.type().lanes()), lhs); + rhs = SimpleCast(Int(bits, rhs.type().lanes()), rhs); + } else { + LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype; + } +} + + +template<typename ValueType> +inline bool ConstPowerHelper(ValueType val, int *shift) { + if (val <= 0) return false; + shift[0] = 0; + while (val != 0) { + if (val & 1) { + return (val == 1); + } + ++shift[0]; + val = val >> 1; + } + return true; +} + +bool is_const_power_of_two_integer(const Expr& x, int* shift) { + if (const auto* op = x.as<ir::IntImm>()) { + return ConstPowerHelper(op->value, shift); + } else if (const auto* op = x.as<ir::UIntImm>()) { + return ConstPowerHelper(op->value, shift); + } else { + return false; + } +} + +Expr cast(const Type& t, Expr value) { + using ir::IntImm; + if (value.type() == t) return value; + // const fold IntImm as they are used in index computations + if (t.lanes() == 1) { + if (const IntImm* op = value.as<IntImm>()) { + return make_const(t, op->value); + } + return ir::Cast::make(t, value); + } else { + if (value.type().lanes() == 1) { + // manually unroll cast + Type vtype = t.element_of(); + if (value.type() != vtype) { + if (const IntImm* op = value.as<IntImm>()) { + value = make_const(vtype, op->value); + } else { + value = ir::Cast::make(vtype, value); + } + } + return ir::Broadcast::make(value, t.lanes()); + } else { + CHECK(value.type().lanes() == t.lanes()); + return ir::Cast::make(t, value); + } + } +} + +Expr reinterpret(const Type& t, Expr value) { + if (value.type() == t) return value; + return ir::Call::make(t, ir::Call::reinterpret, { value }, ir::Call::PureIntrinsic); +} + +#define TVM_CONST_PROPAGATION(BODY) \ + using ir::IntImm; \ + using ir::UIntImm; \ + const IntImm* pa = a.as<IntImm>(); \ + const IntImm* pb = b.as<IntImm>(); \ + const Type& ta = a.type(); \ + const Type& tb = b.type(); \ + if (IsIndexType(ta) && IsIndexType(tb)) { \ + BODY; \ + } \ + BinaryOpMatchTypes(a, b); + + +Expr operator+(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + Type rtype = ta.bits() >= tb.bits() ? ta : tb; + if (pa && pb) return IntImm::make(rtype, pa->value + pb->value); + if (pa && pa->value == 0) return SimpleCast(rtype, b); + if (pb && pb->value == 0) return SimpleCast(rtype, a); + }); + return ir::Add::make(a, b); +} + +Expr operator-(Expr a) { + using ir::IntImm; + const IntImm* pa = a.as<IntImm>(); + if (pa) { + return ir::IntImm::make(a.type(), -pa->value); + } + return make_zero(a.type()) - a; +} + +Expr operator-(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + Type rtype = ta.bits() >= tb.bits() ? ta : tb; + if (pa && pb) return IntImm::make(rtype, pa->value - pb->value); + if (pb && pb->value == 0) return SimpleCast(rtype, a); + }); + return ir::Sub::make(a, b); +} + +Expr operator*(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + Type rtype = ta.bits() >= tb.bits() ? ta : tb; + if (pa && pb) return IntImm::make(rtype, pa->value * pb->value); + if (pa) { + if (pa->value == 1) return SimpleCast(rtype, b); + if (pa->value == 0) return SimpleCast(rtype, a); + } + if (pb) { + if (pb->value == 1) return SimpleCast(rtype, a); + if (pb->value == 0) return SimpleCast(rtype, b); + } + }); + return ir::Mul::make(a, b); +} + +Expr operator/(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + Type rtype = ta.bits() >= tb.bits() ? ta : tb; + // due to division and mod can have different modes + // only constant fold positive number where rule is fixed. + if (pa && pb && pa->value >= 0 && pb->value > 0) { + return IntImm::make(rtype, pa->value / pb->value); + } + if (pa) { + if (pa->value == 0) return SimpleCast(rtype, a); + } + if (pb) { + if (pb->value == 1) return SimpleCast(rtype, a); + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + }); + return ir::Div::make(a, b); +} + +Expr operator%(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + Type rtype = ta.bits() >= tb.bits() ? ta : tb; + // due to division and mod can have different modes + // only constant fold positive number where rule is fixed. + if (pa && pb && pa->value >= 0 && pb->value > 0) { + return IntImm::make(rtype, pa->value % pb->value); + } + if (pa) { + if (pa->value == 0) return SimpleCast(rtype, a); + } + if (pb) { + if (pb->value == 1) return make_zero(rtype); + CHECK_NE(pb->value, 0) << "Divide by zero"; + } + }); + return ir::Mod::make(a, b); +} + +Expr min(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + Type rtype = ta.bits() >= tb.bits() ? ta : tb; + if (pa && pb) return IntImm::make(rtype, std::min(pa->value, pb->value)); + }); + return ir::Min::make(a, b); +} + +Expr max(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + Type rtype = ta.bits() >= tb.bits() ? ta : tb; + if (pa && pb) return IntImm::make(rtype, std::max(pa->value, pb->value)); + }); + return ir::Max::make(a, b); +} + +Expr select(Expr cond, Expr true_value, Expr false_value) { + using ir::IntImm; + using ir::UIntImm; + CHECK(cond.type().is_bool()); + BinaryOpMatchTypes(true_value, false_value); + if (const UIntImm* op = cond.as<UIntImm>()) { + if (op->value != 0) { + return true_value; + } else { + return false_value; + } + } else if (const IntImm* op = cond.as<IntImm>()) { + if (op->value != 0) { + return true_value; + } else { + return false_value; + } + } + return ir::Select::make(cond, true_value, false_value); +} + +Expr likely(Expr cond) { + if (is_const(cond)) return cond; + return ir::Call::make(cond.type(), ir::Call::likely, { cond }, ir::Call::PureIntrinsic); +} + +Expr operator>(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + if (pa && pb) return UIntImm::make(UInt(1), pa->value > pb->value); + }); + return ir::GT::make(a, b); +} + +Expr operator>=(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + if (pa && pb) return UIntImm::make(UInt(1), pa->value >= pb->value); + }); + return ir::GE::make(a, b); +} + +Expr operator<(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + if (pa && pb) return UIntImm::make(UInt(1), pa->value < pb->value); + }); + return ir::LT::make(a, b); +} + +Expr operator<=(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + if (pa && pb) return UIntImm::make(UInt(1), pa->value <= pb->value); + }); + return ir::LE::make(a, b); +} + +Expr operator==(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + if (pa && pb) return UIntImm::make(UInt(1), pa->value == pb->value); + }); + return ir::EQ::make(a, b); +} + +Expr operator!=(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + if (pa && pb) return UIntImm::make(UInt(1), pa->value != pb->value); + }); + return ir::NE::make(a, b); +} + +Expr operator&&(Expr a, Expr b) { + using ir::UIntImm; + const UIntImm* pa = a.as<UIntImm>(); + const UIntImm* pb = b.as<UIntImm>(); + if (pa && pb) { + return UIntImm::make(UInt(1), pa->value && pb->value); + } + return ir::And::make(a, b); +} + +Expr operator||(Expr a, Expr b) { + using ir::UIntImm; + const UIntImm* pa = a.as<UIntImm>(); + const UIntImm* pb = b.as<UIntImm>(); + if (pa && pb) { + return UIntImm::make(UInt(1), pa->value || pb->value); + } + return ir::Or::make(a, b); +} + +Expr operator!(Expr a) { + using ir::UIntImm; + const UIntImm* pa = a.as<UIntImm>(); + if (pa) { + return UIntImm::make(UInt(1), !(pa->value)); + } + return ir::Not::make(a); +} + +Expr operator>>(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + Type rtype = ta.bits() >= tb.bits() ? ta : tb; + if (pa && pb) return IntImm::make(rtype, (pa->value >> pb->value)); + if (pb) { + if (pb->value == 0) return SimpleCast(rtype, a); + } + }); + return ir::Call::make(a.type(), ir::Call::shift_right, { a, b }, ir::Call::PureIntrinsic); +} + +Expr operator<<(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + Type rtype = ta.bits() >= tb.bits() ? ta : tb; + if (pa && pb) return IntImm::make(rtype, (pa->value << pb->value)); + if (pb) { + if (pb->value == 0) return SimpleCast(rtype, a); + } + }); + return ir::Call::make(a.type(), ir::Call::shift_left, { a, b }, ir::Call::PureIntrinsic); +} + +Expr operator&(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + Type rtype = ta.bits() >= tb.bits() ? ta : tb; + if (pa && pb) return IntImm::make(rtype, (pa->value & pb->value)); + }); + return ir::Call::make(a.type(), ir::Call::bitwise_and, { a, b }, ir::Call::PureIntrinsic); +} + +Expr operator|(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + Type rtype = ta.bits() >= tb.bits() ? ta : tb; + if (pa && pb) return IntImm::make(rtype, (pa->value | pb->value)); + }); + return ir::Call::make(a.type(), ir::Call::bitwise_or, { a, b }, ir::Call::PureIntrinsic); +} + +Expr operator^(Expr a, Expr b) { + TVM_CONST_PROPAGATION({ + Type rtype = ta.bits() >= tb.bits() ? ta : tb; + if (pa && pb) return IntImm::make(rtype, (pa->value ^ pb->value)); + }); + return ir::Call::make(a.type(), ir::Call::bitwise_xor, { a, b }, ir::Call::PureIntrinsic); +} + +Expr operator~(Expr a) { + CHECK(a.type().is_int() || a.type().is_uint()); + return ir::Call::make(a.type(), ir::Call::bitwise_not, { a }, ir::Call::PureIntrinsic); +} + +Expr pow(Expr x, Expr y) { + BinaryOpMatchTypes(x, y); + CHECK(x.type().is_float()) << "power only applies to float"; + return ir::Call::make(x.type(), "pow", { x, y }, ir::Call::PureIntrinsic); +} + +Expr abs(Expr x) { + if (x.type().is_int()) { + return select(x >= make_zero(x.type()), x, -x); + } else if (x.type().is_float()) { + return ir::Call::make(x.type(), "fabs", {x}, ir::Call::PureIntrinsic); + } else if (x.type().is_uint()) { + return x; + } else { + LOG(FATAL) << "Data type " << x.type() + <<" not supported for absolute op. Skipping absolute op..."; + return x; + } +} + Expr sum(Expr source, Array<IterVar> rdom) { Var x("x", source.type()), y("y", source.type()); Expr result = ir::Add::make(x, y); @@ -38,7 +438,7 @@ Expr min(Expr source, Array<IterVar> rdom) { Expr prod(Expr source, Array<IterVar> rdom) { Var x("x", source.type()), y("y", source.type()); Expr result = ir::Mul::make(x, y); - Expr identity_element = make_one(source.type()); + Expr identity_element = make_const(source.type(), 1); ir::CommReducer combiner = ir::CommReducerNode::make({x}, {y}, {result}, {identity_element}); return ir::Reduce::make(combiner, {source}, rdom, make_const(Bool(1), true), 0); diff --git a/src/pass/ir_util.h b/src/pass/ir_util.h index f871133fb..3cef4486e 100644 --- a/src/pass/ir_util.h +++ b/src/pass/ir_util.h @@ -7,6 +7,7 @@ #define TVM_PASS_IR_UTIL_H_ #include <tvm/ir.h> +#include <tvm/ir_operator.h> #include <tvm/runtime/device_api.h> #include <vector> @@ -75,7 +76,7 @@ inline Expr TVMStructGet( Array<Expr> args ={ handle, make_const(Int(32), index), - make_const(Int(32), kind)}; + make_const(Int(32), static_cast<int>(kind))}; return Call::make(dtype, intrinsic::tvm_struct_get, args, Call::PureIntrinsic); } @@ -125,7 +126,7 @@ inline Stmt TVMStructSet( Array<Expr> args ={ handle, make_const(Int(32), index), - make_const(Int(32), kind), + make_const(Int(32), static_cast<int>(kind)), value}; return Evaluate::make( Call::make(Int(32), intrinsic::tvm_struct_set, args, Call::Intrinsic)); diff --git a/src/pass/split_pipeline.cc b/src/pass/split_pipeline.cc index 0dd5bd651..c143a0d19 100644 --- a/src/pass/split_pipeline.cc +++ b/src/pass/split_pipeline.cc @@ -102,9 +102,8 @@ class MarkChannelAccess : public IRMutator { } else { alloc_size = op->extents[0]; for (size_t i = 1; i < op->extents.size(); ++i) { - alloc_size *= op->extents[i]; + alloc_size = alloc_size * op->extents[i]; } - alloc_size = ir::Simplify(alloc_size); } if (rw.write_count) { diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 2bab21d85..54f5010f1 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -578,7 +578,7 @@ class StoragePlanRewriter : public IRMutator { combo_size = combo_size / type_bits; // round up for can not divided if (!divided) { - combo_size += make_const(Int(32), 1); + combo_size = combo_size + make_const(Int(32), 1); } combo_size = ir::Simplify(combo_size); e->new_alloc = Allocate::make( diff --git a/src/pass/vectorize_loop.cc b/src/pass/vectorize_loop.cc index 206b75ed0..fe2f81980 100644 --- a/src/pass/vectorize_loop.cc +++ b/src/pass/vectorize_loop.cc @@ -437,7 +437,6 @@ class LoopVectorizer : public IRMutator { Stmt Mutate_(const For* op, const Stmt& s) final { if (op->for_type == ForType::Vectorized) { CHECK(is_zero(op->min)); - CHECK(is_positive_const(op->extent)); int lanes = 0; bool succ = arith::GetConstInt(op->extent, &lanes); if (!succ || lanes < 1) { diff --git a/tests/cpp/ir_mutator_test.cc b/tests/cpp/ir_mutator_test.cc index fd5a60756..0802d405b 100644 --- a/tests/cpp/ir_mutator_test.cc +++ b/tests/cpp/ir_mutator_test.cc @@ -1,6 +1,7 @@ #include <dmlc/logging.h> #include <gtest/gtest.h> #include <tvm/ir_mutator.h> +#include <tvm/ir_operator.h> namespace { using namespace tvm::ir; diff --git a/tests/python/unittest/test_arith_intset.py b/tests/python/unittest/test_arith_intset.py index 78589cf3a..9b869fedd 100644 --- a/tests/python/unittest/test_arith_intset.py +++ b/tests/python/unittest/test_arith_intset.py @@ -35,7 +35,7 @@ def test_deduce(): e1 = (a*4+b < c) res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - ans1 = (((c - b) + -1)/4) + ans1 = (((c - b) + -1)/4) assert str(tvm.ir_pass.Simplify(res1.max())) == str(ans1) e2 = (tvm.max(5, a * 4) < 0) @@ -63,7 +63,7 @@ def test_check(): assert res1.is_nothing() # multiple compare operators - res2 = tvm.arith.DeduceBound(a, (a+b>3)>c , {b: b_s, c: c_s}, {}) + res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: c_s}, {}) assert res2.is_nothing() # multiple target variable @@ -88,11 +88,11 @@ def test_deduce_basic(): res1 = tvm.arith.DeduceBound(a, e0<=17, {b: b_s}, {b: b_s}) [x, y] = [res1.max(), b_s.max()] if coff > 0 else [res1.min(), b_s.min()] assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1 - + res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s}) [x, y] = [res1.max(), b_s.max()] if coff < 0 else [res1.min(), b_s.min()] assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1 - + test_basic(0, 4, 4) test_basic(1, 5, 4) test_basic(2, 6, 4) @@ -137,4 +137,3 @@ if __name__ == "__main__": test_check() test_deduce_basic() test_deduce_complex() - diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py index c9a04747b..bf25ca3df 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_lang_basic.py @@ -8,7 +8,7 @@ def test_const(): def test_make(): x = tvm.const(1) - y = tvm.make.IntImm('int32', 1) + y = tvm.var("x") z = x + y assert isinstance(tvm.max(x, y), tvm.expr.Max) assert isinstance(tvm.min(x, y), tvm.expr.Min) diff --git a/tests/python/unittest/test_lang_operator.py b/tests/python/unittest/test_lang_operator.py new file mode 100644 index 000000000..9c701ed2a --- /dev/null +++ b/tests/python/unittest/test_lang_operator.py @@ -0,0 +1,35 @@ +import tvm + +def test_const_fold(): + def check(f, *args): + x = f(*[tvm.const(x) for x in args]) + y = f(*args) + if not isinstance(x, (tvm.expr.IntImm, tvm.expr.UIntImm)) or x.value != int(y): + raise ValueError("check error: %s vs %s " % (x, y)) + + check(lambda x, y: x + y, 3, 4) + check(lambda x, y: x * y, 3, 12) + check(lambda x, y: x * y - 10, 3, 12) + check(lambda x, y: x - y % 10, 3, 12) + check(lambda x, y: x // y + 10, 100, 12) + check(lambda x, y: x & y + 10, 112, 128) + check(lambda x, y: x > y, 112, 128) + check(lambda x, y: x < y, 112, 128) + check(lambda x, y: x <= y, 112, 128) + check(lambda x, y: x >= y, 112, 128) + check(lambda x, y: (x | y) ^ 10, 112, 128) + + +def test_const_fold2(): + x = tvm.var("x") + assert (x + 0).same_as(x) + assert (0 + x).same_as(x) + assert (x - 0).same_as(x) + assert (x % 1).value == 0 + assert (x * 1).same_as(x) + assert (1 * x).same_as(x) + assert isinstance((1 / x), tvm.expr.Div) + +if __name__ == "__main__": + test_const_fold() + test_const_fold2() diff --git a/tests/python/unittest/test_lang_reflection.py b/tests/python/unittest/test_lang_reflection.py index 83b440a2c..3ec760f20 100644 --- a/tests/python/unittest/test_lang_reflection.py +++ b/tests/python/unittest/test_lang_reflection.py @@ -15,7 +15,7 @@ def test_make_smap(): # save load json x = tvm.const(1) y = tvm.const(10) - z = x + y + z = tvm.expr.Add(x, y) smap = tvm.convert({"z": z, "x": x}) json_str = tvm.save_json(tvm.convert([smap])) arr = tvm.load_json(json_str) diff --git a/tests/python/unittest/test_pass_simplify.py b/tests/python/unittest/test_pass_simplify.py index c38083822..fce6eaed5 100644 --- a/tests/python/unittest/test_pass_simplify.py +++ b/tests/python/unittest/test_pass_simplify.py @@ -53,7 +53,6 @@ def test_canonical(): assert (tvm.ir_pass.Equal(ret1, ret2)) if __name__ == "__main__": - test_modular() test_bound() test_basic() test_simplify() diff --git a/topi/include/topi/elemwise.h b/topi/include/topi/elemwise.h index 88c77f0af..02bc51515 100644 --- a/topi/include/topi/elemwise.h +++ b/topi/include/topi/elemwise.h @@ -163,7 +163,7 @@ inline Tensor full(const Array<Expr>& shape, const Expr fill_value, std::string name = "tensor", std::string tag = kElementWise) { - Expr ev = lossless_cast(dtype, fill_value); + Expr ev = cast(dtype, fill_value); if (!ev.defined()) { LOG(ERROR) << "Can't cast fill_value to " << dtype; } @@ -173,7 +173,7 @@ inline Tensor full(const Array<Expr>& shape, } /*! -* \brief Creates an operation that construct a tensor with same shape as input tensor, +* \brief Creates an operation that construct a tensor with same shape as input tensor, * then fill a tensor with fill_value * * \param x The input tensor @@ -187,10 +187,7 @@ inline Tensor full_like(const Tensor& x, const Expr fill_value, std::string name = "tensor", std::string tag = kElementWise) { - Expr ev = lossless_cast(x->dtype, fill_value); - if (!ev.defined()) { - LOG(ERROR) << "Can't cast fill_value to " << x->dtype; - } + Expr ev = cast(x->dtype, fill_value); return compute(x->shape, [&](const Array<Var>& i) { return ev; }, name, tag); diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index ca318adfe..795d04a31 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -94,10 +94,10 @@ inline Tensor pool_impl(const Tensor& x, out_shape.Set(height_axis, out_height); out_shape.Set(width_axis, out_width); - const int64_t *padding_h0 = HalideIR::Internal::as_const_int(pad_top); - const int64_t *padding_w0 = HalideIR::Internal::as_const_int(pad_left); - const int64_t *padding_h1 = HalideIR::Internal::as_const_int(pad_bottom); - const int64_t *padding_w1 = HalideIR::Internal::as_const_int(pad_right); + const int64_t *padding_h0 = as_const_int(pad_top); + const int64_t *padding_w0 = as_const_int(pad_left); + const int64_t *padding_h1 = as_const_int(pad_bottom); + const int64_t *padding_w1 = as_const_int(pad_right); const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) || ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1)); @@ -192,7 +192,7 @@ inline bool find_height_width(const std::string& layout, * Since pooling does not care about the factor size of dimensions * other than `H` and `W`, one can pass `NCHWc` as well. * \param count_include_pad Whether include padding in the calculation when pool_type is 'avg' -* +* * * \return The output tensor in the same layout */ diff --git a/topi/python/topi/vision/ssd/multibox.py b/topi/python/topi/vision/ssd/multibox.py index a8f971465..4e6e6ab27 100644 --- a/topi/python/topi/vision/ssd/multibox.py +++ b/topi/python/topi/vision/ssd/multibox.py @@ -164,10 +164,10 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho oy = py * vy * ah + ay ow = tvm.exp(pw * vw) * aw / 2.0 oh = tvm.exp(ph * vh) * ah / 2.0 - return tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox - ow)), ox - ow), \ - tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy - oh)), oy - oh), \ - tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, ox + ow)), ox + ow), \ - tvm.select(clip, tvm.make.Max(0, tvm.make.Min(1, oy + oh)), oy + oh) + return tvm.select(clip, tvm.max(0, tvm.min(1, ox - ow)), ox - ow), \ + tvm.select(clip, tvm.max(0, tvm.min(1, oy - oh)), oy - oh), \ + tvm.select(clip, tvm.max(0, tvm.min(1, ox + ow)), ox + ow), \ + tvm.select(clip, tvm.max(0, tvm.min(1, oy + oh)), oy + oh) batch_size = cls_prob.shape[0] num_classes = cls_prob.shape[1] @@ -191,7 +191,7 @@ def transform_loc_ir(cls_prob, loc_pred, anchor, valid_count, out, clip, thresho with ib.if_scope(j > 0): temp = p_cls_prob[n * num_anchors * num_classes + j * num_anchors + i] cls_id[0] = tvm.select(temp > score[0], j, cls_id[0]) - score[0] = tvm.make.Max(temp, score[0]) + score[0] = tvm.max(temp, score[0]) with ib.if_scope(tvm.all(cls_id[0] > 0, score[0] < threshold)): cls_id[0] = 0 # [id, prob, xmin, ymin, xmax, ymax] -- GitLab