From 373a8caa79293bcfb9a026fa73e9f25885240dd0 Mon Sep 17 00:00:00 2001
From: Siva <sivar.b@huawei.com>
Date: Tue, 26 Jun 2018 22:45:24 +0530
Subject: [PATCH] [NNVM][TENSORFLOW] Mobilenet support. (#1335)

---
 nnvm/python/nnvm/frontend/tensorflow.py       | 52 +++++++++-----
 nnvm/python/nnvm/testing/tf.py                | 70 +++++++++++++------
 .../frontend/tensorflow/test_forward.py       | 24 +++++++
 3 files changed, 109 insertions(+), 37 deletions(-)

diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py
index 3ab997617..715365178 100644
--- a/nnvm/python/nnvm/frontend/tensorflow.py
+++ b/nnvm/python/nnvm/frontend/tensorflow.py
@@ -35,6 +35,11 @@ class AttrCvt(object):
         self._ignores.append('use_cudnn_on_gpu')
         self._ignores.append('_node_name')
         self._ignores.append('is_training')
+        # Retain the names
+        try:
+            attrs['name'] = attrs['_node_name']
+        except KeyError:
+            pass
         return AttrConvert(self._op_name, self._transforms, self._excludes,
                            self._disables, self._ignores, self._extras,
                            self._custom_check)(inputs, attrs, *args)
@@ -405,13 +410,19 @@ def _concat():
 
 def _reshape():
     def _impl(inputs, attr, params):
-        pop_node = inputs.pop(1)
-        shape_arg = params[pop_node.list_output_names()[0]]
-        params.pop(pop_node.list_output_names()[0])
-        return AttrCvt(
-            op_name="reshape",
-            extras={'shape':tuple(shape_arg.asnumpy())},
-            ignores=['Tshape'])(inputs, attr)
+        try:
+            pop_node = inputs[1]
+            shape_arg = params.pop(pop_node.list_output_names()[0])
+            inputs.pop(1)
+
+            return AttrCvt(
+                op_name="reshape",
+                extras={'shape':tuple(shape_arg.asnumpy())},
+                ignores=['Tshape'])(inputs, attr)
+        except KeyError:
+            return AttrCvt(
+                op_name="reshape_like",
+                ignores=['Tshape'])(inputs, attr)
     return _impl
 
 def _bias_add():
@@ -427,6 +438,18 @@ def _squeeze():
             ignores=['T'])(inputs, attr)
     return _impl
 
+def _fused_batch_norm():
+    def _impl(inputs, attr, params):
+        # Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
+        # NNVM:       (data, gamma, beta, moving_mean, moving_varience)
+        return AttrCvt(
+            op_name='batch_norm',
+            transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
+            extras={'axis': 3}, # Fix axis
+            ignores=['data_format'],
+            disables=['momentum'])(inputs, attr)
+    return _impl
+
 def _batch_norm():
     def _impl(inputs, attr, params):
         # Rearrange inputs from
@@ -445,19 +468,14 @@ def _batch_norm():
 
 def _relu6():
     def _impl(inputs, attr, params):
-        return _sym.clip(inputs[0], a_min=0, a_max=6)
+        return _sym.clip(inputs[0], a_min=0, a_max=6, name=attr['_node_name'])
     return _impl
 
 def _shape():
     def _impl(inputs, attr, params):
-        input_shapes = attr['_input_shapes'][inputs[0]]
-
-        # Fix the -1 dimensions to 1
-        input_shapes[0] = [1 if x == -1 else x for x in input_shapes[0]]
-        params[attr['_node_name']] = tvm.nd.array(input_shapes[0])
-
-        return _sym.Variable(name=attr['_node_name'],
-                             shape=params[attr['_node_name']].shape)
+        # Result of this operator is prominently used by reshape operator.
+        # Just pass the input as it is so that reshape_like can be used there.
+        return inputs[0]
     return _impl
 
 # compatible operators that do NOT require any conversion.
@@ -491,7 +509,7 @@ _convert_map = {
     'Add'                               : _elemwise('add'),
     'Rsqrt'                             : _rsqrt(),
     'Squeeze'                           : _squeeze(),
-    'FusedBatchNorm'                    : _batch_norm(),
+    'FusedBatchNorm'                    : _fused_batch_norm(),
     'Relu6'                             : _relu6(),
     'DepthwiseConv2dNative'             : _depthwise_conv(),
     'Shape'                             : _shape(),
diff --git a/nnvm/python/nnvm/testing/tf.py b/nnvm/python/nnvm/testing/tf.py
index 3421573e3..1762ce565 100644
--- a/nnvm/python/nnvm/testing/tf.py
+++ b/nnvm/python/nnvm/testing/tf.py
@@ -153,6 +153,35 @@ def read_normalized_tensor_from_image_file(file_name,
     np_array = normalized.eval()
     return np_array
 
+def get_workload(model_path):
+    """ Import workload from frozen protobuf
+
+    Parameters
+    ----------
+    model_path: str
+        model_path on remote repository to download from.
+
+    Returns
+    -------
+    graph_def: graphdef
+        graph_def is the tensorflow workload for mobilenet.
+
+    """
+
+    repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/'
+    model_name = os.path.basename(model_path)
+    model_url = os.path.join(repo_base, model_path)
+
+    from mxnet.gluon.utils import download
+    download(model_url, model_name)
+
+    # Creates graph from saved graph_def.pb.
+    with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
+        graph_def = tf.GraphDef()
+        graph_def.ParseFromString(f.read())
+        graph = tf.import_graph_def(graph_def, name='')
+        return graph_def
+
 def get_workload_inception_v3():
     """ Import Inception V3 workload from frozen protobuf
 
@@ -168,23 +197,15 @@ def get_workload_inception_v3():
     """
 
     repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV3/'
-    model_name = 'inception_v3_2016_08_28_frozen-with_shapes.pb'
-    model_url = os.path.join(repo_base, model_name)
+    model_path = 'InceptionV3/inception_v3_2016_08_28_frozen-with_shapes.pb'
+
     image_name = 'elephant-299.jpg'
     image_url = os.path.join(repo_base, image_name)
-
     from mxnet.gluon.utils import download
-    download(model_url, model_name)
     download(image_url, image_name)
-
     normalized = read_normalized_tensor_from_image_file(os.path.join("./", image_name))
 
-    # Creates graph from saved graph_def.pb.
-    with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
-        graph_def = tf.GraphDef()
-        graph_def.ParseFromString(f.read())
-        graph = tf.import_graph_def(graph_def, name='')
-        return (normalized, graph_def)
+    return (normalized, get_workload(model_path))
 
 def get_workload_inception_v1():
     """ Import Inception V1 workload from frozen protobuf
@@ -203,13 +224,11 @@ def get_workload_inception_v1():
     """
 
     repo_base = 'https://github.com/dmlc/web-data/raw/master/tensorflow/models/InceptionV1/'
-    model_name = 'classify_image_graph_def-with_shapes.pb'
-    model_url = os.path.join(repo_base, model_name)
+    model_path = 'InceptionV1/classify_image_graph_def-with_shapes.pb'
     image_name = 'elephant-299.jpg'
     image_url = os.path.join(repo_base, image_name)
 
     from mxnet.gluon.utils import download
-    download(model_url, model_name)
     download(image_url, image_name)
 
     if not tf.gfile.Exists(os.path.join("./", image_name)):
@@ -221,9 +240,20 @@ def get_workload_inception_v1():
     tvm_data = Image.open(os.path.join("./", image_name)).resize((299, 299))
     tvm_data = np.array(tvm_data)
 
-    # Creates graph from saved graph_def.pb.
-    with tf.gfile.FastGFile(os.path.join("./", model_name), 'rb') as f:
-        graph_def = tf.GraphDef()
-        graph_def.ParseFromString(f.read())
-        graph = tf.import_graph_def(graph_def, name='')
-        return (image_data, tvm_data, graph_def)
+    return (image_data, tvm_data, get_workload(model_path))
+
+def get_workload_mobilenet():
+    """ Import mobilenet workload from frozen protobuf
+
+    Parameters
+    ----------
+        Nothing.
+
+    Returns
+    -------
+    graph_def: graphdef
+        graph_def is the tensorflow workload for mobilenet.
+
+    """
+
+    return get_workload("MobilenetV1/mobilenet_v1_1.0_224_frozen-with-shapes.pb")
diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py
index 4e742a4a5..6dc8cfab2 100644
--- a/nnvm/tests/python/frontend/tensorflow/test_forward.py
+++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py
@@ -406,6 +406,29 @@ def test_forward_inception_v1():
 
         np.testing.assert_allclose(tf_output, tvm_output, rtol=2e-2, atol=2e-2)
 
+#######################################################################
+# Mobilenet
+# ---------
+def test_forward_mobilenet():
+    '''test mobilenet model'''
+    with tf.Graph().as_default():
+        graph_def = nnvm.testing.tf.get_workload_mobilenet()
+        # Call the utility to import the graph definition into default graph.
+        graph_def = nnvm.testing.tf.ProcessGraphDefParam(graph_def)
+
+        data = np.random.uniform(size=(1, 224, 224, 3)).astype('float32')
+        out_node = 'MobilenetV1/Predictions/Reshape_1'
+
+        with tf.Session() as sess:
+            tf_output = run_tf_graph(sess, data, 'input:0', out_node + ':0')
+
+            out_shape = tf_output.shape
+            tvm_output = run_tvm_graph(graph_def, data, 'input', out_shape, 'float32')
+            top_tvm = np.squeeze(tvm_output).argsort()[-10:][::-1]
+            top_tf = np.squeeze(tf_output).argsort()[-10:][::-1]
+
+            np.testing.assert_allclose(np.squeeze(tvm_output), np.squeeze(tf_output), rtol=1e-5, atol=1e-5)
+
 #######################################################################
 # Main
 # ----
@@ -419,3 +442,4 @@ if __name__ == '__main__':
     test_forward_multi_input()
     test_forward_inception_v3()
     test_forward_inception_v1()
+    test_forward_mobilenet()
-- 
GitLab