From 3a0b757c10e7bac05f64a780e17c4fd4210165c1 Mon Sep 17 00:00:00 2001
From: Siva <sivar.b@huawei.com>
Date: Sun, 8 Jul 2018 07:21:57 +0530
Subject: [PATCH] [NNVM][TOP] broadcast versions corresponding to topi: mod,
 max, min, pow, left_shift, right_shift greater, less, equal, not_equal,
 greater_equal and less_equal. (#1383)

---
 nnvm/python/nnvm/frontend/onnx.py             |   2 +-
 nnvm/python/nnvm/top/tensor.py                |  48 ++++
 nnvm/src/top/tensor/broadcast.cc              | 247 ++++++++++++++++++
 nnvm/tests/python/compiler/test_top_level4.py |  55 +++-
 .../python/frontend/onnx/test_forward.py      |  45 ++++
 5 files changed, 393 insertions(+), 4 deletions(-)

diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py
index 0d8915ad4..ee7fcc3c5 100644
--- a/nnvm/python/nnvm/frontend/onnx.py
+++ b/nnvm/python/nnvm/frontend/onnx.py
@@ -544,7 +544,7 @@ def _get_convert_map(opset):
         'Exp': Renamer('exp'),
         'Log': Renamer('log'),
         'Tanh': Renamer('tanh'),
-        # 'Pow'
+        'Pow': Renamer('broadcast_pow'),
         'PRelu': Prelu.get_converter(opset),
         'Sigmoid': Renamer('sigmoid'),
         # 'HardSigmoid'
diff --git a/nnvm/python/nnvm/top/tensor.py b/nnvm/python/nnvm/top/tensor.py
index bd486287a..3ff59e5d0 100644
--- a/nnvm/python/nnvm/top/tensor.py
+++ b/nnvm/python/nnvm/top/tensor.py
@@ -168,6 +168,54 @@ reg.register_schedule("broadcast_mul", _fschedule_broadcast)
 reg.register_pattern("broadcast_div", OpPattern.BROADCAST)
 reg.register_schedule("broadcast_div", _fschedule_broadcast)
 
+# broadcast mod
+reg.register_pattern("broadcast_mod", OpPattern.BROADCAST)
+reg.register_schedule("broadcast_mod", _fschedule_broadcast)
+
+# broadcast max
+reg.register_pattern("broadcast_max", OpPattern.BROADCAST)
+reg.register_schedule("broadcast_max", _fschedule_broadcast)
+
+# broadcast min
+reg.register_pattern("broadcast_min", OpPattern.BROADCAST)
+reg.register_schedule("broadcast_min", _fschedule_broadcast)
+
+# broadcast pow
+reg.register_pattern("broadcast_pow", OpPattern.BROADCAST)
+reg.register_schedule("broadcast_pow", _fschedule_broadcast)
+
+# broadcast left_shift
+reg.register_pattern("broadcast_left_shift", OpPattern.BROADCAST)
+reg.register_schedule("broadcast_left_shift", _fschedule_broadcast)
+
+# broadcast right_shift
+reg.register_pattern("broadcast_right_shift", OpPattern.BROADCAST)
+reg.register_schedule("broadcast_right_shift", _fschedule_broadcast)
+
+# broadcast greater
+reg.register_pattern("broadcast_greater", OpPattern.BROADCAST)
+reg.register_schedule("broadcast_greater", _fschedule_broadcast)
+
+# broadcast less
+reg.register_pattern("broadcast_less", OpPattern.BROADCAST)
+reg.register_schedule("broadcast_less", _fschedule_broadcast)
+
+# broadcast equal
+reg.register_pattern("broadcast_equal", OpPattern.BROADCAST)
+reg.register_schedule("broadcast_equal", _fschedule_broadcast)
+
+# broadcast not_equal
+reg.register_pattern("broadcast_not_equal", OpPattern.BROADCAST)
+reg.register_schedule("broadcast_not_equal", _fschedule_broadcast)
+
+# broadcast greater_equal
+reg.register_pattern("broadcast_greater_equal", OpPattern.BROADCAST)
+reg.register_schedule("broadcast_greater_equal", _fschedule_broadcast)
+
+# broadcast less_equal
+reg.register_pattern("broadcast_less_equal", OpPattern.BROADCAST)
+reg.register_schedule("broadcast_less_equal", _fschedule_broadcast)
+
 # broadcast_to
 reg.register_pattern("broadcast_to", OpPattern.BROADCAST)
 reg.register_schedule("broadcast_to", _fschedule_broadcast)
diff --git a/nnvm/src/top/tensor/broadcast.cc b/nnvm/src/top/tensor/broadcast.cc
index 5f6cc4e66..edf209e3e 100644
--- a/nnvm/src/top/tensor/broadcast.cc
+++ b/nnvm/src/top/tensor/broadcast.cc
@@ -15,6 +15,7 @@
 #include "../op_common.h"
 #include "../elemwise_op_common.h"
 #include "topi/broadcast.h"
+#include "topi/elemwise.h"
 
 namespace nnvm {
 namespace top {
@@ -346,5 +347,251 @@ Example::
     return std::vector<NodeEntry>{ dlhs, drhs };
 });
 
+NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_mod, mod)
+.add_alias("__mod_symbol__")
+.describe(R"code(Returns element-wise mod of the input arrays with broadcasting.
+
+Example::
+
+   x = [[ 1.,  2.,  3.],
+        [ 4.,  5.,  6.]]
+
+   y = [[ 2.],
+        [ 3.]]
+
+   broadcast_mod(x, y) = [[ 1.,  0.,  1.],
+                          [ 1.,  2.,  0.]]
+
+)code" NNVM_ADD_FILELINE);
+
+NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_max, maximum)
+.add_alias("__max_symbol__")
+.describe(R"code(Returns element-wise max of the input arrays with broadcasting.
+
+Example::
+
+   x = [[ 1.,  2.,  3.],
+        [ 4.,  5.,  6.]]
+
+   y = [[ 2.],
+        [ 3.]]
+
+   broadcast_max(x, y) = [[ 2.,  2.,  3.],
+                          [ 4.,  5.,  6.]]
+
+)code" NNVM_ADD_FILELINE);
+
+NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_min, minimum)
+.add_alias("__min_symbol__")
+.describe(R"code(Returns element-wise minimum of the input arrays with broadcasting.
+
+Example::
+
+   x = [[ 1.,  2.,  3.],
+        [ 4.,  5.,  6.]]
+
+   y = [[ 2.],
+        [ 3.]]
+
+   broadcast_min(x, y) = [[ 1.,  2.,  2.],
+                          [ 3.,  3.,  3.]]
+
+)code" NNVM_ADD_FILELINE);
+
+NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_pow, power)
+.add_alias("__pow_symbol__")
+.describe(R"code(Returns element-wise x^y of the input arrays with broadcasting.
+
+Example::
+
+   x = [[ 1.,  2.,  3.],
+        [ 4.,  5.,  6.]]
+
+   y = [[ 1.],
+        [ 2.]]
+
+   broadcast_pow(x, y) = [[ 1.,   2.,   3. ],
+                          [ 16.,  25.,  36.]]
+
+)code" NNVM_ADD_FILELINE);
+
+NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_left_shift, left_shift)
+.add_alias("__left_shift_symbol__")
+.describe(R"code(Returns element-wise x << y of the input arrays with broadcasting.
+
+Example::
+
+   x = [[ 1.,  2.,  3.],
+        [ 4.,  5.,  6.]]
+
+   y = [[ 2.],
+        [ 1.]]
+
+   broadcast_left_shift(x, y) = [[ 4.,  8.,  12.],
+                                 [ 8.,  10., 12.]]
+
+)code" NNVM_ADD_FILELINE);
+
+NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_right_shift, right_shift)
+.add_alias("__right_shift_symbol__")
+.describe(R"code(Returns element-wise x >> y of the input arrays with broadcasting.
+
+Example::
+
+   x = [[ 4.,  8.,  12.],
+        [ 8.,  10., 12.]]
+
+   y = [[ 2.],
+        [ 1.]]
+
+   broadcast_right_shift(x, y) = [[ 1.,  2.,  3.],
+                                  [ 4.,  5.,  6.]]
+
+)code" NNVM_ADD_FILELINE);
+
+NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_greater, greater)
+.add_alias("__greater_symbol__")
+.describe(R"code(Returns element-wise x > y of the input arrays with broadcasting.
+
+Example::
+
+   x = [[ 1.,  2.,  3.],
+        [ 4.,  5.,  6.]]
+
+   y = [[ 2.],
+        [ 3.]]
+
+   broadcast_greater(x, y) = [[ 0.,  0.,  1.],
+                              [ 1.,  1.,  1.]]
+
+)code" NNVM_ADD_FILELINE)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ topi::cast(topi::greater(inputs[0], inputs[1]), out_info[0]->dtype) };
+}, 11);
+
+NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_less, less)
+.add_alias("__less_symbol__")
+.describe(R"code(Returns element-wise x < y of the input arrays with broadcasting.
+
+Example::
+
+   x = [[ 1.,  2.,  3.],
+        [ 4.,  5.,  6.]]
+
+   y = [[ 2.],
+        [ 3.]]
+
+   broadcast_less(x, y) = [[ 1.,  0.,  0.],
+                           [ 0.,  0.,  0.]]
+
+)code" NNVM_ADD_FILELINE)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ topi::cast(topi::less(inputs[0], inputs[1]), out_info[0]->dtype) };
+}, 11);
+
+NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_equal, equal)
+.add_alias("__equal_symbol__")
+.describe(R"code(Returns element-wise x == y of the input arrays with broadcasting.
+
+Example::
+
+   x = [[ 1.,  2.,  3.],
+        [ 4.,  5.,  6.]]
+
+   y = [[ 2.],
+        [ 5.]]
+
+   broadcast_equal(x, y) = [[ 0.,  1.,  0.],
+                            [ 0.,  1.,  0.]]
+
+)code" NNVM_ADD_FILELINE)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ topi::cast(topi::equal(inputs[0], inputs[1]), out_info[0]->dtype) };
+}, 11);
+
+NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_not_equal, not_equal)
+.add_alias("__not_equal_symbol__")
+.describe(R"code(Returns element-wise x != y of the input arrays with broadcasting.
+
+Example::
+
+   x = [[ 1.,  2.,  3.],
+        [ 4.,  5.,  6.]]
+
+   y = [[ 2.],
+        [ 4.]]
+
+   broadcast_not_equal(x, y) = [[ 1.,  0.,  1.],
+                                [ 0.,  1.,  1.]]
+
+)code" NNVM_ADD_FILELINE)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ topi::cast(topi::not_equal(inputs[0],
+                                                     inputs[1]),
+                                                     out_info[0]->dtype) };
+}, 11);
+
+NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_greater_equal, greater_equal)
+.add_alias("__greater_equal_symbol__")
+.describe(R"code(Returns element-wise x >= y of the input arrays with broadcasting.
+
+Example::
+
+   x = [[ 1.,  2.,  3.],
+        [ 4.,  5.,  6.]]
+
+   y = [[ 2.],
+        [ 6.]]
+
+   broadcast_greater_equal(x, y) = [[ 0.,  1.,  1.],
+                                    [ 0.,  0.,  1.]]
+
+)code" NNVM_ADD_FILELINE)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ topi::cast(topi::greater_equal(inputs[0],
+                                                         inputs[1]),
+                                                         out_info[0]->dtype) };
+}, 11);
+
+NNVM_REGISTER_BINARY_BROADCAST_OP(broadcast_less_equal, less_equal)
+.add_alias("__less_equal_symbol__")
+.describe(R"code(Returns element-wise x <= y of the input arrays with broadcasting.
+
+Example::
+
+   x = [[ 1.,  2.,  3.],
+        [ 4.,  5.,  6.]]
+
+   y = [[ 1.],
+        [ 5.]]
+
+   broadcast_less_equal(x, y) = [[ 1.,  0.,  0.],
+                                 [ 1.,  1.,  0.]]
+
+)code" NNVM_ADD_FILELINE)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    return Array<Tensor>{ topi::cast(topi::less_equal(inputs[0],
+                                                      inputs[1]),
+                                                      out_info[0]->dtype) };
+}, 11);
+
 }  // namespace top
 }  // namespace nnvm
diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py
index 33f3ff3e9..de6a3fa33 100644
--- a/nnvm/tests/python/compiler/test_top_level4.py
+++ b/nnvm/tests/python/compiler/test_top_level4.py
@@ -9,17 +9,23 @@ from nnvm.testing.config import ctx_list
 
 
 def helper(symbol, inputs, dtype,
-           np_forward, np_backward=None, need_input=True, need_head_grads=True):
+           np_forward, np_backward=None,
+           need_input=True, need_head_grads=True, in_range={}):
     ishapes = {}
     input_syms = []
     np_inputs = {}
     for (name, shape, s) in inputs:
         ishapes.update({name: shape})
-        np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
+        if name in in_range:
+            np_inputs.update({name: np.random.uniform(size=shape,
+                                                      low=in_range[name][0],
+                                                      high=in_range[name][1]).astype(dtype)})
+        else:
+            np_inputs.update({name: np.random.uniform(size=shape).astype(dtype)})
         input_syms.append(s)
 
     for target, ctx in ctx_list():
-        graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes)
+        graph, lib, _ = nnvm.compiler.build(symbol, target, ishapes, dtype=dtype)
         m = graph_runtime.create(graph, lib, ctx)
         m.run(**np_inputs)
         y_np = np_forward(**np_inputs)
@@ -228,6 +234,49 @@ def test_broadcast():
         return da, db
     helper(y, inputs, dtype, lambda a, b: a / b, _backward_div)
 
+    y = sym.broadcast_mod(a, b)
+    helper(y, inputs, 'int32',
+           lambda a, b: np.mod(a, b),
+           in_range={'a': (0.001, 100), 'b': (1, 100)})
+
+    y = sym.broadcast_max(a, b)
+    helper(y, inputs, dtype, lambda a, b: np.maximum(a, b))
+
+    y = sym.broadcast_min(a, b)
+    helper(y, inputs, dtype, lambda a, b: np.minimum(a, b))
+
+    y = sym.broadcast_pow(a, b)
+    helper(y, inputs, dtype,
+           lambda a, b: np.power(a, b),
+           in_range={'a': (0.001, 100), 'b': (0.001, 2)})
+
+    y = sym.broadcast_left_shift(a, b)
+    helper(y, inputs, 'int32', lambda a, b: a << b)
+
+    y = sym.broadcast_right_shift(a, b)
+    helper(y, inputs, 'int32', lambda a, b: a >> b)
+
+    y = sym.broadcast_greater(a, b)
+    helper(y, inputs, dtype, lambda a, b: np.greater(a, b))
+
+    y = sym.broadcast_less(a, b)
+    helper(y, inputs, dtype, lambda a, b: np.less(a, b))
+
+    y = sym.broadcast_equal(a, b)
+    helper(y, inputs, 'int32', lambda a, b: np.equal(a, b),
+           in_range={'a': (-2, 2), 'b': (-2, 2)})
+
+    y = sym.broadcast_not_equal(a, b)
+    helper(y, inputs, 'int32', lambda a, b: np.not_equal(a, b),
+           in_range={'a': (-2, 2), 'b': (-2, 2)})
+
+    y = sym.broadcast_greater_equal(a, b)
+    helper(y, inputs, 'int32', lambda a, b: np.greater_equal(a, b),
+           in_range={'a': (-3, 3), 'b': (-3, 3)})
+
+    y = sym.broadcast_less_equal(a, b)
+    helper(y, inputs, 'int32', lambda a, b: np.less_equal(a, b),
+           in_range={'a': (-3, 3), 'b': (-3, 3)})
 
 def test_greater():
     l = sym.Variable("l")
diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py
index 0af92e3f4..56f1e0062 100644
--- a/nnvm/tests/python/frontend/onnx/test_forward.py
+++ b/nnvm/tests/python/frontend/onnx/test_forward.py
@@ -108,6 +108,50 @@ def test_reshape_like():
 
     np.testing.assert_allclose(ref_shape, tvm_out.shape)
 
+def _test_power_iteration(x_shape, y_shape):
+    if isinstance(y_shape, int):
+        y_shape = [y_shape]
+
+    x = np.random.uniform(size=x_shape).astype(np.float32)
+    y = np.random.uniform(size=y_shape).astype(np.float32)
+
+    np_res = np.power(x, y).astype(np.float32)
+
+    res = helper.make_node("Pow", ['x', 'y'], ['out'])
+
+    graph = helper.make_graph([res],
+                              'power_test',
+                              inputs = [helper.make_tensor_value_info("x", TensorProto.FLOAT, list(x_shape)),
+                                        helper.make_tensor_value_info("y", TensorProto.FLOAT, list(y_shape))],
+                              outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(np_res.shape))])
+
+    model = helper.make_model(graph, producer_name='power_test')
+
+    for target, ctx in ctx_list():
+        new_sym, params = nnvm.frontend.from_onnx(model)
+
+        input_name = model.graph.input[0].name
+        input_name1 = model.graph.input[1].name
+        shape_dict = {input_name: x.shape, input_name1: y.shape}
+        dtype_dict = {input_name: x.dtype, input_name1: y.dtype}
+
+        graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, dtype_dict, params=params)
+        m = graph_runtime.create(graph, lib, ctx)
+        # set inputs
+        m.set_input(input_name, tvm.nd.array(x))
+        m.set_input(input_name1, tvm.nd.array(y))
+        m.set_input(**params)
+        m.run()
+        # get outputs
+        tvm_out = m.get_output(0, tvm.nd.empty(np_res.shape, np_res.dtype))
+
+        np.testing.assert_allclose(np_res, tvm_out.asnumpy(), rtol=1e-5, atol=1e-5)
+
+def test_power():
+    _test_power_iteration((1, 3), (1))
+    _test_power_iteration((2, 3), (2, 3))
+    _test_power_iteration((2, 3), (1, 3))
+
 def test_squeeze():
     in_shape = (1, 3, 1, 3, 1, 1)
     out_shape = (3, 3)
@@ -247,6 +291,7 @@ if __name__ == '__main__':
     verify_resnet18()
     test_reshape()
     test_reshape_like()
+    test_power()
     test_squeeze()
     test_unsqueeze()
     test_slice()
-- 
GitLab