diff --git a/nnvm/python/nnvm/frontend/tensorflow.py b/nnvm/python/nnvm/frontend/tensorflow.py index 13ed717b045097658eb626ffc6dcc3e5023a127f..b01d489fb0423fef179fff299b634bfd0e580845 100644 --- a/nnvm/python/nnvm/frontend/tensorflow.py +++ b/nnvm/python/nnvm/frontend/tensorflow.py @@ -569,6 +569,7 @@ def _stridedSlice(): m_begin = [0] * data_dim m_end = [0] * data_dim m_stride = [0] * data_dim + fshape_indices = [] #Count new axis after ellipsis_mask, consider while applying ellipsis_mask. ellipsis_seen = False new_axes_after_ellipsis = 0 @@ -593,7 +594,10 @@ def _stridedSlice(): m_begin[final_index] = 0 m_end[final_index] = data_shape[0][final_index] m_stride[final_index] = 1 + fshape_indices.append(final_index) final_index += 1 + elif mask &new_axis_mask: + fshape_indices.append(-1) elif not mask & new_axis_mask: if final_index == len(m_begin): break @@ -614,28 +618,30 @@ def _stridedSlice(): if begin[index] < 0 else begin[index] m_end[final_index] = begin[index] + 1 m_stride[final_index] = 1 + fshape_indices.append(-2) + else: + fshape_indices.append(final_index) + final_index += 1 - return m_begin, m_end, m_stride + return m_begin, m_end, m_stride, fshape_indices + fshape_indices = None if begin_mask or end_mask or ellipsis_mask or new_axis_mask or shrink_axis_mask: - begin, end, stride = _transform_mask(stride_dim, ellipsis_mask) + begin, end, stride, fshape_indices = _transform_mask(stride_dim, ellipsis_mask) out = _sym.strided_slice(inputs[0], begin=begin, end=end, stride=stride) out_shape = _infer_out_shapes(out, params)[0] + if not fshape_indices: + fshape_indices = range(len(out_shape)) #Create final output shape. final_output = [] - out_index = 0 - index = 0 - while out_index != len(out_shape): - #axis with shrink_axis_mask dimension=1 and it is ignored. - mask = 1 << index - if (new_axis_mask & mask) and not ellipsis_mask & mask: + for gather_index in fshape_indices: + if gather_index == -1: final_output.append(1) - elif (not mask & shrink_axis_mask) or index >= stride_dim: - #Shrink is considered till stride_dim - final_output.append(out_shape[out_index]) - out_index += 1 - index += 1 + elif gather_index == -2: + pass + else: + final_output.append(out_shape[gather_index]) return _sym.reshape(out, shape=tuple(final_output)) return _impl diff --git a/nnvm/tests/python/frontend/tensorflow/test_forward.py b/nnvm/tests/python/frontend/tensorflow/test_forward.py index e93f14ceb96892a875e289815d900a3b0eaa99d2..c98748c0fc03303258ab09559f93760040fd8812 100644 --- a/nnvm/tests/python/frontend/tensorflow/test_forward.py +++ b/nnvm/tests/python/frontend/tensorflow/test_forward.py @@ -435,11 +435,15 @@ def _test_stridedslice(ip_shape, begin, end, stride, dtype, def test_forward_stridedslice(): '''test StridedSlice''' - return + _test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32') _test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], 'float32', ellipsis_mask=8) + _test_stridedslice((3, 4, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2) + _test_stridedslice((3, 4, 5, 3), [1, 0], [4, 2], [2, 1], 'float32', ellipsis_mask=2) + _test_stridedslice((3, 4, 5, 3), [1, 0, 1], [4, 2, 2], [2, 1, 1], 'float32', ellipsis_mask=2) _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 2], [2, 1, 1], 'float32', new_axis_mask=5) _test_stridedslice((3, 4, 3), [1, 1, 1], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=4) + _test_stridedslice((6, 4, 5), [1, 1, 1], [6, 3, 4], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=5) _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=4, new_axis_mask=2) _test_stridedslice((3, 4, 3), [1, 1, 2], [4, 4, 3], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3) _test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 1], [2, 1, 1], 'float32', ellipsis_mask=2, new_axis_mask=3) @@ -1056,7 +1060,7 @@ if __name__ == '__main__': test_forward_resize_bilinear() test_forward_pad() test_forward_gather() - #test_forward_stridedslice() + test_forward_stridedslice() # Activations test_forward_sigmoid()