From 4fb58115eaf7d4b70873d95d80d4249c78eb93c8 Mon Sep 17 00:00:00 2001 From: Pariksheet Pinjari <pariksheet.pinjari@huawei.com> Date: Thu, 28 Jun 2018 20:35:33 +0530 Subject: [PATCH] Strided_slice added in NNVM (#1318) --- nnvm/include/nnvm/top/tensor.h | 16 +++ nnvm/python/nnvm/top/transform.py | 4 + nnvm/src/top/tensor/transform.cc | 113 ++++++++++++++++++ nnvm/tests/python/compiler/test_top_level1.py | 36 ++++++ 4 files changed, 169 insertions(+) diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h index dca02cae9..87128a663 100644 --- a/nnvm/include/nnvm/top/tensor.h +++ b/nnvm/include/nnvm/top/tensor.h @@ -48,6 +48,22 @@ struct SplitParam : public dmlc::Parameter<SplitParam> { } }; +struct StridedSliceParam : public dmlc::Parameter<StridedSliceParam> { + // numpy convention, only support indices, not support list. + Tuple<int64_t> begin; + Tuple<int64_t> end; + Tuple<int64_t> stride; + + DMLC_DECLARE_PARAMETER(StridedSliceParam) { + DMLC_DECLARE_FIELD(begin) + .describe("Indices for begin of slice"); + DMLC_DECLARE_FIELD(end) + .describe("Indices for end of the slice"); + DMLC_DECLARE_FIELD(stride).set_default(Tuple<int64_t>()) + .describe("Stride values of the slice"); + } +}; + enum TypeFlag { kFloat32 = 0, kFloat64 = 1, diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py index 6900f85f9..c87b4735d 100644 --- a/nnvm/python/nnvm/top/transform.py +++ b/nnvm/python/nnvm/top/transform.py @@ -61,6 +61,10 @@ reg.register_schedule("concatenate", _fschedule_injective) reg.register_pattern("split", OpPattern.INJECTIVE) reg.register_schedule("split", _fschedule_injective) +# strided_slice +reg.register_pattern("strided_slice", OpPattern.INJECTIVE) +reg.register_schedule("strided_slice", _fschedule_injective) + # slice_like reg.register_pattern("slice_like", OpPattern.INJECTIVE) reg.register_schedule("slice_like", _fschedule_injective) diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 05d3cfcb5..0b0beb5b7 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -829,6 +829,119 @@ Examples:: }; }); +// strided_slice +DMLC_REGISTER_PARAMETER(StridedSliceParam); + +inline void StridedSliceParamParser(nnvm::NodeAttrs* attrs) { + StridedSliceParam param; + param.Init(attrs->dict); + attrs->parsed = std::move(param); +} + +inline bool StridedSliceInferShape(const NodeAttrs& attrs, + std::vector<TShape>* in_shape, + std::vector<TShape>* out_shape) { + const StridedSliceParam& param = nnvm::get<StridedSliceParam>(attrs.parsed); + const TShape& dshape = (*in_shape)[0]; + if (dshape.ndim() == 0) return false; + TShape oshape = dshape; + dim_t num_axis = dshape.ndim(); + + std::vector<int64_t> begin_vec; + std::copy(param.begin.begin(), param.begin.end(), std::back_inserter(begin_vec)); + for (dim_t i = begin_vec.size(); i < num_axis; ++i) { + begin_vec.push_back(0); + } + + std::vector<int64_t> end_vec; + std::copy(param.end.begin(), param.end.end(), std::back_inserter(end_vec)); + for (dim_t i = end_vec.size(); i < num_axis; ++i) { + end_vec.push_back(dshape[i]); + } + + std::vector<int64_t> stride_vec; + std::copy(param.stride.begin(), param.stride.end(), std::back_inserter(stride_vec)); + for (dim_t i = stride_vec.size(); i < num_axis; ++i) { + stride_vec.push_back(1); + } + + for (dim_t i = 0; i < num_axis; ++i) { + int64_t begin_range = stride_vec[i] < 0 ? -1 : 0; + int64_t end_range = stride_vec[i] < 0 ? dshape[i] - 1 : dshape[i]; + int64_t begin = begin_vec[i] < 0 ? dshape[i] + begin_vec[i] : begin_vec[i]; + int64_t end = end_vec[i] < 0 ? dshape[i] + end_vec[i] : end_vec[i]; + begin = std::min(std::max(begin, begin_range), end_range); + end = std::min(std::max(end, begin_range), end_range); + + int interval = std::abs(end - begin); + int slice_size = static_cast<int>((interval + + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); + CHECK(stride_vec[i] < 0 ? (end < begin) : (begin < end)) + << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] + << "] is invalid for axis=" << i; + oshape[i] = slice_size; + } + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); + return true; +} + +NNVM_REGISTER_OP(strided_slice) +.describe(R"code(Strided slice of an array. + +Examples:: + + x = [[ 1., 4., 7., 10.], + [ 2., 5., 8., 11.], + [ 3., 6., 9., 12.]] + + strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4., 7., 10.], + [ 5., 8., 11.]] + + x = [[[ 1., 2.], + [ 3., 4.]], + + [[ 5., 6.], + [ 7., 8.]]] + + strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 2.], + [ 3., 4.]], + + [[ 5., 6.], + [ 7., 8.]]] +)code" NNVM_ADD_FILELINE) +.add_argument("data", "Tensor", "Array to be sliced") +.add_arguments(StridedSliceParam::__FIELDS__()) +.set_attr_parser(StridedSliceParamParser) +.set_attr<FInferShape>("FInferShape", StridedSliceInferShape) +.set_attr<FInferType>("FInferType", ElemwiseType<1, 1>) +.set_attr<FCorrectLayout>("FCorrectLayout", ElemwiseArbitraryLayout<1, 1>) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const StridedSliceParam& param = nnvm::get<StridedSliceParam>(attrs.parsed); + Array<Expr> begin; + Array<Expr> end; + Array<Expr> stride; + + for (int64_t i : param.begin) { + begin.push_back(tvm::make_const(tvm::Int(32), i)); + } + + for (int64_t i : param.end) { + end.push_back(tvm::make_const(tvm::Int(32), i)); + } + + for (int64_t i : param.stride) { + stride.push_back(tvm::make_const(tvm::Int(32), i)); + } + + return Array<Tensor>{ topi::strided_slice(inputs[0], begin, end, stride) }; +}) +.set_support_level(1); + // Flip DMLC_REGISTER_PARAMETER(FlipParam); diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py index 73391c80d..b97aff8ef 100644 --- a/nnvm/tests/python/compiler/test_top_level1.py +++ b/nnvm/tests/python/compiler/test_top_level1.py @@ -329,6 +329,41 @@ def test_split(): verify_split((5, 3), [3], axis=0) verify_split((5, 9, 3), [3, 4], axis=1) +def verify_strided_slice(ishape, begin, end, strideinp=None): + stride = strideinp if strideinp else [1, 1, 1] + x = sym.Variable("x") + if strideinp: + y = sym.strided_slice(x, begin = begin, end = end, stride = stride) + 1 + else: + y = sym.strided_slice(x, begin = begin, end = end) + 1 + x_np = np.random.uniform(size=ishape).astype("float32") + for i in range(len(begin), 3): + begin.append(0) + for i in range(len(end), 3): + end.append(ishape[i]) + def test_forward(x, begin, end, stride): + return x[begin[0]:end[0]:stride[0], + begin[1]:end[1]:stride[1], begin[2]:end[2]:stride[2]] + 1 + + for target, ctx in ctx_list(): + # set input + graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape}) + m = graph_runtime.create(graph, lib, ctx) + m.run(x=x_np) + res = test_forward(x_np, begin, end, stride) + out = m.get_output(0, tvm.nd.empty(res.shape)) + np.testing.assert_allclose(out.asnumpy(), res, atol=1e-5, rtol=1e-5) + +def test_strided_slice(): + verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) + verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1]) + verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2]) + verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1]) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3]) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3]) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4]) + verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3]) def verify_squeeze(dshape, axis): x = sym.Variable("x") @@ -448,3 +483,4 @@ if __name__ == "__main__": test_pad() test_lrn() test_l2_normalize() + test_strided_slice() -- GitLab