From 2e94a4b55f976d39814996367115a8c5476d2d69 Mon Sep 17 00:00:00 2001
From: yuruofeifei <yuruofeifei@gmail.com>
Date: Fri, 2 Feb 2018 14:33:29 -0800
Subject: [PATCH] [TOPI] Add compute for more operators (#849)

* [TOPI] Add compute for more operators

* Remove device except llvm

* Address comments

* Remove matmul compute

* Add outtype to boolean operator

* Address coments
---
 python/tvm/api.py                        |   1 +
 topi/python/topi/__init__.py             |   2 +-
 topi/python/topi/tensor.py               | 116 ++++++++++++++++++++
 topi/python/topi/transform.py            |  55 ++++++++++
 topi/tests/python/test_topi_math.py      |  31 ++++++
 topi/tests/python/test_topi_tensor.py    | 129 +++++++++++++++++++++++
 topi/tests/python/test_topi_transform.py |  44 ++++++++
 7 files changed, 377 insertions(+), 1 deletion(-)
 create mode 100644 topi/python/topi/tensor.py
 create mode 100644 topi/tests/python/test_topi_math.py
 create mode 100644 topi/tests/python/test_topi_tensor.py

diff --git a/python/tvm/api.py b/python/tvm/api.py
index 08b3d95dc..7c90b0ec9 100644
--- a/python/tvm/api.py
+++ b/python/tvm/api.py
@@ -19,6 +19,7 @@ from . import schedule as _schedule
 from . import container as _container
 from . import tag as _tag
 
+int8 = "int8"
 int32 = "int32"
 float32 = "float32"
 handle = "handle"
diff --git a/topi/python/topi/__init__.py b/topi/python/topi/__init__.py
index 4979bdbaa..1648de33b 100644
--- a/topi/python/topi/__init__.py
+++ b/topi/python/topi/__init__.py
@@ -12,6 +12,7 @@ from __future__ import absolute_import as _abs
 from tvm._ffi.libinfo import __version__
 
 from .math import *
+from .tensor import *
 from .reduction import *
 from .transform import *
 from .broadcast import *
@@ -23,4 +24,3 @@ from . import mali
 from . import testing
 from . import util
 from . import rocm
-from . import cpp
diff --git a/topi/python/topi/tensor.py b/topi/python/topi/tensor.py
new file mode 100644
index 000000000..3e2e0abb4
--- /dev/null
+++ b/topi/python/topi/tensor.py
@@ -0,0 +1,116 @@
+# pylint: disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition
+"""Elementwise operators"""
+from __future__ import absolute_import as _abs
+import tvm
+from . import tag
+
+@tvm.tag_scope(tag=tag.ELEMWISE)
+def elemwise_sum(xs, num_args):
+    """Perform element-wise sum on inputs
+
+    Parameters
+    ----------
+    xs : list of tvm.Tensor
+        Input arguments.
+    num_args : int
+        Number of arguments
+
+    Returns
+    -------
+    y : tvm.Tensor
+        The result.
+    """
+    assert len(xs) > 0, "elemwise sum must have at least one input tensor."
+
+    def _compute(*i):
+        return sum([x(*i) for x in xs])
+
+    return tvm.compute(xs[0].shape, _compute)
+
+@tvm.tag_scope(tag=tag.ELEMWISE)
+def full(shape, dtype, fill_value):
+    """Fill tensor with fill_value
+
+    Parameters
+    ----------
+    shape : tuple
+        Input tensor shape.
+    dtype : str
+        Data type
+    fill_value : float
+        Value to be filled
+
+    Returns
+    -------
+    y : tvm.Tensor
+        The result.
+    """
+    return tvm.compute(shape, lambda *i: tvm.const(fill_value, dtype))
+
+@tvm.tag_scope(tag=tag.ELEMWISE)
+def full_like(x, fill_value):
+    """Construct a tensor with same shape as input tensor,
+       then fill tensor with fill_value.
+
+    Parameters
+    ----------
+    x : tvm.Tensor
+        Input argument.
+    fill_value : float
+        Value to be filled
+
+    Returns
+    -------
+    y : tvm.Tensor
+        The result.
+    """
+    dtype = x.dtype
+    return tvm.compute(x.shape, lambda *i: tvm.const(fill_value, dtype))
+
+@tvm.tag_scope(tag=tag.ELEMWISE)
+def greater(lhs, rhs, out_type=tvm.int8):
+    """Compare two input tensors element-wise and return an mask tensor
+       which contains 1 if lhs > rhs holds else 0
+
+    Parameters
+    ----------
+    lhs : tvm.Tensor
+        Left input argument.
+    rhs : tvm.Tensor
+        Right argument.
+    out_type: str
+        Output data type. Default is int8
+
+    Returns
+    -------
+    y : tvm.Tensor
+        The result.
+    """
+    return tvm.compute(lhs.shape,
+                       lambda *i: tvm.select(lhs(*i) > rhs(*i),
+                                             tvm.const(1, out_type),
+                                             tvm.const(0, out_type)))
+
+@tvm.tag_scope(tag=tag.ELEMWISE)
+def less(lhs, rhs, out_type=tvm.int8):
+    """Compare two input tensors element-wise and return an mask tensor
+       which contains 1 if lhs < rhs holds else 0
+
+    Parameters
+    ----------
+    lhs : tvm.Tensor
+        Left input argument.
+    rhs : tvm.Tensor
+        Right argument.
+    out_type: str
+        Output data type. Default is int8
+
+    Returns
+    -------
+    y : tvm.Tensor
+        The result.
+    """
+    return tvm.compute(lhs.shape,
+                       lambda *i: tvm.select(lhs(*i) < rhs(*i),
+                                             tvm.const(1, out_type),
+                                             tvm.const(0, out_type)))
diff --git a/topi/python/topi/transform.py b/topi/python/topi/transform.py
index 3194a5601..46998ace2 100644
--- a/topi/python/topi/transform.py
+++ b/topi/python/topi/transform.py
@@ -2,6 +2,7 @@
 """Injective transformation operators"""
 from __future__ import absolute_import as _abs
 import tvm
+import topi
 from . import tag
 from .util import ravel_index, unravel_index, get_const_int, get_const_tuple
 
@@ -29,6 +30,60 @@ def expand_dims(a, axis, num_newaxis=1):
     return tvm.compute(new_shape, _compute)
 
 
+@tvm.tag_scope(tag=tag.BROADCAST)
+def expand_like(a, shape_like, axis):
+    """Expand an input array with the shape of second array.
+    This operation can always be composed of unsqueezing and
+    expanding dims on those unsqueezed axes.
+
+    Examples::
+    input = [ 12.  19.  27.]
+    input.shape = (3,)
+
+    new_shape_array = [[[1,2],[2,3],[1,3]],
+                      [[1,4],[4,3],[5,2]],
+                      [[7,1],[7,2],[7,3]]]
+    new_shape_array.shape = (3, 3, 2)
+
+    expand_like(input, [1,2], new_shape_array) =
+                      [[[12,12],[12,12],[12,12]],
+                      [[19,19],[19,19],[19,19]],
+                      [[27,27],[27,27],[27,27]]]
+
+    Parameters
+    ----------
+    a : tvm.Tensor
+        The tensor to be expanded.
+    shape_like : tvm.Tensor
+        The tensor to with target shape.
+    axis: list of int
+        axis to be expanded on
+    Returns
+    -------
+    ret : tvm.Tensor
+    """
+    odim = len(axis) + len(a.shape)
+    if odim != len(shape_like.shape):
+        raise ValueError("shape inconsistent when expand_like ({}, {}, {})".format(
+            len(axis), len(a.shape), len(shape_like.shape)))
+
+    real_axis = topi.reduction._get_real_axis(len(shape_like.shape), axis)
+    real_axis = sorted(real_axis)
+
+    if not real_axis:
+        return a
+
+    def _compute(*idxs):
+        indices = []
+        axis_index = 0
+        for i in range(0, len(idxs)):
+            if i not in real_axis:
+                indices.append(idxs[i])
+                axis_index += 1
+        return a(*indices)
+    return tvm.compute(shape_like.shape, _compute)
+
+
 @tvm.tag_scope(tag=tag.INJECTIVE)
 def transpose(a, axes=None):
     """Permute the dimensions of an array.
diff --git a/topi/tests/python/test_topi_math.py b/topi/tests/python/test_topi_math.py
new file mode 100644
index 000000000..937639785
--- /dev/null
+++ b/topi/tests/python/test_topi_math.py
@@ -0,0 +1,31 @@
+import numpy as np
+import tvm
+import topi
+from topi import util
+
+
+def test_util():
+    x = tvm.const(100)
+    assert util.get_const_int(x) == 100
+    assert util.get_const_tuple((x, x)) == (100, 100)
+
+
+def test_ewise():
+    m = tvm.var('m')
+    l = tvm.var('l')
+    A = tvm.placeholder((m, l), name='A')
+
+    def test_apply(func, name):
+        B = func(A)
+        assert tuple(B.shape) == tuple(A.shape)
+        assert B.op.body[0].name == name
+
+    test_apply(topi.exp, "exp")
+    test_apply(topi.tanh, "tanh")
+    test_apply(topi.sigmoid, "sigmoid")
+    test_apply(topi.log, "log")
+    test_apply(topi.sqrt, "sqrt")
+
+if __name__ == "__main__":
+    test_util()
+    test_ewise()
diff --git a/topi/tests/python/test_topi_tensor.py b/topi/tests/python/test_topi_tensor.py
new file mode 100644
index 000000000..1489ad073
--- /dev/null
+++ b/topi/tests/python/test_topi_tensor.py
@@ -0,0 +1,129 @@
+"""Test code for tensor operator"""
+import numpy as np
+import tvm
+import topi
+from tvm.contrib.pickle_memoize import memoize
+
+def verify_elemwise_sum(num_args, dtype):
+    shape = (3,5,4)
+
+    tvm_placeholders = []
+    for i in range(num_args):
+        tvm_placeholders.append(
+            tvm.placeholder(shape, name="data"+str(i), dtype=dtype))
+    esum = topi.elemwise_sum(tvm_placeholders, num_args=num_args)
+    s = tvm.create_schedule([esum.op])
+
+    @memoize("topi.tests.test_topi_elemwise_sum")
+    def get_ref_data():
+        np_nd = [np.random.uniform(0, 10, size=shape).astype(dtype)
+                 for i in range(num_args)]
+        return np_nd
+    np_nd = get_ref_data()
+
+    def check_device(device):
+        if not tvm.module.enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+
+        ctx = tvm.context(device, 0)
+        out = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx)
+        f = tvm.build(s, tvm_placeholders + [esum], device, name="elemwise_sum")
+        tvm_nd = [tvm.nd.array(nd, ctx) for nd in np_nd] + [out]
+        f(*tvm_nd)
+        np_out = np.sum(np.array(np_nd), axis=0)
+        np.testing.assert_allclose(out.asnumpy(), np_out, rtol=1e-5)
+
+    for device in ["llvm"]:
+        check_device(device)
+
+
+def verify_full(shape, dtype, fill_value):
+    A = tvm.placeholder(shape, dtype=dtype, name="A")
+    B = topi.full_like(A, fill_value=fill_value)
+    C = topi.full(shape=shape, dtype=dtype, fill_value=fill_value)
+    s1 = tvm.create_schedule([B.op])
+    s2 = tvm.create_schedule([C.op])
+
+    @memoize("topi.tests.test_topi_full")
+    def get_ref_data():
+        return np.full(shape, fill_value, dtype)
+    np_nd = get_ref_data()
+
+    def check_device(device):
+        if not tvm.module.enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+
+        ctx = tvm.context(device, 0)
+        out = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx)
+        f = tvm.build(s1, [A, B], device, name="full_like")
+        f(tvm.nd.array(np.zeros(shape, dtype), ctx), out)
+        np.testing.assert_allclose(out.asnumpy(), np_nd, rtol=1e-5)
+
+        f = tvm.build(s2, [C], device, name="full")
+        f(out)
+        np.testing.assert_allclose(out.asnumpy(), np_nd, rtol=1e-5)
+
+    for device in ["llvm"]:
+        check_device(device)
+
+
+def verify_comparator(shape, dtype, out_type='int8'):
+    A = tvm.placeholder(shape, dtype, name="A")
+    B = tvm.placeholder(shape, dtype, name="B")
+    C = topi.less(A, B)
+    s_less = tvm.create_schedule([C.op])
+
+    D = tvm.placeholder(shape, dtype, name="D")
+    E = tvm.placeholder(shape, dtype, name="E")
+    F = topi.greater(D, E, out_type)
+    s_greater = tvm.create_schedule([F.op])
+
+    @memoize("topi.tests.test_topi_indicator")
+    def get_ref_data():
+        return [np.random.uniform(0, 10, size=shape).astype(dtype),
+                np.random.uniform(0, 10, size=shape).astype(dtype)]
+    [np_l, np_r] = get_ref_data()
+
+    def check_device(device):
+        if not tvm.module.enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+
+        ctx = tvm.context(device, 0)
+        out = tvm.nd.array(np.zeros(shape, dtype=out_type), ctx)
+        tvm_l = tvm.nd.array(np_l, ctx)
+        tvm_r = tvm.nd.array(np_r, ctx)
+
+        f = tvm.build(s_less, [A, B, C], device, name="less")
+        f(tvm_l, tvm_r, out)
+        np.testing.assert_allclose(out.asnumpy(), np.less(np_l, np_r).astype(out_type), rtol=1e-5)
+
+        f = tvm.build(s_greater, [D, E, F], device, name="greater")
+        f(tvm_l, tvm_r, out)
+        np.testing.assert_allclose(out.asnumpy(), np.greater(np_l, np_r).astype(out_type), rtol=1e-5)
+
+    for device in ["llvm"]:
+        check_device(device)
+
+def test_elemwise_sum():
+    verify_elemwise_sum(1, "float32")
+    verify_elemwise_sum(5, "float32")
+    verify_elemwise_sum(4, "int32")
+
+
+def test_full():
+    verify_full((3,4,5), "float32", 3.14)
+    verify_full((10,), "int32", 7)
+
+
+def test_comparator():
+    verify_comparator((3,4,5), "float32")
+    verify_comparator((7,), "int32")
+    verify_comparator((3,4,5), "float32", "int8")
+
+if __name__ == "__main__":
+    test_elemwise_sum()
+    test_full()
+    test_comparator()
diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py
index 46c860fb4..f1f39816b 100644
--- a/topi/tests/python/test_topi_transform.py
+++ b/topi/tests/python/test_topi_transform.py
@@ -150,6 +150,41 @@ def verify_split(src_shape, indices_or_sections, axis):
         check_device(device)
 
 
+def verify_expand_like(in_shape, out_shape, axis):
+    A = tvm.placeholder(shape=in_shape, name="A")
+    B = tvm.placeholder(shape=out_shape, name="B")
+    C = topi.expand_like(A, B, axis)
+    s = tvm.create_schedule([C.op])
+
+    def check_device(device):
+        if not tvm.module.enabled(device):
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+
+        ctx = tvm.context(device, 0)
+        f = tvm.build(s, [A, B, C], device, name="expand_like")
+        input = np.random.uniform(size=in_shape).astype(A.dtype)
+        tvm_input = tvm.nd.array(input, ctx)
+
+        odim = len(out_shape)
+        real_axis = [x if x >= 0 else x + odim for x in axis]
+        real_axis = sorted(real_axis)
+        for x in real_axis:
+            input = np.expand_dims(input, x).astype(A.dtype)
+        for x in real_axis:
+            input = np.concatenate([input]*out_shape[x], axis=x).astype(A.dtype)
+        assert input.shape == out_shape
+
+        tvm_shape_like = tvm.nd.array(np.zeros(out_shape).astype(B.dtype), ctx)
+        out = tvm.nd.array(np.zeros(out_shape).astype(A.dtype), ctx)
+        f(tvm_input, tvm_shape_like, out)
+        np.testing.assert_allclose(out.asnumpy(), input)
+
+    for device in ["llvm"]:
+        check_device(device)
+
+
 def test_expand_dims():
     verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
     verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
@@ -191,6 +226,14 @@ def test_split():
     verify_split((2, 12, 3), [2, 4], 1)
     verify_split((10, 12, 24), [5, 7, 9], -1)
 
+
+def test_expand_like():
+    verify_expand_like((3,), (2, 3), [0])
+    verify_expand_like((2,), (2, 3), [1])
+    verify_expand_like((3, 4), (3, 5, 4), [1])
+    verify_expand_like((5, 7), (5, 6, 7, 8), [1, 3])
+
+
 if __name__ == "__main__":
     test_concatenate()
     test_tranpose()
@@ -198,3 +241,4 @@ if __name__ == "__main__":
     test_reshape()
     test_squeeze()
     test_split()
+    test_expand_like()
-- 
GitLab