From 9b0e499054f204cee411afc1965c95a87d6dab1c Mon Sep 17 00:00:00 2001
From: Sergey Mironov <grrwlf@gmail.com>
Date: Sat, 18 Aug 2018 07:40:52 +0300
Subject: [PATCH] [NNVM] TF: Add Pack operation (#1570)

---
 nnvm/include/nnvm/top/tensor.h                |  2 +-
 nnvm/python/nnvm/frontend/tensorflow.py       |  9 ++++++
 nnvm/src/top/tensor/transform.cc              | 19 ++++++------
 .../frontend/tensorflow/test_forward.py       | 29 ++++++++++++++++++-
 4 files changed, 48 insertions(+), 11 deletions(-)

diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h
index 22ee9d711..53ed5b3b0 100644
--- a/nnvm/include/nnvm/top/tensor.h
+++ b/nnvm/include/nnvm/top/tensor.h
@@ -16,7 +16,7 @@ namespace top {
 struct ConcatenateParam : public dmlc::Parameter<ConcatenateParam> {
   int axis;
   DMLC_DECLARE_PARAMETER(ConcatenateParam) {
-    DMLC_DECLARE_FIELD(axis).set_lower_bound(0).set_default(1)
+    DMLC_DECLARE_FIELD(axis).set_default(1)
     .describe("the axis to be concated.");
   }
 };
diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py
index d761e34c7..092b8fa20 100644
--- a/nnvm/python/nnvm/frontend/tensorflow.py
+++ b/nnvm/python/nnvm/frontend/tensorflow.py
@@ -339,6 +339,14 @@ def _concat():
             extras={'axis': axis.asnumpy()[0]})(inputs, attr)
     return _impl
 
+def _pack():
+    def _impl(inputs, attr, params):
+        axis = int(attr["axis"])
+        inputs_reshaped = [_sym.expand_dims(i, axis=axis, num_newaxis=1) for i in inputs]
+        return _sym.concatenate(*inputs_reshaped, axis=axis)
+
+    return _impl
+
 def _reshape():
     def _impl(inputs, attr, params):
         try:
@@ -673,6 +681,7 @@ _convert_map = {
     'Minimum'                           : _elemwise('min'),
     'Sum'                               : _sum(),
     'Square'                            : _square(),
+    'Pack'                              : _pack(),
     'Relu'                              : AttrCvt('relu'),
     'Reshape'                           : _reshape(),
     'ResizeBilinear'                    : _resize_bilinear(),
diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc
index 78255d20f..52dca5654 100644
--- a/nnvm/src/top/tensor/transform.cc
+++ b/nnvm/src/top/tensor/transform.cc
@@ -93,23 +93,24 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs,
   TShape dshape;
   dim_t size = 0;
   bool has_zero = false;
+  int axis = param.axis >= 0 ? param.axis : in_shape->at(0).ndim() + param.axis;
   for (size_t i = 0; i < in_shape->size(); ++i) {
     TShape tmp = (*in_shape)[i];
     if (tmp.ndim()) {
-      CHECK_LT(static_cast<dim_t>(param.axis), tmp.ndim())
-          << "concat dim " << param.axis << " out of range of input shape " << tmp;
-      has_zero = tmp[param.axis] == 0 || has_zero;
-      size += tmp[param.axis];
-      tmp[param.axis] = 0;
+      CHECK_LT(static_cast<dim_t>(axis), tmp.ndim())
+          << "concat dim " << axis << " out of range of input shape " << tmp;
+      has_zero = tmp[axis] == 0 || has_zero;
+      size += tmp[axis];
+      tmp[axis] = 0;
       shape_assign(&dshape, tmp);
     }
   }
 
   TShape tmp = (*out_shape)[0];
   if (tmp.ndim()) {
-    CHECK_LT(static_cast<dim_t>(param.axis), tmp.ndim())
-        << "concat dim " << param.axis << " out of range of input shape " << tmp;
-    tmp[param.axis] = 0;
+    CHECK_LT(static_cast<dim_t>(axis), tmp.ndim())
+        << "concat dim " << axis << " out of range of input shape " << tmp;
+    tmp[axis] = 0;
     shape_assign(&dshape, tmp);
   }
 
@@ -119,7 +120,7 @@ inline bool ConcatenateInferShape(const NodeAttrs& attrs,
     NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, i, dshape);
   }
 
-  if (!has_zero) dshape[param.axis] = size;
+  if (!has_zero) dshape[axis] = size;
   NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, dshape);
   return dshape.Size() != 0;
 }
diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py
index 64c57c126..6fa020a03 100644
--- a/nnvm/tests/python/frontend/tensorflow/test_forward.py
+++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py
@@ -342,7 +342,7 @@ def _test_argx(func, data, **kwargs):
 
         compare_tf_with_tvm(data, 'c0:0', 'argx0:0')
 
-def test_argmin_argmax():
+def test_forward_argminmax():
     for axis in [None,0,1,2]:
         data = np.random.uniform(size=(8,4,9)).astype('float32')
         _test_argx(tf.argmax, data=data, axis=axis)
@@ -555,6 +555,31 @@ def test_forward_lstm():
 
     _test_lstm_cell(1, 2, 1, 0.0, 'float32')
 
+
+
+#######################################################################
+# Pack
+# ---
+def _test_pack(axis, shape, **kwargs):
+
+    a = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
+    b = np.arange(np.prod(shape), dtype=np.float32).reshape(shape)
+
+    with tf.Graph().as_default():
+        tf_a = array_ops.placeholder(shape=shape, dtype='float32', name='pl_a')
+        tf_b = array_ops.placeholder(shape=shape, dtype='float32', name='pl_b')
+        tf_c = tf.stack([tf_a,tf_b], axis=axis, **kwargs)
+        assert tf_c.op.op_def.name == 'Pack', "tf.stack() is expected to produce 'Pack' operation"
+
+        compare_tf_with_tvm([a,b], ['pl_a:0','pl_b:0'], 'stack:0')
+
+def test_forward_pack():
+    for axis in range(-3,3):
+        _test_pack(axis, [3,2,1])
+    for axis in range(-1,1):
+        _test_pack(axis, [3])
+    _test_pack(0, [])
+
 #######################################################################
 # Pad
 # ---
@@ -818,9 +843,11 @@ if __name__ == '__main__':
     test_forward_reshape()
     test_forward_squeeze()
     test_forward_sigmoid()
+    test_forward_argminmax()
     if tf.__version__ == '1.4.1':
         _test_forward_concat_v2()
     test_forward_multi_input()
+    test_forward_pack()
     test_forward_inception_v3()
     test_forward_inception_v1()
     test_forward_mobilenet()
-- 
GitLab