From fae3f60cc8edbae14755f33d54bf8c17b1cd1d75 Mon Sep 17 00:00:00 2001 From: lixiaoquan <radioheads@163.com> Date: Sat, 1 Sep 2018 03:35:05 +0800 Subject: [PATCH] [FRONTEND][TENSORFLOW] Add Transpose support. (#1665) * [FRONTEND][TENSORFLOW] Add Transpose support. * [FRONTEND][TENSORFLOW] Get parameter from inputs and fix document style. * [FRONTEND][TENSORFLOW] Handle the case that perm is not specified. * [FRONTEND][TENSORFLOW] Convert Rank and Range to param. * [FRONTEND][TENSORFLOW] Fix a pylint issue. * [FRONTEND][TENSORFLOW] Implement Rank and Range as normal op. --- nnvm/python/nnvm/frontend/tensorflow.py | 32 +++++++++++++++++++ .../frontend/tensorflow/test_forward.py | 23 +++++++++++++ 2 files changed, 55 insertions(+) diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 6be5333cc..d9406601d 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -650,6 +650,35 @@ def _pad(name): ignores=['Tpaddings'],)(new_inputs, attr) return _impl +def _transpose(): + def _impl(inputs, attr, params): + # If perm is not specified, axes is left empty, + # otherwise its value is get from params + param_name = inputs[1].list_output_names()[0] + axes = params.get(param_name, tvm.nd.array([])).asnumpy() + return _sym.transpose(inputs[0], axes=tuple(axes)) + return _impl + +def _rank(): + def _impl(inputs, attr, params): + input_shapes = attr['_input_shapes'][inputs[0]] + assert len(inputs) == 1 + + name = attr["_node_name"] + params[name] = tvm.nd.array([len(input_shapes[0])]) + return _sym.Variable(name=name, shape=params[name].shape) + return _impl + +def _range(): + def _impl(inputs, attr, params): + start = params.pop(inputs[0].list_output_names()[0]).asnumpy()[0] + limit = params.pop(inputs[1].list_output_names()[0]).asnumpy()[0] + delta = params.pop(inputs[2].list_output_names()[0]).asnumpy()[0] + + name = attr["_node_name"] + params[name] = tvm.nd.array([start, limit, delta]) + return _sym.Variable(name=name, shape=params[name].shape) + return _impl # compatible operators that do NOT require any conversion. _identity_list = [] @@ -700,6 +729,9 @@ _convert_map = { 'LRN' : _lrn(), 'Pad' : _pad('Pad'), 'PadV2' : _pad('PadV2'), + 'Range' : _range(), + 'Rank' : _rank(), + 'Transpose' : _transpose(), } # _convert_map_rnn defines maps of rnn operator name to diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index e0e18d1bd..b0fb02cf0 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -853,11 +853,34 @@ def _test_l2_normalize(ishape, eps, axis): def test_forward_l2_normalize(): _test_l2_normalize((1, 3, 20, 20), 0.001, (0,)) +####################################################################### +# transpose +# --------- +def _test_forward_transpose(ishape, axes=None): + input = np.random.uniform(size=ishape).astype(np.float32) + + with tf.Graph().as_default(): + in1 = tf.placeholder(shape=input.shape, dtype=input.dtype, name="transpose_data") + + if axes is None: + tf.transpose(in1) + else: + tf.transpose(in1, perm=axes) + + compare_tf_with_tvm(input, 'transpose_data:0', 'transpose:0') + +def test_forward_transpose(): + _test_forward_transpose((2, 3, 4)) + _test_forward_transpose((7, 8, 8, 10)) + _test_forward_transpose((2, 3, 4), (1, 2, 0)) + _test_forward_transpose((2, 3, 4), (0, 1, 2)) + _test_forward_transpose((2, 3, 4, 5), (3, 0, 1, 2)) ####################################################################### # Main # ---- if __name__ == '__main__': + test_forward_transpose() test_forward_convolution() test_forward_pooling() test_forward_reshape() -- GitLab