From 6737739c3753f3ebce96b3426de7cd0e546582fa Mon Sep 17 00:00:00 2001
From: Ashutosh Parkhi <ashutosh.parkhi@imgtec.com>
Date: Mon, 14 Jan 2019 16:15:42 +0530
Subject: [PATCH] [Tensorflow] Support for Crop (#2285)

fixes

fixes
---
 nnvm/python/nnvm/frontend/tensorflow.py        | 16 ++++++++++++++++
 .../python/frontend/tensorflow/test_forward.py | 18 ++++++++++++++++++
 2 files changed, 34 insertions(+)

diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py
index a869abac9..c0848bb10 100644
--- a/nnvm/python/nnvm/frontend/tensorflow.py
+++ b/nnvm/python/nnvm/frontend/tensorflow.py
@@ -388,6 +388,21 @@ def _pack():
 
     return _impl
 
+def _slice():
+    def _impl(inputs, attr, params):
+        begin = params.pop(inputs[1].list_output_names()[0]).asnumpy().tolist()
+        size = params.pop(inputs[2].list_output_names()[0]).asnumpy().tolist()
+        data_shape = attr['_input_shapes'][inputs[0]]
+        data_dim = len(data_shape)
+        end = size
+        for i in range(data_dim):
+            if size[i] == -1:
+                end[i] = data_shape[i] - begin[i]
+            else:
+                end[i] += begin[i]
+        return _sym.strided_slice(inputs[0], begin=begin, end=size)
+    return _impl
+
 def _reshape():
     def _impl(inputs, attr, params):
         try:
@@ -883,6 +898,7 @@ _convert_map = {
     'Sum'                               : _sum(),
     'Square'                            : _square(),
     'Pack'                              : _pack(),
+    'Slice'                             : _slice(),
     'LeakyRelu'                         : AttrCvt('leaky_relu'),
     'Relu'                              : AttrCvt('relu'),
     'Reshape'                           : _reshape(),
diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py
index 5b8f11695..0ea92248f 100644
--- a/nnvm/tests/python/frontend/tensorflow/test_forward.py
+++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py
@@ -655,6 +655,23 @@ def test_forward_resize_bilinear():
     _test_resize_bilinear((6, 32, 64, 64), [20, 20], True)
 
 
+#######################################################################
+# Crop to bounding box
+# --------------------
+
+def _test_crop(in_shape, off_h, off_w, tar_h, tar_w):
+    """ Crop to bounding box """
+    data = np.random.uniform(size=in_shape).astype('float32')
+    with tf.Graph().as_default():
+        in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
+        tf.image.crop_to_bounding_box(in_data, off_h, off_w, tar_h, tar_w)
+        compare_tf_with_tvm(data, 'Placeholder:0', 'crop_to_bounding_box/Slice:0')
+
+def test_forward_crop():
+    """ Crop to bounding box """
+    _test_crop((1, 224, 224, 3), 20, 20, 120, 120)
+
+
 #######################################################################
 # LSTM
 # ----
@@ -1139,6 +1156,7 @@ if __name__ == '__main__':
     test_forward_squeeze()
     test_forward_pack()
     test_forward_resize_bilinear()
+    test_forward_crop()
     test_forward_pad()
     test_forward_gather()
     test_forward_stridedslice()
-- 
GitLab