From fae3f60cc8edbae14755f33d54bf8c17b1cd1d75 Mon Sep 17 00:00:00 2001
From: lixiaoquan <radioheads@163.com>
Date: Sat, 1 Sep 2018 03:35:05 +0800
Subject: [PATCH] [FRONTEND][TENSORFLOW] Add Transpose support. (#1665)

* [FRONTEND][TENSORFLOW] Add Transpose support.

* [FRONTEND][TENSORFLOW] Get parameter from inputs and fix document style.

* [FRONTEND][TENSORFLOW] Handle the case that perm is not specified.

* [FRONTEND][TENSORFLOW] Convert Rank and Range to param.

* [FRONTEND][TENSORFLOW] Fix a pylint issue.

* [FRONTEND][TENSORFLOW] Implement Rank and Range as normal op.
---
 nnvm/python/nnvm/frontend/tensorflow.py       | 32 +++++++++++++++++++
 .../frontend/tensorflow/test_forward.py       | 23 +++++++++++++
 2 files changed, 55 insertions(+)

diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py
index 6be5333cc..d9406601d 100644
--- a/nnvm/python/nnvm/frontend/tensorflow.py
+++ b/nnvm/python/nnvm/frontend/tensorflow.py
@@ -650,6 +650,35 @@ def _pad(name):
             ignores=['Tpaddings'],)(new_inputs, attr)
     return _impl
 
+def _transpose():
+    def _impl(inputs, attr, params):
+        # If perm is not specified, axes is left empty,
+        # otherwise its value is get from params
+        param_name = inputs[1].list_output_names()[0]
+        axes = params.get(param_name, tvm.nd.array([])).asnumpy()
+        return _sym.transpose(inputs[0], axes=tuple(axes))
+    return _impl
+
+def _rank():
+    def _impl(inputs, attr, params):
+        input_shapes = attr['_input_shapes'][inputs[0]]
+        assert len(inputs) == 1
+
+        name = attr["_node_name"]
+        params[name] = tvm.nd.array([len(input_shapes[0])])
+        return _sym.Variable(name=name, shape=params[name].shape)
+    return _impl
+
+def _range():
+    def _impl(inputs, attr, params):
+        start = params.pop(inputs[0].list_output_names()[0]).asnumpy()[0]
+        limit = params.pop(inputs[1].list_output_names()[0]).asnumpy()[0]
+        delta = params.pop(inputs[2].list_output_names()[0]).asnumpy()[0]
+
+        name = attr["_node_name"]
+        params[name] = tvm.nd.array([start, limit, delta])
+        return _sym.Variable(name=name, shape=params[name].shape)
+    return _impl
 
 # compatible operators that do NOT require any conversion.
 _identity_list = []
@@ -700,6 +729,9 @@ _convert_map = {
     'LRN'                               : _lrn(),
     'Pad'                               : _pad('Pad'),
     'PadV2'                             : _pad('PadV2'),
+    'Range'                             : _range(),
+    'Rank'                              : _rank(),
+    'Transpose'                         : _transpose(),
 }
 
 # _convert_map_rnn defines maps of rnn operator name to
diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py
index e0e18d1bd..b0fb02cf0 100644
--- a/nnvm/tests/python/frontend/tensorflow/test_forward.py
+++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py
@@ -853,11 +853,34 @@ def _test_l2_normalize(ishape, eps, axis):
 def test_forward_l2_normalize():
     _test_l2_normalize((1, 3, 20, 20), 0.001, (0,))
 
+#######################################################################
+# transpose
+# ---------
+def _test_forward_transpose(ishape, axes=None):
+    input = np.random.uniform(size=ishape).astype(np.float32)
+
+    with tf.Graph().as_default():
+        in1 = tf.placeholder(shape=input.shape, dtype=input.dtype, name="transpose_data")
+
+        if axes is None:
+            tf.transpose(in1)
+        else:
+            tf.transpose(in1, perm=axes)
+
+        compare_tf_with_tvm(input, 'transpose_data:0', 'transpose:0')
+
+def test_forward_transpose():
+    _test_forward_transpose((2, 3, 4))
+    _test_forward_transpose((7, 8, 8, 10))
+    _test_forward_transpose((2, 3, 4), (1, 2, 0))
+    _test_forward_transpose((2, 3, 4), (0, 1, 2))
+    _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2))
 
 #######################################################################
 # Main
 # ----
 if __name__ == '__main__':
+    test_forward_transpose()
     test_forward_convolution()
     test_forward_pooling()
     test_forward_reshape()
-- 
GitLab