From 4fb58115eaf7d4b70873d95d80d4249c78eb93c8 Mon Sep 17 00:00:00 2001
From: Pariksheet Pinjari <pariksheet.pinjari@huawei.com>
Date: Thu, 28 Jun 2018 20:35:33 +0530
Subject: [PATCH] Strided_slice added in NNVM (#1318)

---
 nnvm/include/nnvm/top/tensor.h                |  16 +++
 nnvm/python/nnvm/top/transform.py             |   4 +
 nnvm/src/top/tensor/transform.cc              | 113 ++++++++++++++++++
 nnvm/tests/python/compiler/test_top_level1.py |  36 ++++++
 4 files changed, 169 insertions(+)

diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h
index dca02cae9..87128a663 100644
--- a/nnvm/include/nnvm/top/tensor.h
+++ b/nnvm/include/nnvm/top/tensor.h
@@ -48,6 +48,22 @@ struct SplitParam : public dmlc::Parameter<SplitParam> {
   }
 };
 
+struct StridedSliceParam : public dmlc::Parameter<StridedSliceParam> {
+  // numpy convention, only support indices, not support list.
+  Tuple<int64_t> begin;
+  Tuple<int64_t> end;
+  Tuple<int64_t> stride;
+
+  DMLC_DECLARE_PARAMETER(StridedSliceParam) {
+    DMLC_DECLARE_FIELD(begin)
+        .describe("Indices for begin of slice");
+    DMLC_DECLARE_FIELD(end)
+        .describe("Indices for end of the slice");
+    DMLC_DECLARE_FIELD(stride).set_default(Tuple<int64_t>())
+        .describe("Stride values of the slice");
+  }
+};
+
 enum TypeFlag {
   kFloat32 = 0,
   kFloat64 = 1,
diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py
index 6900f85f9..c87b4735d 100644
--- a/nnvm/python/nnvm/top/transform.py
+++ b/nnvm/python/nnvm/top/transform.py
@@ -61,6 +61,10 @@ reg.register_schedule("concatenate", _fschedule_injective)
 reg.register_pattern("split", OpPattern.INJECTIVE)
 reg.register_schedule("split", _fschedule_injective)
 
+# strided_slice
+reg.register_pattern("strided_slice", OpPattern.INJECTIVE)
+reg.register_schedule("strided_slice", _fschedule_injective)
+
 # slice_like
 reg.register_pattern("slice_like", OpPattern.INJECTIVE)
 reg.register_schedule("slice_like", _fschedule_injective)
diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc
index 05d3cfcb5..0b0beb5b7 100644
--- a/nnvm/src/top/tensor/transform.cc
+++ b/nnvm/src/top/tensor/transform.cc
@@ -829,6 +829,119 @@ Examples::
     };
 });
 
+// strided_slice
+DMLC_REGISTER_PARAMETER(StridedSliceParam);
+
+inline void StridedSliceParamParser(nnvm::NodeAttrs* attrs) {
+  StridedSliceParam param;
+  param.Init(attrs->dict);
+  attrs->parsed = std::move(param);
+}
+
+inline bool StridedSliceInferShape(const NodeAttrs& attrs,
+                            std::vector<TShape>* in_shape,
+                            std::vector<TShape>* out_shape) {
+  const StridedSliceParam& param = nnvm::get<StridedSliceParam>(attrs.parsed);
+  const TShape& dshape = (*in_shape)[0];
+  if (dshape.ndim() == 0) return false;
+  TShape oshape = dshape;
+  dim_t num_axis = dshape.ndim();
+
+  std::vector<int64_t> begin_vec;
+  std::copy(param.begin.begin(), param.begin.end(), std::back_inserter(begin_vec));
+  for (dim_t i = begin_vec.size(); i < num_axis; ++i) {
+    begin_vec.push_back(0);
+  }
+
+  std::vector<int64_t> end_vec;
+  std::copy(param.end.begin(), param.end.end(), std::back_inserter(end_vec));
+  for (dim_t i = end_vec.size(); i < num_axis; ++i) {
+    end_vec.push_back(dshape[i]);
+  }
+
+  std::vector<int64_t> stride_vec;
+  std::copy(param.stride.begin(), param.stride.end(), std::back_inserter(stride_vec));
+  for (dim_t i = stride_vec.size(); i < num_axis; ++i) {
+    stride_vec.push_back(1);
+  }
+
+  for (dim_t i = 0; i < num_axis; ++i) {
+      int64_t begin_range = stride_vec[i] < 0 ? -1 : 0;
+      int64_t end_range = stride_vec[i] < 0 ? dshape[i] - 1 : dshape[i];
+      int64_t begin = begin_vec[i] < 0 ? dshape[i] + begin_vec[i] : begin_vec[i];
+      int64_t end = end_vec[i] < 0 ? dshape[i] + end_vec[i] : end_vec[i];
+      begin = std::min(std::max(begin, begin_range), end_range);
+      end = std::min(std::max(end, begin_range), end_range);
+
+      int interval = std::abs(end - begin);
+      int slice_size = static_cast<int>((interval
+                                       + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i]));
+      CHECK(stride_vec[i] < 0 ? (end < begin) : (begin < end))
+        << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i]
+        << "] is invalid for axis=" << i;
+      oshape[i] = slice_size;
+  }
+  NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
+  return true;
+}
+
+NNVM_REGISTER_OP(strided_slice)
+.describe(R"code(Strided slice of an array.
+
+Examples::
+
+  x = [[  1.,   4.,   7.,  10.],
+       [  2.,   5.,   8.,  11.],
+       [  3.,   6.,   9.,  12.]]
+
+  strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4.,  7.,  10.],
+                                                               [ 5.,  8.,  11.]]
+
+  x = [[[ 1.,  2.],
+        [ 3.,  4.]],
+
+       [[ 5.,  6.],
+        [ 7.,  8.]]]
+
+  strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1.,  2.],
+                                                 [ 3.,  4.]],
+
+                                                [[ 5.,  6.],
+                                                 [ 7.,  8.]]]
+)code" NNVM_ADD_FILELINE)
+.add_argument("data", "Tensor", "Array to be sliced")
+.add_arguments(StridedSliceParam::__FIELDS__())
+.set_attr_parser(StridedSliceParamParser)
+.set_attr<FInferShape>("FInferShape", StridedSliceInferShape)
+.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>)
+.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const StridedSliceParam& param = nnvm::get<StridedSliceParam>(attrs.parsed);
+    Array<Expr> begin;
+    Array<Expr> end;
+    Array<Expr> stride;
+
+    for (int64_t i : param.begin) {
+        begin.push_back(tvm::make_const(tvm::Int(32), i));
+    }
+
+    for (int64_t i : param.end) {
+        end.push_back(tvm::make_const(tvm::Int(32), i));
+    }
+
+    for (int64_t i : param.stride) {
+        stride.push_back(tvm::make_const(tvm::Int(32), i));
+    }
+
+    return Array<Tensor>{ topi::strided_slice(inputs[0], begin, end, stride) };
+})
+.set_support_level(1);
+
 // Flip
 DMLC_REGISTER_PARAMETER(FlipParam);
 
diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py
index 73391c80d..b97aff8ef 100644
--- a/nnvm/tests/python/compiler/test_top_level1.py
+++ b/nnvm/tests/python/compiler/test_top_level1.py
@@ -329,6 +329,41 @@ def test_split():
     verify_split((5, 3), [3], axis=0)
     verify_split((5, 9, 3), [3, 4], axis=1)
 
+def verify_strided_slice(ishape, begin, end, strideinp=None):
+    stride = strideinp if strideinp else [1, 1, 1]
+    x = sym.Variable("x")
+    if strideinp:
+        y = sym.strided_slice(x, begin = begin, end = end, stride = stride) + 1
+    else:
+        y = sym.strided_slice(x, begin = begin, end = end) + 1
+    x_np = np.random.uniform(size=ishape).astype("float32")
+    for i in range(len(begin), 3):
+        begin.append(0)
+    for i in range(len(end), 3):
+        end.append(ishape[i])
+    def test_forward(x, begin, end, stride):
+        return x[begin[0]:end[0]:stride[0],
+                    begin[1]:end[1]:stride[1], begin[2]:end[2]:stride[2]] + 1
+
+    for target, ctx in ctx_list():
+        # set input
+        graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape})
+        m = graph_runtime.create(graph, lib, ctx)
+        m.run(x=x_np)
+        res = test_forward(x_np, begin, end, stride)
+        out = m.get_output(0, tvm.nd.empty(res.shape))
+        np.testing.assert_allclose(out.asnumpy(), res, atol=1e-5, rtol=1e-5)
+
+def test_strided_slice():
+    verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2])
+    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1])
+    verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1])
+    verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2])
+    verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1])
+    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
+    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3])
+    verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4])
+    verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3])
 
 def verify_squeeze(dshape, axis):
     x = sym.Variable("x")
@@ -448,3 +483,4 @@ if __name__ == "__main__":
     test_pad()
     test_lrn()
     test_l2_normalize()
+    test_strided_slice()
-- 
GitLab