From 4b765c511d588d3a21b39d82d80451a0b0742a9f Mon Sep 17 00:00:00 2001
From: Pariksheet Pinjari <pariksheet.pinjari@huawei.com>
Date: Sat, 21 Apr 2018 09:27:04 +0530
Subject: [PATCH] [OP] PReLU Support (#394)

---
 nnvm/include/nnvm/top/nn.h                    |  7 +++
 nnvm/python/nnvm/top/nn.py                    |  3 ++
 nnvm/src/top/nn/nn.cc                         | 47 +++++++++++++++++++
 nnvm/tests/python/compiler/test_top_level1.py | 39 +++++++++++++++
 .../tests/python/unittest/test_infer_shape.py | 12 +++++
 nnvm/tests/python/unittest/test_top_level3.py |  7 +++
 6 files changed, 115 insertions(+)

diff --git a/nnvm/include/nnvm/top/nn.h b/nnvm/include/nnvm/top/nn.h
index a8d443241..7eb2e5e11 100644
--- a/nnvm/include/nnvm/top/nn.h
+++ b/nnvm/include/nnvm/top/nn.h
@@ -101,6 +101,13 @@ struct LeakyReLUParam : public dmlc::Parameter<LeakyReLUParam> {
   }
 };
 
+struct PReLUParam : public dmlc::Parameter<PReLUParam> {
+  int axis;
+  DMLC_DECLARE_PARAMETER(PReLUParam) {
+    DMLC_DECLARE_FIELD(axis).set_default(1)
+      .describe("Specify which shape axis the channel is specified.");
+  }
+};
 
 struct PadParam : public dmlc::Parameter<PadParam> {
   float pad_value;
diff --git a/nnvm/python/nnvm/top/nn.py b/nnvm/python/nnvm/top/nn.py
index f6dd6f562..d7e6e5a08 100644
--- a/nnvm/python/nnvm/top/nn.py
+++ b/nnvm/python/nnvm/top/nn.py
@@ -18,6 +18,9 @@ reg.register_pattern("relu", OpPattern.ELEMWISE)
 reg.register_schedule("leaky_relu", _fschedule_broadcast)
 reg.register_pattern("leaky_relu", OpPattern.ELEMWISE)
 
+# prelu
+reg.register_schedule("prelu", _fschedule_broadcast)
+reg.register_pattern("prelu", OpPattern.BROADCAST)
 
 # flatten
 reg.register_schedule("flatten", _fschedule_broadcast)
diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc
index 5ac4b7662..e3755d952 100644
--- a/nnvm/src/top/nn/nn.cc
+++ b/nnvm/src/top/nn/nn.cc
@@ -417,6 +417,53 @@ NNVM_REGISTER_OP(leaky_relu)
 })
 .set_support_level(1);
 
+// prelu
+DMLC_REGISTER_PARAMETER(PReLUParam);
+
+inline bool PReluInferShape(const nnvm::NodeAttrs &attrs,
+                            std::vector<TShape> *in_shape,
+                            std::vector<TShape> *out_shape) {
+  const PReLUParam &param = nnvm::get<PReLUParam>(attrs.parsed);
+  TShape dshape = in_shape->at(0);
+  NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 0, dshape);
+
+  // The case of parametric relu
+  CHECK_EQ(dshape.ndim(), 4) << "Input data should be 4D, but got " << dshape.ndim();
+  CHECK(size_t(param.axis) < dshape.Size())
+      << "Wrong axis ("  << param.axis << ")value.";
+
+  NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, 1, TShape({dshape[param.axis]}));
+
+  TShape oshape(dshape);
+  NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
+  return true;
+}
+
+NNVM_REGISTER_OP(prelu)
+.describe(R"code(Parametric version of a Rectified Linear Unit.
+It accepts two arguments: an input ``x`` and a channelwise slope ``alpha``
+and computes the output as :math:`PReLU(x) y = x > 0 ? x : alpha * x`,
+where :math:`*` is an channelwise multiplication for each sample in the
+
+)code" NNVM_ADD_FILELINE)
+.add_argument("data", "Tensor", "Input data.")
+.add_argument("alpha", "Tensor", "Input channelwise alpha.")
+.add_arguments(PReLUParam::__FIELDS__())
+.set_attr_parser(ParamParser<PReLUParam>)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<FInferShape>("FInferShape", PReluInferShape)
+.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
+    return std::vector<std::string>{"data", "alpha"};
+  })
+.set_attr<FTVMCompute>(
+  "FTVMCompute", [](const NodeAttrs& attrs,
+                    const Array<Tensor>& inputs,
+                    const Array<Tensor>& out_info) {
+    const PReLUParam& param = nnvm::get<PReLUParam>(attrs.parsed);
+    return Array<Tensor>{ topi::prelu<float>(inputs[0], inputs[1], param.axis)};
+  })
+.set_support_level(4);
 
 DMLC_REGISTER_PARAMETER(PadParam);
 
diff --git a/nnvm/tests/python/compiler/test_top_level1.py b/nnvm/tests/python/compiler/test_top_level1.py
index 480aa271a..ebf8c6ce6 100644
--- a/nnvm/tests/python/compiler/test_top_level1.py
+++ b/nnvm/tests/python/compiler/test_top_level1.py
@@ -64,6 +64,43 @@ def test_relu():
     inputs = [('x', dshape, x)]
     helper(y, inputs, dtype, forward, backward)
 
+def test_prelu_nchw():
+    x = sym.Variable("x")
+    a = sym.Variable("a")
+    y = sym.prelu(data=x, alpha=a)
+
+    def forward(x, a):
+        return (x < 0) * (x * a.reshape(3, 1, 1)) + (x>=0) * x
+
+    dtype = "float32"
+    dshape_x = (1, 3, 32, 32)
+    dshape_w = (3,)
+
+    inputs = [
+        ('x', dshape_x, x),
+        ('a', dshape_w, a)
+    ]
+    helper(y, inputs, dtype, forward)
+
+def test_prelu_nhwc():
+    x = sym.Variable("x")
+    a = sym.Variable("a")
+    y = sym.prelu(data=x, alpha=a, axis=3)
+
+    def forward(x, a):
+        return (x < 0) * (x * a.reshape(1, 1, 3)) + (x>=0) * x
+
+    dtype = "float32"
+    dshape_x = (1, 32, 32, 3)
+    dshape_w = (3,)
+
+    inputs = [
+        ('x', dshape_x, x),
+        ('a', dshape_w, a)
+    ]
+
+
+    helper(y, inputs, dtype, forward)
 
 def test_sym_scalar_pow():
     scalar = 3
@@ -336,6 +373,8 @@ if __name__ == "__main__":
     test_batchnorm()
     test_dense()
     test_relu()
+    test_prelu_nchw()
+    test_prelu_nhwc()
     test_sym_scalar_pow()
     test_scalar_sym_pow()
     test_exp()
diff --git a/nnvm/tests/python/unittest/test_infer_shape.py b/nnvm/tests/python/unittest/test_infer_shape.py
index 226dedd3d..8011e96f3 100644
--- a/nnvm/tests/python/unittest/test_infer_shape.py
+++ b/nnvm/tests/python/unittest/test_infer_shape.py
@@ -250,6 +250,17 @@ def test_reshape():
     check((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4))
 
 
+def test_prelu():
+    def check(in_shape, axis,  out_shape):
+        x = sym.Variable("x", shape=in_shape)
+        w = sym.Variable("w")
+        y = sym.prelu(x, w, axis=axis, name="y")
+        sdict = infer_shape(y)
+        assert(tuple(sdict["y"][0]) == tuple(out_shape))
+    check((1, 3, 2, 2), 1, (1, 3, 2, 2))
+    check((1, 2, 2, 3), 3, (1, 2, 2, 3))
+
+
 # Level 4
 def test_transpose():
     def check(in_shape, out_shape, **kwargs):
@@ -319,3 +330,4 @@ if __name__ == "__main__":
     test_broadcast_binary()
     test_reduce()
     test_transpose()
+    test_prelu()
diff --git a/nnvm/tests/python/unittest/test_top_level3.py b/nnvm/tests/python/unittest/test_top_level3.py
index e37863b2b..47e7a8bce 100644
--- a/nnvm/tests/python/unittest/test_top_level3.py
+++ b/nnvm/tests/python/unittest/test_top_level3.py
@@ -16,8 +16,15 @@ def test_leaky_relu():
     y = sym.leaky_relu(x, alpha=0.1)
     assert(y.list_input_names() == ["x"])
 
+def test_prelu():
+    x = sym.Variable("x")
+    w = sym.Variable("w")
+    y = sym.prelu(x, w)
+    assert(y.list_input_names()[0] == 'x')
+    assert(y.list_input_names()[1] == 'w')
 
 if __name__ == "__main__":
     test_scalar_op()
     test_reshape()
     test_leaky_relu()
+    test_prelu()
-- 
GitLab