From 1f2c815671353a4cb7f337eca71ba67d56b4d799 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Tue, 13 Nov 2018 13:32:38 -0800
Subject: [PATCH] [RELAY][OP] strided_slice (#2094)

---
 docs/langref/relay_op.rst                     |   2 +
 include/tvm/relay/attrs/transform.h           |  15 ++
 nnvm/src/top/tensor/transform.cc              |  30 +++-
 python/tvm/_ffi/node_generic.py               |   2 +
 python/tvm/relay/op/__init__.py               |   1 +
 python/tvm/relay/op/_transform.py             |   8 +
 python/tvm/relay/op/transform.py              |  27 +++
 src/api/api_lang.cc                           |   6 +-
 src/relay/ir/text_printer.cc                  |   6 +-
 src/relay/op/tensor/transform.cc              | 168 ++++++++++++++++++
 tests/python/relay/test_op_level4.py          |  38 +++-
 topi/include/topi/transform.h                 |  55 ++++--
 topi/python/topi/testing/__init__.py          |   1 +
 .../topi/testing/strided_slice_python.py      |  32 ++++
 topi/tests/python/test_topi_transform.py      |  17 +-
 15 files changed, 371 insertions(+), 37 deletions(-)
 create mode 100644 python/tvm/relay/op/_transform.py
 create mode 100644 topi/python/topi/testing/strided_slice_python.py

diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst
index 405f071e3..e99ac3c97 100644
--- a/docs/langref/relay_op.rst
+++ b/docs/langref/relay_op.rst
@@ -123,6 +123,7 @@ This level enables additional math and transform operators.
    tvm.relay.min
    tvm.relay.mean
    tvm.relay.prod
+   tvm.relay.strided_slice
 
 
 **Level 5: Vision/Image Operators**
@@ -227,6 +228,7 @@ Level 4 Definitions
 .. autofunction:: tvm.relay.min
 .. autofunction:: tvm.relay.mean
 .. autofunction:: tvm.relay.prod
+.. autofunction:: tvm.relay.strided_slice
 
 
 
diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h
index cb87d358e..4d2008628 100644
--- a/include/tvm/relay/attrs/transform.h
+++ b/include/tvm/relay/attrs/transform.h
@@ -123,6 +123,21 @@ struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
   }
 };
 
+/*! \brief Attributes for StridedSlice operator */
+struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
+  Array<Integer> begin;
+  Array<Integer> end;
+  Array<Integer> strides;
+
+  TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
+    TVM_ATTR_FIELD(begin)
+        .describe("Indices for begin of slice, begin index is also inclusive");
+    TVM_ATTR_FIELD(end)
+        .describe("Indices for end of slice, end index is also inclusive");
+    TVM_ATTR_FIELD(strides).set_default(Array<Integer>({}))
+        .describe("Stride values of the slice");
+  }
+};
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_TRANSFORM_H_
diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc
index 4d08bf761..2f42727d6 100644
--- a/nnvm/src/top/tensor/transform.cc
+++ b/nnvm/src/top/tensor/transform.cc
@@ -980,23 +980,25 @@ Examples::
                     const Array<Tensor>& inputs,
                     const Array<Tensor>& out_info) {
     const StridedSliceParam& param = nnvm::get<StridedSliceParam>(attrs.parsed);
-    Array<Expr> begin;
-    Array<Expr> end;
-    Array<Expr> stride;
+    Array<Integer> begin;
+    Array<Integer> end;
+    Array<Integer> stride;
 
     for (int64_t i : param.begin) {
-        begin.push_back(tvm::make_const(tvm::Int(32), i));
+      begin.push_back(static_cast<int>(i));
     }
 
     for (int64_t i : param.end) {
-        end.push_back(tvm::make_const(tvm::Int(32), i));
+      end.push_back(static_cast<int>(i));
     }
 
     for (int64_t i : param.stride) {
-        stride.push_back(tvm::make_const(tvm::Int(32), i));
+      stride.push_back(static_cast<int>(i));
     }
 
-    return Array<Tensor>{ topi::strided_slice(inputs[0], begin, end, stride) };
+    return Array<Tensor>{
+      topi::strided_slice(inputs[0], begin, end, stride)
+    };
 })
 .set_support_level(1);
 
@@ -1210,6 +1212,15 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
   return true;
 }
 
+// Adapter function to make int array.
+Array<Integer> GetIntArray(Array<Expr> arr) {
+  for (size_t i = 0; i < arr.size(); ++i) {
+    CHECK(!arr[i].defined() || arr[i].as<IntImm>())
+        << "Expect an int array";
+  }
+  return Array<Integer>(arr.node_);
+}
+
 NNVM_REGISTER_OP(slice_like)
 .describe(R"code(Slice the first input respect to the second input.
 )code" NNVM_ADD_FILELINE)
@@ -1261,7 +1272,10 @@ NNVM_REGISTER_OP(slice_like)
       }
     }
     return Array<Tensor>{
-      topi::strided_slice(inputs[0], begin_idx, end_idx, strides)
+      topi::strided_slice(inputs[0],
+                          GetIntArray(begin_idx),
+                          GetIntArray(end_idx),
+                          GetIntArray(strides))
     };
 })
 .set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
diff --git a/python/tvm/_ffi/node_generic.py b/python/tvm/_ffi/node_generic.py
index b7230f29d..e86453499 100644
--- a/python/tvm/_ffi/node_generic.py
+++ b/python/tvm/_ffi/node_generic.py
@@ -56,6 +56,8 @@ def convert_to_node(value):
         return _api_internal._Map(*vlist)
     elif isinstance(value, NodeGeneric):
         return value.asnode()
+    elif value is None:
+        return None
     else:
         raise ValueError("don't know how to convert type %s to node" % type(value))
 
diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py
index 9b5814866..30aef433d 100644
--- a/python/tvm/relay/op/__init__.py
+++ b/python/tvm/relay/op/__init__.py
@@ -13,6 +13,7 @@ from . import vision
 
 # operator registry
 from . import _tensor
+from . import _transform
 from ..expr import Expr
 from ..base import register_relay_node
 
diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py
new file mode 100644
index 000000000..7867336d0
--- /dev/null
+++ b/python/tvm/relay/op/_transform.py
@@ -0,0 +1,8 @@
+#pylint: disable=invalid-name, unused-argument
+"""Backend compiler related feature registration"""
+from __future__ import absolute_import
+from . import op as _reg
+from .op import schedule_injective
+
+# strided_slice
+_reg.register_schedule("strided_slice", schedule_injective)
diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py
index 909b175f0..e43a4a573 100644
--- a/python/tvm/relay/op/transform.py
+++ b/python/tvm/relay/op/transform.py
@@ -334,3 +334,30 @@ def split(data, indices_or_sections, axis=0):
     else:
         ret_size = len(indices_or_sections) + 1
     return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size)
+
+
+def strided_slice(data, begin, end, strides=None):
+    """Strided slice of an array..
+
+    Parameters
+    ----------
+    data : relay.Expr
+        The source array to be sliced.
+
+    begin: list of int
+        The indices to begin with in the slicing.
+
+    end: list of int
+        Indicies indicating end of the slice.
+
+    strides: list of int, optional
+        Specifies the stride values, it can be negative in that case,
+        the input tensor will be reversed in that particular axis.
+
+    Returns
+    -------
+    ret : relay.Expr
+        The computed result.
+    """
+    strides = strides or []
+    return _make.strided_slice(data, list(begin), list(end), list(strides))
diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc
index 75365da5b..3525e23b8 100644
--- a/src/api/api_lang.cc
+++ b/src/api/api_lang.cc
@@ -47,7 +47,11 @@ TVM_REGISTER_API("_Array")
 .set_body([](TVMArgs args,  TVMRetValue* ret) {
     std::vector<NodePtr<Node> > data;
     for (int i = 0; i < args.size(); ++i) {
-      data.push_back(args[i].node_sptr());
+      if (args[i].type_code() != kNull) {
+        data.push_back(args[i].node_sptr());
+      } else {
+        data.push_back(NodePtr<Node>(nullptr));
+      }
     }
     auto node = make_node<ArrayNode>();
     node->data = std::move(data);
diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc
index 93ed76bed..bfc5f0db5 100644
--- a/src/relay/ir/text_printer.cc
+++ b/src/relay/ir/text_printer.cc
@@ -403,7 +403,11 @@ class TextPrinter :
    * \param os The output type.
    */
   void PrintAttr(const NodeRef& value, std::ostream& os) {  // NOLINT(*)
-    this->VisitAttr(value, os);
+    if (value.defined()) {
+      this->VisitAttr(value, os);
+    } else {
+      os << "None";
+    }
   }
   //------------------------------------
   // Overload of Attr printing functions
diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc
index 20e0e3adb..98ac1c30b 100644
--- a/src/relay/op/tensor/transform.cc
+++ b/src/relay/op/tensor/transform.cc
@@ -7,6 +7,7 @@
 #include <tvm/relay/attrs/transform.h>
 #include <tvm/ir_operator.h>
 #include <tvm/ir.h>
+#include <topi/transform.h>
 #include <vector>
 #include "../op_common.h"
 
@@ -890,6 +891,173 @@ RELAY_REGISTER_OP("broadcast_to_like")
 .set_support_level(10)
 .add_type_rel("BroadCastToLike", BroadCastToLikeRel);
 
+
+// strided_slice
+TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
+bool StridedSliceRel(const Array<Type>& types,
+                     int num_inputs,
+                     const Attrs& attrs,
+                     const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+
+  const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
+  CHECK(param != nullptr);
+
+  auto dshape = data->shape;
+  auto num_axis = dshape.size();
+
+  std::vector<int64_t> stride_vec;
+  for (Integer i : param->strides) {
+    CHECK(i.defined());
+    stride_vec.push_back(i->value);
+  }
+  for (size_t i = stride_vec.size(); i < num_axis; ++i) {
+    stride_vec.push_back(1);
+  }
+  const int64_t max_range = std::numeric_limits<int64_t>::max();
+
+  std::vector<int64_t> begin_vec;
+  for (size_t i = 0; i < param->begin.size(); ++i) {
+    if (!param->begin[i].defined()) {
+      // value=None
+      begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
+    } else {
+      begin_vec.push_back(param->begin[i]->value);
+    }
+  }
+  for (size_t i = begin_vec.size(); i < num_axis; ++i) {
+    begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
+  }
+
+  std::vector<int64_t> end_vec;
+  for (size_t i = 0; i < param->end.size(); ++i) {
+    // allow end to be None
+    if (!param->end[i].defined()) {
+      end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
+    } else {
+      end_vec.push_back(param->end[i]->value);
+    }
+  }
+  for (size_t i = end_vec.size(); i < num_axis; ++i) {
+    end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
+  }
+
+  std::vector<IndexExpr> oshape(dshape.size());
+  for (size_t i = 0; i < num_axis; ++i) {
+    int64_t stride_v = stride_vec[i];
+    int64_t begin_v = begin_vec[i];
+    int64_t end_v = end_vec[i];
+
+    if ((stride_v == 1 &&
+         begin_v == 0 &&
+         end_v == max_range) ||
+        (stride_v == -1 &&
+         begin_v == max_range &&
+         end_v == 0)) {
+      // Quick path, do not slice this dimension.
+      oshape[i] = dshape[i];
+      continue;
+    }
+    // Normal path, require the shape to be concrete integer.
+    // Require concrete integer as symbolic inference of min/max
+    // can get complicated and not very helpful.
+    const int64_t* p_dim_size = as_const_int(dshape[i]);
+    CHECK(p_dim_size)
+        << "strided_slice requires sliced dimension to be concrete int";
+    int64_t dim_size = p_dim_size[0];
+    begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v;
+    end_v = (end_v < 0) ? dim_size + end_v : end_v;
+
+    int64_t slice_range, step;
+    if (stride_v < 0) {
+      if (end_v < -1) end_v = -1;
+      CHECK_LT(end_v, begin_v)
+          << "strided_slice get empty slice at axis " << i;
+      begin_v = std::min(dim_size - 1, begin_v);
+      slice_range = begin_v - end_v;
+      step = -stride_v;
+    } else {
+      if (begin_v < 0) begin_v = 0;
+      CHECK_GE(stride_v, 0);
+      CHECK_LT(begin_v, end_v)
+          << "strided_slice get empty slice at axis " << i;
+      end_v = std::min(dim_size, end_v);
+      slice_range = end_v - begin_v;
+      step = stride_v;
+    }
+    oshape[i] = make_const(dshape[i].type(), (slice_range + step - 1) / step);
+  }
+  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  return true;
+}
+
+
+// Positional relay function to create StridedSlice operator used by frontend FFI.
+Expr MakeStridedSlice(Expr data,
+                      Array<Integer> begin,
+                      Array<Integer> end,
+                      Array<Integer> strides) {
+  auto attrs = make_node<StridedSliceAttrs>();
+  attrs->begin = std::move(begin);
+  attrs->end = std::move(end);
+  attrs->strides = std::move(strides);
+  static const Op& op = Op::Get("strided_slice");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+Array<Tensor> StridedSliceCompute(const Attrs& attrs,
+                                  const Array<Tensor>& inputs,
+                                  const Type& out_type,
+                                  const Target& target) {
+  const StridedSliceAttrs *param = attrs.as<StridedSliceAttrs>();
+  CHECK(param != nullptr);
+  return Array<Tensor>{
+    topi::strided_slice(inputs[0], param->begin, param->end, param->strides)
+  };
+}
+
+
+TVM_REGISTER_API("relay.op._make.strided_slice")
+  .set_body([](const TVMArgs& args, TVMRetValue* rv) {
+      runtime::detail::unpack_call<Expr, 4>(MakeStridedSlice, args, rv);
+  });
+
+
+RELAY_REGISTER_OP("strided_slice")
+    .describe(R"code(Strided slice of an array.
+
+Examples::
+
+  x = [[  1.,   4.,   7.,  10.],
+       [  2.,   5.,   8.,  11.],
+       [  3.,   6.,   9.,  12.]]
+
+  strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4.,  7.,  10.],
+                                                               [ 5.,  8.,  11.]]
+
+  x = [[[ 1.,  2.],
+        [ 3.,  4.]],
+
+       [[ 5.,  6.],
+        [ 7.,  8.]]]
+
+  strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1.,  2.],
+                                                 [ 3.,  4.]],
+
+                                                [[ 5.,  6.],
+                                                 [ 7.,  8.]]]
+)code" TVM_ADD_FILELINE)
+.set_num_inputs(1)
+.add_argument("data", "Tensor", "The input tensor.")
+.set_support_level(4)
+.set_attrs_type_key("relay.attrs.StridedSliceAttrs")
+.add_type_rel("StridedSlice", StridedSliceRel)
+.set_attr<FTVMCompute>("FTVMCompute", StridedSliceCompute)
+.set_attr<TOpPattern>("TOpPattern", kInjective);
+
+
 // Split
 TVM_REGISTER_NODE_TYPE(SplitAttrs);
 
diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py
index 6fd70c386..dd12dc7cf 100644
--- a/tests/python/relay/test_op_level4.py
+++ b/tests/python/relay/test_op_level4.py
@@ -2,7 +2,7 @@ import tvm
 import numpy as np
 from tvm import relay
 from tvm.relay.testing import ctx_list
-
+import topi.testing
 
 def test_binary_op():
     def check_binary_op(opfunc, ref):
@@ -142,7 +142,43 @@ def test_reduce_functions():
         verify_reduce(func, (128, 24, 128), (0, 1), True, False, (1, 1, 128))
         verify_reduce(func, (128, 24, 128), (0, 2), True, False, (1, 24, 1))
 
+
+def test_strided_slice():
+    def verify(dshape, begin, end, strides, output, test_ref=True):
+        x = relay.var("x", relay.TensorType(dshape, "float32"))
+        z = relay.strided_slice(x, begin=begin, end=end, strides=strides)
+        func = relay.Function([x], z)
+        func = relay.ir_pass.infer_type(func)
+        text = func.astext()
+        assert "begin=" in text
+        assert "end=" in text
+        if output:
+            assert func.body.checked_type == relay.ty.TensorType(output, "float32")
+        if not test_ref:
+            return
+        x_data = np.random.uniform(size=dshape).astype("float32")
+        ref_res = topi.testing.strided_slice_python(
+            x_data, begin, end, strides)
+        for target, ctx in ctx_list():
+            intrp = relay.create_executor("graph", ctx=ctx, target=target)
+            op_res = intrp.evaluate(func)(x_data)
+            tvm.testing.assert_allclose(op_res.asnumpy(), ref_res)
+
+    d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
+    verify((d1, d2, 3), [None, None, 1], [None, None, 2], None, (d1, d2, 1), False)
+    verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2))
+    verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3))
+    verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3))
+    verify((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], (1, 2, 2))
+    verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
+    verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3))
+    verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3))
+    verify((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3))
+    verify((3, 4, 3), [1, 1], [4, 4, 3], None, (2, 3, 3))
+
+
 if __name__ == "__main__":
+    test_strided_slice()
     test_binary_op()
     test_cmp_type()
     test_binary_int_broadcast()
diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h
index 7fc408c2c..cb09f1cb4 100644
--- a/topi/include/topi/transform.h
+++ b/topi/include/topi/transform.h
@@ -10,6 +10,7 @@
 #include <vector>
 #include <iterator>
 #include <algorithm>
+#include <limits>
 
 #include "topi/tags.h"
 #include "topi/detail/ravel_unravel.h"
@@ -403,31 +404,51 @@ inline Array<Tensor> split(const Tensor& x,
 * \return A Tensor whose op member is the split operation
 */
 inline Tensor strided_slice(const Tensor& x,
-                            const Array<Expr>& begin,
-                            const Array<Expr>& end,
-                            const Array<Expr>& strides,
+                            const Array<Integer>& begin,
+                            const Array<Integer>& end,
+                            const Array<Integer>& strides,
                             std::string name = "tensor",
                             std::string tag = kInjective) {
   size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
-  std::vector<int64_t> begin_vec = GetConstInt64Values(begin, "begin");
-  std::vector<int64_t> end_vec = GetConstInt64Values(end, "end");
-  std::vector<int64_t> stride_vec = GetConstInt64Values(strides, "strides");
-  // in case user has not provided begin indices for all the axes,
-  // then inflate it with default value = 0
-  for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) {
-    begin_vec.push_back(0);
-  }
-  // in case user has not provided end indices for all the axes,
-  // then inflate it with default value = input_tensor.shape[axis]
-  for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) {
-    end_vec.push_back(GetConstInt(x->shape[i]));
+  // Setup the ranges.
+  // NOTE: this code duplicates the shape inference logic relay.op
+  // Consider to refactor in the future.
+  std::vector<int64_t> stride_vec;
+  for (Integer i : strides) {
+    CHECK(i.defined());
+    stride_vec.push_back(i->value);
   }
-  // in case user has not provided stride values,
-  // then inflate it with default value = 1
   for (size_t i = stride_vec.size(); i < src_tensor_dim; ++i) {
     stride_vec.push_back(1);
   }
+  const int64_t max_range = std::numeric_limits<int64_t>::max();
+
+  std::vector<int64_t> begin_vec;
+  for (size_t i = 0; i < begin.size(); ++i) {
+    if (!begin[i].defined()) {
+      // value=None
+      begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
+    } else {
+      begin_vec.push_back(begin[i]->value);
+    }
+  }
+  for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) {
+    begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range);
+  }
 
+  std::vector<int64_t> end_vec;
+  for (size_t i = 0; i < end.size(); ++i) {
+    // allow end to be None
+    if (!end[i].defined()) {
+      end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
+    } else {
+      end_vec.push_back(end[i]->value);
+    }
+  }
+  for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) {
+    end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range);
+  }
+  // Compute
   Array<Expr> out_shape;
   Array<Expr> begin_expr;
   Array<Expr> strides_expr;
diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py
index 8a3269ba8..c496e08c1 100644
--- a/topi/python/topi/testing/__init__.py
+++ b/topi/python/topi/testing/__init__.py
@@ -19,3 +19,4 @@ from .shortcut_python import shortcut_python
 from .lrn_python import lrn_python
 from .l2_normalize_python import l2_normalize_python
 from .gather_nd_python import gather_nd_python
+from .strided_slice_python import strided_slice_python
diff --git a/topi/python/topi/testing/strided_slice_python.py b/topi/python/topi/testing/strided_slice_python.py
new file mode 100644
index 000000000..4407b3bec
--- /dev/null
+++ b/topi/python/topi/testing/strided_slice_python.py
@@ -0,0 +1,32 @@
+"""gather_nd in python"""
+
+def strided_slice_python(data, begin, end, strides):
+    """Python version of strided slice operator.
+
+    Parameters
+    ----------
+    data : numpy.ndarray
+        Input data
+
+    begin : list
+        Begining of the slices.
+
+    end : list
+        End of the slices.
+
+    strides : list
+        The stride of each slice.
+
+    Returns
+    -------
+    result : numpy.ndarray
+        The sliced result.
+    """
+    strides = [] if strides is None else strides
+    slices = []
+    for i in range(len(data.shape)):
+        slices.append(slice(
+            begin[i] if i < len(begin) else None,
+            end[i] if i < len(end) else None,
+            strides[i] if i < len(strides) else None))
+    return data[tuple(slices)]
diff --git a/topi/tests/python/test_topi_transform.py b/topi/tests/python/test_topi_transform.py
index 75e4d3b67..dc3c3fb70 100644
--- a/topi/tests/python/test_topi_transform.py
+++ b/topi/tests/python/test_topi_transform.py
@@ -249,13 +249,11 @@ def verify_take(src_shape, indices_src, axis=None):
     for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
         check_device(device)
 
-def verify_strided_slice(in_shape, begin, end, stride=None):
-    stride = stride if stride else [1, 1, 1]
+def verify_strided_slice(in_shape, begin, end, strides=None):
     A = tvm.placeholder(shape=in_shape, name="A")
-    B = topi.strided_slice(A, begin, end, stride) + 1
-    def test_forward(x, begin, end, stride):
-        return x[begin[0]:end[0]:stride[0],
-                    begin[1]:end[1]:stride[1], begin[2]:end[2]:stride[2]] + 1
+    strides = [1,1,1] if strides is None else strides
+    B = topi.strided_slice(A, begin, end, strides) + 1
+
     def check_device(device):
         ctx = tvm.context(device, 0)
         if not ctx.exist:
@@ -267,7 +265,8 @@ def verify_strided_slice(in_shape, begin, end, stride=None):
 
         foo = tvm.build(s, [A, B], device, name="stride_slice")
         x_np = np.random.uniform(size=in_shape).astype(A.dtype)
-        out_npy = test_forward(x_np, begin, end, stride)
+        out_npy = topi.testing.strided_slice_python(
+            x_np, begin, end, strides) + 1
         data_nd = tvm.nd.array(x_np, ctx)
         out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=A.dtype)
         foo(data_nd, out_nd)
@@ -298,7 +297,7 @@ def verify_gather_nd(src_shape, indices_src, indices_dtype):
             shape_size = shape_size * src_shape[i]
         data_npy = np.arange(shape_size, dtype=src_dtype).reshape((src_shape))
         out_npys = topi.testing.gather_nd_python(data_npy, indices_src)
-        
+
         data_nd = tvm.nd.array(data_npy, ctx)
         indices_nd = tvm.nd.array(indices_src, ctx)
         out_nd = tvm.nd.empty(out_npys.shape, ctx=ctx, dtype=src_dtype)
@@ -412,6 +411,7 @@ def test_gather_nd():
                          indices_dtype)
 
 if __name__ == "__main__":
+    test_strided_slice()
     test_concatenate()
     test_tranpose()
     test_expand_dims()
@@ -421,5 +421,4 @@ if __name__ == "__main__":
     test_flip()
     test_expand_like()
     test_take()
-    test_strided_slice()
     test_gather_nd()
-- 
GitLab