From 113b46ec802101f6b4933c4fc869de99a5be9312 Mon Sep 17 00:00:00 2001
From: Siva <sivar.b@huawei.com>
Date: Tue, 26 Jun 2018 22:45:53 +0530
Subject: [PATCH] [NNVM][ONNX] Shape operator support (limited/differed) -
 #1297 (#1333)

---
 nnvm/python/nnvm/frontend/onnx.py             | 43 ++++++++-
 .../python/frontend/onnx/test_forward.py      | 91 ++++++++++++++++---
 2 files changed, 115 insertions(+), 19 deletions(-)

diff --git a/nnvm/python/nnvm/frontend/onnx.py b/nnvm/python/nnvm/frontend/onnx.py
index d92a856d7..fb21f6b7b 100644
--- a/nnvm/python/nnvm/frontend/onnx.py
+++ b/nnvm/python/nnvm/frontend/onnx.py
@@ -258,10 +258,11 @@ class Reshape(OnnxOpConverter):
     def _impl_v5(cls, inputs, attr, params):
         if inputs[1].list_output_names()[0] in params:
             shape = tuple(params[inputs[1].list_output_names()[0]].asnumpy())
+            out = _sym.reshape(inputs[0], shape=shape)
         else:
-            raise RuntimeError('Shape is not contained in graph initializer.')
-        return _sym.reshape(inputs[0], shape=shape)
+            out = _sym.reshape_like(inputs[0], inputs[1])
 
+        return out
 
 class Scale(OnnxOpConverter):
 
@@ -405,6 +406,36 @@ def _fully_connected(opset):
     return _impl
 
 
+class Shape(OnnxOpConverter):
+    """ Operator converter for Shape.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        # Result of this operator is prominently used by reshape operator.
+        # Just pass the input as it is so that reshape_like can be used there.
+        print("Shape: Differently implemented in NNVM as a bypass (dummy operator)")
+        return inputs[0]
+
+class Cast(OnnxOpConverter):
+    """ Operator converter for Cast.
+    """
+
+    @classmethod
+    def _impl_v1(cls, inputs, attr, params):
+        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)
+
+    @classmethod
+    def _impl_v5(cls, inputs, attr, params):
+        try:
+            from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
+            attr['to'] = TENSOR_TYPE_TO_NP_TYPE[attr['to']]
+        except ImportError as e:
+            raise ImportError(
+                "Unable to import onnx.mapping which is required {}".format(e))
+        return AttrCvt(op_name='cast', transforms={'to': 'dtype'})(inputs, attr)
+
+
 # compatible operators that do NOT require any conversion.
 _identity_list = []
 
@@ -505,7 +536,7 @@ def _get_convert_map(opset):
         # 'ArgMin'
 
         # defs/tensor
-        'Cast': AttrCvt('cast', {'to': 'dtype'}),
+        'Cast': Cast.get_converter(opset),
         'Reshape': Reshape.get_converter(opset),
         'Concat': Renamer('concatenate'),
         'Split': AttrCvt('split', {'split': 'indices_or_sections'}),
@@ -514,6 +545,7 @@ def _get_convert_map(opset):
         # 'Gather'
         # 'Squeeze'
         'Pad': Pad.get_converter(opset),
+        'Shape': Shape.get_converter(opset),
     }
 
 
@@ -719,6 +751,9 @@ def from_onnx(model):
     """
     g = GraphProto()
     graph = model.graph
-    opset = model.opset_import[0].version if model.opset_import else 1
+    try:
+        opset = model.opset_import[0].version if model.opset_import else 1
+    except AttributeError:
+        opset = 1
     sym, params = g.from_onnx(graph, opset)
     return sym, params
diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py
index 9aef8b2cb..93a458a67 100644
--- a/nnvm/tests/python/frontend/onnx/test_forward.py
+++ b/nnvm/tests/python/frontend/onnx/test_forward.py
@@ -5,6 +5,23 @@ from tvm.contrib import graph_runtime
 from nnvm.testing.config import ctx_list
 import onnx
 from model_zoo import super_resolution, squeezenet1_1, lenet, resnet18_1_0
+from onnx import helper, TensorProto
+
+def get_tvm_output(model, x, target, ctx, out_shape, dtype='float32'):
+    new_sym, params = nnvm.frontend.from_onnx(model)
+    input_name = model.graph.input[0].name
+    shape_dict = {input_name: x.shape}
+    dtype_dict = {input_name: dtype}
+    graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, dtype_dict, params=params)
+    m = graph_runtime.create(graph, lib, ctx)
+    # set inputs
+    m.set_input(input_name, tvm.nd.array(x.astype(dtype)))
+    m.set_input(**params)
+    m.run()
+    # get outputs
+    out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
+    return out.asnumpy()
+
 
 def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
     import caffe2.python.onnx.backend
@@ -14,26 +31,12 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
         c2_out = prepared_backend.run(W)[0]
         return c2_out
 
-    def get_tvm_output(model, x, target, ctx, dtype='float32'):
-        new_sym, params = nnvm.frontend.from_onnx(model)
-        input_name = model.graph.input[0].name
-        shape_dict = {input_name: x.shape}
-        graph, lib, params = nnvm.compiler.build(new_sym, target, shape_dict, params=params)
-        m = graph_runtime.create(graph, lib, ctx)
-        # set inputs
-        m.set_input(input_name, tvm.nd.array(x.astype(dtype)))
-        m.set_input(**params)
-        m.run()
-        # get outputs
-        out = m.get_output(0, tvm.nd.empty(out_shape, dtype))
-        return out.asnumpy()
-
     dtype = 'float32'
     x = np.random.uniform(size=data_shape)
     model = onnx.load(graph_file)
     c2_out = get_caffe2_output(model, x, dtype)
     for target, ctx in ctx_list():
-        tvm_out = get_tvm_output(model, x, target, ctx, dtype)
+        tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype)
         np.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
 
 def verify_super_resolution_example():
@@ -48,8 +51,66 @@ def verify_lenet():
 def verify_resnet18():
     verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000))
 
+
+def test_reshape():
+    in_shape = (4, 3, 3, 4)
+    ref_shape = (3, 4, 4, 3)
+
+    ref_array = np.array(ref_shape)
+    ref_node = onnx.helper.make_node('Constant',
+                                 inputs=[],
+                                 outputs=['ref_in'],
+                                 value=onnx.helper.make_tensor(name = 'const_tensor',
+                                                               data_type = onnx.TensorProto.INT32,
+                                                               dims = ref_array.shape,
+                                                               vals = ref_array.flatten().astype(int)))
+    reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
+
+    graph = helper.make_graph([ref_node, reshape_node],
+                              "reshape_test",
+                              inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+                              outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))])
+
+    model = helper.make_model(graph, producer_name='reshape_test')
+
+    for target, ctx in ctx_list():
+        x = np.random.uniform(size=in_shape)
+        tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
+
+    np.testing.assert_allclose(ref_shape, tvm_out.shape)
+
+def test_reshape_like():
+    in_shape = (4, 3, 3, 4)
+    ref_shape = (3, 4, 4, 3)
+
+    ref_array = np.random.uniform(size=ref_shape).astype('float32')
+    ref_node = onnx.helper.make_node('Constant',
+                                 inputs=[],
+                                 outputs=['ref_in'],
+                                 value=onnx.helper.make_tensor(name = 'const_tensor',
+                                                               data_type = onnx.TensorProto.FLOAT,
+                                                               dims = ref_array.shape,
+                                                               vals = ref_array.flatten().astype(float)))
+    copy_node = helper.make_node("Identity", ["ref_in"], ["copy_in"])
+    reshape_node = helper.make_node("Reshape", ["in", "copy_in"], ["out"])
+
+    graph = helper.make_graph([ref_node, copy_node, reshape_node],
+                              "reshape_like_test",
+                              inputs = [helper.make_tensor_value_info("in", TensorProto.FLOAT, list(in_shape))],
+                              outputs = [helper.make_tensor_value_info("out", TensorProto.FLOAT, list(ref_shape))])
+
+    model = helper.make_model(graph, producer_name='reshape_like_test')
+
+    for target, ctx in ctx_list():
+        x = np.random.uniform(size=in_shape)
+        tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
+
+    np.testing.assert_allclose(ref_shape, tvm_out.shape)
+
 if __name__ == '__main__':
     # verify_super_resolution_example()
     # verify_squeezenet1_1()
     # verify_lenet()
     verify_resnet18()
+    test_reshape()
+    test_reshape_like()
-- 
GitLab