From 4b765c511d588d3a21b39d82d80451a0b0742a9f Mon Sep 17 00:00:00 2001 From: Pariksheet Pinjari <pariksheet.pinjari@huawei.com> Date: Sat, 21 Apr 2018 09:27:04 +0530 Subject: [PATCH] [OP] PReLU Support (#394) --- nnvm/include/nnvm/top/nn.h | 7 +++ nnvm/python/nnvm/top/nn.py | 3 ++ nnvm/src/top/nn/nn.cc | 47 +++++++++++++++++++ nnvm/tests/python/compiler/test_top_level1.py | 39 +++++++++++++++ .../tests/python/unittest/test_infer_shape.py | 12 +++++ nnvm/tests/python/unittest/test_top_level3.py | 7 +++ 6 files changed, 115 insertions(+) diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h index a8d443241..7eb2e5e11 100644 --- a/nnvm/include/nnvm/top/nn.h +++ b/nnvm/include/nnvm/top/nn.h @@ -101,6 +101,13 @@ struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> { } }; +struct PReLUParam : public dmlc::Parameter<PReLUParam> { + int axis; + DMLC_DECLARE_PARAMETER(PReLUParam) { + DMLC_DECLARE_FIELD(axis).set_default(1) + .describe("Specify which shape axis the channel is specified."); + } +}; struct PadParam : public dmlc::Parameter<PadParam> { float pad_value; diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py index f6dd6f562..d7e6e5a08 100644 --- a/nnvm/python/nnvm/top/nn.py +++ b/nnvm/python/nnvm/top/nn.py @@ -18,6 +18,9 @@ reg.register_pattern("relu", OpPattern.ELEMWISE) reg.register_schedule("leaky_relu", _fschedule_broadcast) reg.register_pattern("leaky_relu", OpPattern.ELEMWISE) +# prelu +reg.register_schedule("prelu", _fschedule_broadcast) +reg.register_pattern("prelu", OpPattern.BROADCAST) # flatten reg.register_schedule("flatten", _fschedule_broadcast) diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc index 5ac4b7662..e3755d952 100644 --- a/nnvm/src/top/nn/nn.cc +++ b/nnvm/src/top/nn/nn.cc @@ -417,6 +417,53 @@ NNVM_REGISTER_OP(leaky_relu) }) .set_support_level(1); +// prelu +DMLC_REGISTER_PARAMETER(PReLUParam); + +inline bool PReluInferShape(const nnvm::NodeAttrs &attrs, + std::vector<TShape> *in_shape, + std::vector<TShape> *out_shape) { + const PReLUParam ¶m = nnvm::get<PReLUParam>(attrs.parsed); + TShape dshape = in_shape->at(0); + NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape); + + // The case of parametric relu + CHECK_EQ(dshape.ndim(), 4) << "Input data should be 4D, but got " << dshape.ndim(); + CHECK(size_t(param.axis) < dshape.Size()) + << "Wrong axis (" << param.axis << ")value."; + + NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 1, TShape({dshape[param.axis]})); + + TShape oshape(dshape); + NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape); + return true; +} + +NNVM_REGISTER_OP(prelu) +.describe(R"code(Parametric version of a Rectified Linear Unit. +It accepts two arguments: an input ``x`` and a channelwise slope ``alpha`` +and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`, +where :math:`*` is an channelwise multiplication for each sample in the + +)code" NNVM_ADD_FILELINE) +.add_argument("data", "Tensor", "Input data.") +.add_argument("alpha", "Tensor", "Input channelwise alpha.") +.add_arguments(PReLUParam::__FIELDS__()) +.set_attr_parser(ParamParser<PReLUParam>) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr<FInferShape>("FInferShape", PReluInferShape) +.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector<std::string>{"data", "alpha"}; + }) +.set_attr<FTVMCompute>( + "FTVMCompute", [](const NodeAttrs& attrs, + const Array<Tensor>& inputs, + const Array<Tensor>& out_info) { + const PReLUParam& param = nnvm::get<PReLUParam>(attrs.parsed); + return Array<Tensor>{ topi::prelu<float>(inputs[0], inputs[1], param.axis)}; + }) +.set_support_level(4); DMLC_REGISTER_PARAMETER(PadParam); diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py index 480aa271a..ebf8c6ce6 100644 --- a/nnvm/tests/python/compiler/test_top_level1.py +++ b/nnvm/tests/python/compiler/test_top_level1.py @@ -64,6 +64,43 @@ def test_relu(): inputs = [('x', dshape, x)] helper(y, inputs, dtype, forward, backward) +def test_prelu_nchw(): + x = sym.Variable("x") + a = sym.Variable("a") + y = sym.prelu(data=x, alpha=a) + + def forward(x, a): + return (x < 0) * (x * a.reshape(3, 1, 1)) + (x>=0) * x + + dtype = "float32" + dshape_x = (1, 3, 32, 32) + dshape_w = (3,) + + inputs = [ + ('x', dshape_x, x), + ('a', dshape_w, a) + ] + helper(y, inputs, dtype, forward) + +def test_prelu_nhwc(): + x = sym.Variable("x") + a = sym.Variable("a") + y = sym.prelu(data=x, alpha=a, axis=3) + + def forward(x, a): + return (x < 0) * (x * a.reshape(1, 1, 3)) + (x>=0) * x + + dtype = "float32" + dshape_x = (1, 32, 32, 3) + dshape_w = (3,) + + inputs = [ + ('x', dshape_x, x), + ('a', dshape_w, a) + ] + + + helper(y, inputs, dtype, forward) def test_sym_scalar_pow(): scalar = 3 @@ -336,6 +373,8 @@ if __name__ == "__main__": test_batchnorm() test_dense() test_relu() + test_prelu_nchw() + test_prelu_nhwc() test_sym_scalar_pow() test_scalar_sym_pow() test_exp() diff --git a/nnvm/tests/python/unittest/test_infer_shape.py b/nnvm/tests/python/unittest/test_infer_shape.py index 226dedd3d..8011e96f3 100644 --- a/nnvm/tests/python/unittest/test_infer_shape.py +++ b/nnvm/tests/python/unittest/test_infer_shape.py @@ -250,6 +250,17 @@ def test_reshape(): check((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4)) +def test_prelu(): + def check(in_shape, axis, out_shape): + x = sym.Variable("x", shape=in_shape) + w = sym.Variable("w") + y = sym.prelu(x, w, axis=axis, name="y") + sdict = infer_shape(y) + assert(tuple(sdict["y"][0]) == tuple(out_shape)) + check((1, 3, 2, 2), 1, (1, 3, 2, 2)) + check((1, 2, 2, 3), 3, (1, 2, 2, 3)) + + # Level 4 def test_transpose(): def check(in_shape, out_shape, **kwargs): @@ -319,3 +330,4 @@ if __name__ == "__main__": test_broadcast_binary() test_reduce() test_transpose() + test_prelu() diff --git a/nnvm/tests/python/unittest/test_top_level3.py b/nnvm/tests/python/unittest/test_top_level3.py index e37863b2b..47e7a8bce 100644 --- a/nnvm/tests/python/unittest/test_top_level3.py +++ b/nnvm/tests/python/unittest/test_top_level3.py @@ -16,8 +16,15 @@ def test_leaky_relu(): y = sym.leaky_relu(x, alpha=0.1) assert(y.list_input_names() == ["x"]) +def test_prelu(): + x = sym.Variable("x") + w = sym.Variable("w") + y = sym.prelu(x, w) + assert(y.list_input_names()[0] == 'x') + assert(y.list_input_names()[1] == 'w') if __name__ == "__main__": test_scalar_op() test_reshape() test_leaky_relu() + test_prelu() -- GitLab