From 891b4e06ced2eb91fb61b47ea0e9027cdf2f458b Mon Sep 17 00:00:00 2001 From: Pariksheet Pinjari <pariksheet.pinjari@huawei.com> Date: Sat, 26 May 2018 04:44:44 +0530 Subject: [PATCH] Flip operator (#505) --- nnvm/include/nnvm/top/tensor.h | 8 +++ nnvm/python/nnvm/top/transform.py | 4 ++ nnvm/src/top/tensor/transform.cc | 49 +++++++++++++++++++ nnvm/tests/python/compiler/test_top_level4.py | 23 +++++++++ 4 files changed, 84 insertions(+) diff --git a/nnvm/include/nnvm/top/tensor.h b/nnvm/include/nnvm/top/tensor.h index 00bad8245..80947bd23 100644 --- a/nnvm/include/nnvm/top/tensor.h +++ b/nnvm/include/nnvm/top/tensor.h @@ -156,6 +156,14 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> { } }; +struct FlipParam : public dmlc::Parameter<FlipParam> { + int axis; + DMLC_DECLARE_PARAMETER(FlipParam) { + DMLC_DECLARE_FIELD(axis).set_default(0) + .describe("the axis to be reveresed."); + } +}; + struct BroadcastToParam : public dmlc::Parameter<BroadcastToParam> { TShape shape; diff --git a/nnvm/python/nnvm/top/transform.py b/nnvm/python/nnvm/top/transform.py index c3ceb6868..b4b8779f2 100644 --- a/nnvm/python/nnvm/top/transform.py +++ b/nnvm/python/nnvm/top/transform.py @@ -41,6 +41,10 @@ reg.register_schedule("reshape_like", _fschedule_injective) reg.register_pattern("transpose", OpPattern.INJECTIVE) reg.register_schedule("transpose", _fschedule_injective) +# flip +reg.register_pattern("flip", OpPattern.INJECTIVE) +reg.register_schedule("flip", _fschedule_injective) + # reshape reg.register_pattern("reshape", OpPattern.INJECTIVE) reg.register_schedule("reshape", _fschedule_injective) diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 48f8428d6..bdc8dc5a9 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -830,5 +830,54 @@ Examples:: }; }); +// Flip +DMLC_REGISTER_PARAMETER(FlipParam); + +NNVM_REGISTER_OP(flip) +.describe(R"code(Reverse the elements of an array. + +Examples:: + + x = [[ 1, 2], + [ 3, 4]] + + flip(x) = [[ 3., 4.], + [ 1., 2.]] + + x = [[[ 1., 2.], + [ 3., 4.]], + + [[ 5., 6.], + [ 7., 8.]]] + + flip(x) = [[[ 5., 6.], + [ 7., 8.]], + + [[ 1., 2.], + [ 3., 4.]]] + + flip(x, axis=1) = [[[ 3., 4.], + [ 1., 2.]], + + [[ 7., 8.], + [ 5., 6.]]] +)code" NNVM_ADD_FILELINE) +.add_argument("data", "Tensor", "Source input") +.add_arguments(FlipParam::__FIELDS__()) +.set_attr_parser(ParamParser<FlipParam>) +.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<FlipParam>) +.set_attr<nnvm::FInferShape>("FInferShape", ElemwiseShape<1, 1>) +.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>) +.set_num_inputs(1) +.set_num_outputs(1) +.set_support_level(4) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const FlipParam& param = nnvm::get<FlipParam>(attrs.parsed); + return Array<Tensor>{ topi::flip(inputs[0], param.axis) }; +}); + } // namespace top } // namespace nnvm diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index c6e8620fc..819768cfb 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -90,6 +90,28 @@ def test_reduce(): verify_reduce((4, 4, 3), np.min, sym.min, keepdims=True) verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2)) +def verify_flip(ishape, axis): + x = sym.Variable("x") + y = sym.flip(x, axis=axis) + 1 + dtype = "float32" + x_np = np.random.uniform(size=ishape).astype(dtype) + res = np.flip(x_np, axis) + 1 + + for target, ctx in ctx_list(): + # set input + graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape}) + m = graph_runtime.create(graph, lib, ctx) + m.run(x=x_np) + out = m.get_output(0, tvm.nd.empty(res.shape)) + np.testing.assert_allclose(out.asnumpy(), res, atol=1e-5, rtol=1e-5) + +def test_flip(): + verify_flip((3, 4, 3), 1) + verify_flip((3, 4, 3), 0) + verify_flip((3, 4, 3), 2) + verify_flip((3, 4, 3), -1) + verify_flip((3, 4, 3), -3) + verify_flip((3, 4, 3), -2) def verify_reshape(dshape, oshape): x = sym.Variable("x") @@ -347,4 +369,5 @@ if __name__ == "__main__": test_elemwise_sum() test_block_grad() test_full() + test_flip() print(nnvm.compiler.engine.dump()) -- GitLab