From 9b0e499054f204cee411afc1965c95a87d6dab1c Mon Sep 17 00:00:00 2001 From: Sergey Mironov <grrwlf@gmail.com> Date: Sat, 18 Aug 2018 07:40:52 +0300 Subject: [PATCH] [NNVM] TF: Add Pack operation (#1570) --- nnvm/include/nnvm/top/tensor.h | 2 +- nnvm/python/nnvm/frontend/tensorflow.py | 9 ++++++ nnvm/src/top/tensor/transform.cc | 19 ++++++------ .../frontend/tensorflow/test_forward.py | 29 ++++++++++++++++++- 4 files changed, 48 insertions(+), 11 deletions(-) diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h index 22ee9d711..53ed5b3b0 100644 --- a/nnvm/include/nnvm/top/tensor.h +++ b/nnvm/include/nnvm/top/tensor.h @@ -16,7 +16,7 @@ namespace top { struct ConcatenateParam : public dmlc::Parameter<ConcatenateParam> { int axis; DMLC_DECLARE_PARAMETER(ConcatenateParam) { - DMLC_DECLARE_FIELD(axis).set_lower_bound(0).set_default(1) + DMLC_DECLARE_FIELD(axis).set_default(1) .describe("the axis to be concated."); } }; diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index d761e34c7..092b8fa20 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -339,6 +339,14 @@ def _concat(): extras={'axis': axis.asnumpy()[0]})(inputs, attr) return _impl +def _pack(): + def _impl(inputs, attr, params): + axis = int(attr["axis"]) + inputs_reshaped = [_sym.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs] + return _sym.concatenate(*inputs_reshaped, axis=axis) + + return _impl + def _reshape(): def _impl(inputs, attr, params): try: @@ -673,6 +681,7 @@ _convert_map = { 'Minimum' : _elemwise('min'), 'Sum' : _sum(), 'Square' : _square(), + 'Pack' : _pack(), 'Relu' : AttrCvt('relu'), 'Reshape' : _reshape(), 'ResizeBilinear' : _resize_bilinear(), diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 78255d20f..52dca5654 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -93,23 +93,24 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs, TShape dshape; dim_t size = 0; bool has_zero = false; + int axis = param.axis >= 0 ? param.axis : in_shape->at(0).ndim() + param.axis; for (size_t i = 0; i < in_shape->size(); ++i) { TShape tmp = (*in_shape)[i]; if (tmp.ndim()) { - CHECK_LT(static_cast<dim_t>(param.axis), tmp.ndim()) - << "concat dim " << param.axis << " out of range of input shape " << tmp; - has_zero = tmp[param.axis] == 0 || has_zero; - size += tmp[param.axis]; - tmp[param.axis] = 0; + CHECK_LT(static_cast<dim_t>(axis), tmp.ndim()) + << "concat dim " << axis << " out of range of input shape " << tmp; + has_zero = tmp[axis] == 0 || has_zero; + size += tmp[axis]; + tmp[axis] = 0; shape_assign(&dshape, tmp); } } TShape tmp = (*out_shape)[0]; if (tmp.ndim()) { - CHECK_LT(static_cast<dim_t>(param.axis), tmp.ndim()) - << "concat dim " << param.axis << " out of range of input shape " << tmp; - tmp[param.axis] = 0; + CHECK_LT(static_cast<dim_t>(axis), tmp.ndim()) + << "concat dim " << axis << " out of range of input shape " << tmp; + tmp[axis] = 0; shape_assign(&dshape, tmp); } @@ -119,7 +120,7 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs, NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, i, dshape); } - if (!has_zero) dshape[param.axis] = size; + if (!has_zero) dshape[axis] = size; NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, dshape); return dshape.Size() != 0; } diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index 64c57c126..6fa020a03 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -342,7 +342,7 @@ def _test_argx(func, data, **kwargs): compare_tf_with_tvm(data, 'c0:0', 'argx0:0') -def test_argmin_argmax(): +def test_forward_argminmax(): for axis in [None,0,1,2]: data = np.random.uniform(size=(8,4,9)).astype('float32') _test_argx(tf.argmax, data=data, axis=axis) @@ -555,6 +555,31 @@ def test_forward_lstm(): _test_lstm_cell(1, 2, 1, 0.0, 'float32') + + +####################################################################### +# Pack +# --- +def _test_pack(axis, shape, **kwargs): + + a = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + b = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + + with tf.Graph().as_default(): + tf_a = array_ops.placeholder(shape=shape, dtype='float32', name='pl_a') + tf_b = array_ops.placeholder(shape=shape, dtype='float32', name='pl_b') + tf_c = tf.stack([tf_a,tf_b], axis=axis, **kwargs) + assert tf_c.op.op_def.name == 'Pack', "tf.stack() is expected to produce 'Pack' operation" + + compare_tf_with_tvm([a,b], ['pl_a:0','pl_b:0'], 'stack:0') + +def test_forward_pack(): + for axis in range(-3,3): + _test_pack(axis, [3,2,1]) + for axis in range(-1,1): + _test_pack(axis, [3]) + _test_pack(0, []) + ####################################################################### # Pad # --- @@ -818,9 +843,11 @@ if __name__ == '__main__': test_forward_reshape() test_forward_squeeze() test_forward_sigmoid() + test_forward_argminmax() if tf.__version__ == '1.4.1': _test_forward_concat_v2() test_forward_multi_input() + test_forward_pack() test_forward_inception_v3() test_forward_inception_v1() test_forward_mobilenet() -- GitLab