From 3d010ed5e3bfe0816e4d6e260dbce31313d063de Mon Sep 17 00:00:00 2001
From: Liangfu Chen <liangfu.chen@icloud.com>
Date: Thu, 5 Jul 2018 00:04:41 +0800
Subject: [PATCH] support equal and not_equal in topi (#1373)

---
 docs/api/python/topi.rst                 |  6 ++
 topi/include/topi/broadcast.h            | 52 ++++++++++++++++
 topi/python/topi/broadcast.py            | 76 ++++++++++++++++++++++++
 topi/src/topi.cc                         |  4 ++
 topi/tests/python/test_topi_broadcast.py | 20 +++++++
 5 files changed, 158 insertions(+)

diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst
index a2221847d..a14355d0f 100644
--- a/docs/api/python/topi.rst
+++ b/docs/api/python/topi.rst
@@ -46,6 +46,8 @@ List of operators
    topi.max
    topi.sum
    topi.min
+   topi.argmax
+   topi.argmin
    topi.broadcast_to
    topi.add
    topi.subtract
@@ -57,6 +59,10 @@ List of operators
    topi.power
    topi.greater
    topi.less
+   topi.equal
+   topi.not_equal
+   topi.greater_equal
+   topi.less_equal
    topi.image.resize
 
 
diff --git a/topi/include/topi/broadcast.h b/topi/include/topi/broadcast.h
index 1a93d26f8..ad1c04ae1 100644
--- a/topi/include/topi/broadcast.h
+++ b/topi/include/topi/broadcast.h
@@ -257,6 +257,58 @@ TOPI_DEFINE_BCAST_OP(greater, { return (a > b); });
  */
 TOPI_DEFINE_BCAST_OP(less, { return (a < b); });
 
+/*!
+ * \fn equal
+ * \brief Compute (A == B) with auto-broadcasting.
+ *
+ * \param A The first tensor, or Expr
+ * \param B The second tensor, or Expr
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return The result.
+ */
+TOPI_DEFINE_BCAST_OP(equal, { return (a == b); });
+
+/*!
+ * \fn not_equal
+ * \brief Compute (A != B) with auto-broadcasting.
+ *
+ * \param A The first tensor, or Expr
+ * \param B The second tensor, or Expr
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return The result.
+ */
+TOPI_DEFINE_BCAST_OP(not_equal, { return (a != b); });
+
+/*!
+ * \fn greater_equal
+ * \brief Compute (A >= B) with auto-broadcasting.
+ *
+ * \param A The first tensor, or Expr
+ * \param B The second tensor, or Expr
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return The result.
+ */
+TOPI_DEFINE_BCAST_OP(greater_equal, { return (a >= b); });
+
+/*!
+ * \fn less_equal
+ * \brief Compute (A <= B) with auto-broadcasting.
+ *
+ * \param A The first tensor, or Expr
+ * \param B The second tensor, or Expr
+ * \param name The name of the operation
+ * \param tag The tag to mark the operation
+ *
+ * \return The result.
+ */
+TOPI_DEFINE_BCAST_OP(less_equal, { return (a <= b); });
+
 }  // namespace topi
 
 #endif  // TOPI_BROADCAST_H_
diff --git a/topi/python/topi/broadcast.py b/topi/python/topi/broadcast.py
index f088e48b0..9cfd99fbe 100644
--- a/topi/python/topi/broadcast.py
+++ b/topi/python/topi/broadcast.py
@@ -249,3 +249,79 @@ def less(lhs, rhs):
         Otherwise returns Tensor.
     """
     return _cpp.less(lhs, rhs)
+
+
+def equal(lhs, rhs):
+    """Compute (lhs==rhs) with auto-broadcasting
+
+    Parameters
+    ----------
+    lhs : tvm.Tensor or Expr
+        The left operand
+    rhs : tvm.Tensor or Expr
+        The right operand
+
+    Returns
+    -------
+    ret : tvm.Tensor or Expr
+        Returns Expr if both operands are Expr.
+        Otherwise returns Tensor.
+    """
+    return _cpp.equal(lhs, rhs)
+
+
+def not_equal(lhs, rhs):
+    """Compute (lhs!=rhs) with auto-broadcasting
+
+    Parameters
+    ----------
+    lhs : tvm.Tensor or Expr
+        The left operand
+    rhs : tvm.Tensor or Expr
+        The right operand
+
+    Returns
+    -------
+    ret : tvm.Tensor or Expr
+        Returns Expr if both operands are Expr.
+        Otherwise returns Tensor.
+    """
+    return _cpp.not_equal(lhs, rhs)
+
+
+def greater_equal(lhs, rhs):
+    """Compute (lhs>=rhs) with auto-broadcasting
+
+    Parameters
+    ----------
+    lhs : tvm.Tensor or Expr
+        The left operand
+    rhs : tvm.Tensor or Expr
+        The right operand
+
+    Returns
+    -------
+    ret : tvm.Tensor or Expr
+        Returns Expr if both operands are Expr.
+        Otherwise returns Tensor.
+    """
+    return _cpp.greater_equal(lhs, rhs)
+
+
+def less_equal(lhs, rhs):
+    """Compute (lhs<=rhs) with auto-broadcasting
+
+    Parameters
+    ----------
+    lhs : tvm.Tensor or Expr
+        The left operand
+    rhs : tvm.Tensor or Expr
+        The right operand
+
+    Returns
+    -------
+    ret : tvm.Tensor or Expr
+        Returns Expr if both operands are Expr.
+        Otherwise returns Tensor.
+    """
+    return _cpp.less_equal(lhs, rhs)
diff --git a/topi/src/topi.cc b/topi/src/topi.cc
index fe1f40985..a4a87b309 100644
--- a/topi/src/topi.cc
+++ b/topi/src/topi.cc
@@ -116,6 +116,10 @@ TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
 TOPI_REGISTER_BCAST_OP("topi.right_shift", topi::right_shift);
 TOPI_REGISTER_BCAST_OP("topi.greater", topi::greater);
 TOPI_REGISTER_BCAST_OP("topi.less", topi::less);
+TOPI_REGISTER_BCAST_OP("topi.equal", topi::equal);
+TOPI_REGISTER_BCAST_OP("topi.not_equal", topi::not_equal);
+TOPI_REGISTER_BCAST_OP("topi.greater_equal", topi::greater_equal);
+TOPI_REGISTER_BCAST_OP("topi.less_equal", topi::less_equal);
 
 /* Ops from elemwise.h */
 TVM_REGISTER_GLOBAL("topi.exp")
diff --git a/topi/tests/python/test_topi_broadcast.py b/topi/tests/python/test_topi_broadcast.py
index ecd238e78..c5720050e 100644
--- a/topi/tests/python/test_topi_broadcast.py
+++ b/topi/tests/python/test_topi_broadcast.py
@@ -142,10 +142,30 @@ def test_cmp():
         return topi.greater(x, y).astype("int8")
     def less(x, y):
         return topi.less(x, y).astype("int8")
+    def equal(x, y):
+        return topi.equal(x, y).astype("int8")
+    def not_equal(x, y):
+        return topi.not_equal(x, y).astype("int8")
+    def greater_equal(x, y):
+        return topi.greater_equal(x, y).astype("int8")
+    def less_equal(x, y):
+        return topi.less_equal(x, y).astype("int8")
     verify_broadcast_binary_ele(
         (1, 2, 2), (2,), greater, np.greater)
     verify_broadcast_binary_ele(
         (2, 1, 2), (2, 3, 1), less, np.less)
+    verify_broadcast_binary_ele(
+        (2, 1, 2), (2, 3, 1), equal, np.equal,
+        lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32')
+    verify_broadcast_binary_ele(
+        (2, 1, 2), (2, 3, 1), not_equal, np.not_equal,
+        lhs_min=-2, lhs_max=2, rhs_min=-2, rhs_max=2, dtype='int32')
+    verify_broadcast_binary_ele(
+        (7, 1, 5), (7, 3, 1), greater_equal, np.greater_equal,
+        lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
+    verify_broadcast_binary_ele(
+        (7, 1, 5), (7, 3, 1), less_equal, np.less_equal,
+        lhs_min=-3, lhs_max=3, rhs_min=-3, rhs_max=3, dtype='int32')
 
 def test_shift():
     # explicit specify the output type
-- 
GitLab