diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index a2221847dae474065ee47eaa4e5b96f3390aabd4..a14355d0f796f63bbb4fd198e051356dbf3d718f 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -46,6 +46,8 @@ List of operators topi.max topi.sum topi.min + topi.argmax + topi.argmin topi.broadcast_to topi.add topi.subtract @@ -57,6 +59,10 @@ List of operators topi.power topi.greater topi.less + topi.equal + topi.not_equal + topi.greater_equal + topi.less_equal topi.image.resize diff --git a/topi/include/topi/broadcast.h b/topi/include/topi/broadcast.h index 1a93d26f8c55709afaa2469aeca8f740de2f788c..ad1c04ae132776b1554abaab64ced5d684e45dbe 100644 --- a/topi/include/topi/broadcast.h +++ b/topi/include/topi/broadcast.h @@ -257,6 +257,58 @@ TOPI_DEFINE_BCAST_OP(greater, { return (a > b); }); */ TOPI_DEFINE_BCAST_OP(less, { return (a < b); }); +/*! + * \fn equal + * \brief Compute (A == B) with auto-broadcasting. + * + * \param A The first tensor, or Expr + * \param B The second tensor, or Expr + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The result. + */ +TOPI_DEFINE_BCAST_OP(equal, { return (a == b); }); + +/*! + * \fn not_equal + * \brief Compute (A != B) with auto-broadcasting. + * + * \param A The first tensor, or Expr + * \param B The second tensor, or Expr + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The result. + */ +TOPI_DEFINE_BCAST_OP(not_equal, { return (a != b); }); + +/*! + * \fn greater_equal + * \brief Compute (A >= B) with auto-broadcasting. + * + * \param A The first tensor, or Expr + * \param B The second tensor, or Expr + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The result. + */ +TOPI_DEFINE_BCAST_OP(greater_equal, { return (a >= b); }); + +/*! + * \fn less_equal + * \brief Compute (A <= B) with auto-broadcasting. + * + * \param A The first tensor, or Expr + * \param B The second tensor, or Expr + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The result. + */ +TOPI_DEFINE_BCAST_OP(less_equal, { return (a <= b); }); + } // namespace topi #endif // TOPI_BROADCAST_H_ diff --git a/topi/python/topi/broadcast.py b/topi/python/topi/broadcast.py index f088e48b0f14d70cf4fc7197643433c7fb62ef5f..9cfd99fbeec40780a549f635b48decaf0ceb00da 100644 --- a/topi/python/topi/broadcast.py +++ b/topi/python/topi/broadcast.py @@ -249,3 +249,79 @@ def less(lhs, rhs): Otherwise returns Tensor. """ return _cpp.less(lhs, rhs) + + +def equal(lhs, rhs): + """Compute (lhs==rhs) with auto-broadcasting + + Parameters + ---------- + lhs : tvm.Tensor or Expr + The left operand + rhs : tvm.Tensor or Expr + The right operand + + Returns + ------- + ret : tvm.Tensor or Expr + Returns Expr if both operands are Expr. + Otherwise returns Tensor. + """ + return _cpp.equal(lhs, rhs) + + +def not_equal(lhs, rhs): + """Compute (lhs!=rhs) with auto-broadcasting + + Parameters + ---------- + lhs : tvm.Tensor or Expr + The left operand + rhs : tvm.Tensor or Expr + The right operand + + Returns + ------- + ret : tvm.Tensor or Expr + Returns Expr if both operands are Expr. + Otherwise returns Tensor. + """ + return _cpp.not_equal(lhs, rhs) + + +def greater_equal(lhs, rhs): + """Compute (lhs>=rhs) with auto-broadcasting + + Parameters + ---------- + lhs : tvm.Tensor or Expr + The left operand + rhs : tvm.Tensor or Expr + The right operand + + Returns + ------- + ret : tvm.Tensor or Expr + Returns Expr if both operands are Expr. + Otherwise returns Tensor. + """ + return _cpp.greater_equal(lhs, rhs) + + +def less_equal(lhs, rhs): + """Compute (lhs<=rhs) with auto-broadcasting + + Parameters + ---------- + lhs : tvm.Tensor or Expr + The left operand + rhs : tvm.Tensor or Expr + The right operand + + Returns + ------- + ret : tvm.Tensor or Expr + Returns Expr if both operands are Expr. + Otherwise returns Tensor. + """ + return _cpp.less_equal(lhs, rhs) diff --git a/topi/src/topi.cc b/topi/src/topi.cc index fe1f4098570d1f9675ddf06705f07550ef19c868..a4a87b3096a34f806a1e395556993515d2f3f6a6 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -116,6 +116,10 @@ TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift); TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift); TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater); TOPI_REGISTER_BCAST_OP("topi.less", topi::less); +TOPI_REGISTER_BCAST_OP("topi.equal", topi::equal); +TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal); +TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal); +TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal); /* Ops from elemwise.h */ TVM_REGISTER_GLOBAL("topi.exp") diff --git a/topi/tests/python/test_topi_broadcast.py b/topi/tests/python/test_topi_broadcast.py index ecd238e78761b4640b9307d16d1cdffbd6809da9..c5720050e538e2d72974e1a591e2e90e9a2ae9ea 100644 --- a/topi/tests/python/test_topi_broadcast.py +++ b/topi/tests/python/test_topi_broadcast.py @@ -142,10 +142,30 @@ def test_cmp(): return topi.greater(x, y).astype("int8") def less(x, y): return topi.less(x, y).astype("int8") + def equal(x, y): + return topi.equal(x, y).astype("int8") + def not_equal(x, y): + return topi.not_equal(x, y).astype("int8") + def greater_equal(x, y): + return topi.greater_equal(x, y).astype("int8") + def less_equal(x, y): + return topi.less_equal(x, y).astype("int8") verify_broadcast_binary_ele( (1, 2, 2), (2,), greater, np.greater) verify_broadcast_binary_ele( (2, 1, 2), (2, 3, 1), less, np.less) + verify_broadcast_binary_ele( + (2, 1, 2), (2, 3, 1), equal, np.equal, + lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32') + verify_broadcast_binary_ele( + (2, 1, 2), (2, 3, 1), not_equal, np.not_equal, + lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32') + verify_broadcast_binary_ele( + (7, 1, 5), (7, 3, 1), greater_equal, np.greater_equal, + lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32') + verify_broadcast_binary_ele( + (7, 1, 5), (7, 3, 1), less_equal, np.less_equal, + lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32') def test_shift(): # explicit specify the output type