From 32a55f88148c0cd2fe764b0000788f1bbc850eab Mon Sep 17 00:00:00 2001
From: ANSHUMAN TRIPATHY <32511895+ANSHUMAN87@users.noreply.github.com>
Date: Sat, 30 Jun 2018 20:25:15 +0530
Subject: [PATCH] Prelu bug fix (#1358)

---
 nnvm/src/top/nn/nn.cc                   | 2 +-
 topi/include/topi/nn.h                  | 1 -
 topi/src/topi.cc                        | 2 +-
 topi/tests/python/test_topi_relu.py     | 9 +++++----
 topi/tests/python_cpp/test_topi_relu.py | 9 +++++----
 5 files changed, 12 insertions(+), 11 deletions(-)

diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc
index ab47ae521..ba89e5ceb 100644
--- a/nnvm/src/top/nn/nn.cc
+++ b/nnvm/src/top/nn/nn.cc
@@ -563,7 +563,7 @@ where :math:`*` is an channelwise multiplication for each sample in the
                     const Array<Tensor>& inputs,
                     const Array<Tensor>& out_info) {
     const PReLUParam& param = nnvm::get<PReLUParam>(attrs.parsed);
-    return Array<Tensor>{ topi::prelu<float>(inputs[0], inputs[1], param.axis)};
+    return Array<Tensor>{ topi::prelu(inputs[0], inputs[1], param.axis)};
   })
 .set_support_level(4);
 
diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h
index ee3101c4c..53b899796 100644
--- a/topi/include/topi/nn.h
+++ b/topi/include/topi/nn.h
@@ -92,7 +92,6 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
  *
  * \return A Tensor whose op member is the relu operation
  */
-template <typename T>
 inline tvm::Tensor prelu(const tvm::Tensor &x,
                          const tvm::Tensor &slope,
                          const int axis = 1,
diff --git a/topi/src/topi.cc b/topi/src/topi.cc
index 59652ed66..fe1f40985 100644
--- a/topi/src/topi.cc
+++ b/topi/src/topi.cc
@@ -191,7 +191,7 @@ TVM_REGISTER_GLOBAL("topi.nn.leaky_relu")
 
 TVM_REGISTER_GLOBAL("topi.nn.prelu")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
-  *rv = prelu<float>(args[0], args[1]);
+  *rv = prelu(args[0], args[1], args[2]);
   });
 
 TVM_REGISTER_GLOBAL("topi.nn.pad")
diff --git a/topi/tests/python/test_topi_relu.py b/topi/tests/python/test_topi_relu.py
index 2f7898ff2..28a86be03 100644
--- a/topi/tests/python/test_topi_relu.py
+++ b/topi/tests/python/test_topi_relu.py
@@ -46,16 +46,16 @@ def verify_leaky_relu(m, alpha):
     np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
 
-def verify_prelu(x, w):
+def verify_prelu(x, w, axis, weight_reshape):
     X = tvm.placeholder((x), name='X')
     W = tvm.placeholder((w), name='W')
     x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
     w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
 
     def _prelu_numpy(x, W):
-        return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x
+        return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x
 
-    B = topi.nn.prelu(X, W)
+    B = topi.nn.prelu(X, W, axis)
     s = tvm.create_schedule([B.op])
 
     ctx = tvm.cpu(0)
@@ -79,7 +79,8 @@ def test_leaky_relu():
     verify_leaky_relu(100, 0.1)
 
 def test_prelu():
-    verify_prelu((1, 3, 2, 2), (3,))
+    verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
+    verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))
 
 if __name__ == "__main__":
     test_schedule_big_array()
diff --git a/topi/tests/python_cpp/test_topi_relu.py b/topi/tests/python_cpp/test_topi_relu.py
index a5064618b..6677c1bf5 100644
--- a/topi/tests/python_cpp/test_topi_relu.py
+++ b/topi/tests/python_cpp/test_topi_relu.py
@@ -50,16 +50,16 @@ def verify_leaky_relu(m, alpha):
     foo(a, b)
     np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
-def verify_prelu(x, w):
+def verify_prelu(x, w, axis, weight_reshape):
     X = tvm.placeholder((x), name='X')
     W = tvm.placeholder((w), name='W')
     x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
     w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
     def _prelu_numpy(x, W):
-        return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x
+        return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x
 
     out_np = _prelu_numpy(x_np, w_np)
-    B = topi.cpp.nn.prelu(X, W)
+    B = topi.cpp.nn.prelu(X, W, axis)
     device = "llvm"
     target = topi.cpp.TEST_create_target(device)
     s = topi.cpp.generic.schedule_injective(target, [B])
@@ -81,7 +81,8 @@ def test_leaky_relu():
     verify_leaky_relu(100, 0.5)
 
 def test_prelu():
-    verify_prelu((1, 3, 2, 2), (3,))
+    verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
+    verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))
 
 if __name__ == "__main__":
     test_relu()
-- 
GitLab