From 32a55f88148c0cd2fe764b0000788f1bbc850eab Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY <32511895+ANSHUMAN87@users.noreply.github.com> Date: Sat, 30 Jun 2018 20:25:15 +0530 Subject: [PATCH] Prelu bug fix (#1358) --- nnvm/src/top/nn/nn.cc | 2 +- topi/include/topi/nn.h | 1 - topi/src/topi.cc | 2 +- topi/tests/python/test_topi_relu.py | 9 +++++---- topi/tests/python_cpp/test_topi_relu.py | 9 +++++---- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc index ab47ae521..ba89e5ceb 100644 --- a/nnvm/src/top/nn/nn.cc +++ b/nnvm/src/top/nn/nn.cc @@ -563,7 +563,7 @@ where :math:`*` is an channelwise multiplication for each sample in the const Array<Tensor>& inputs, const Array<Tensor>& out_info) { const PReLUParam& param = nnvm::get<PReLUParam>(attrs.parsed); - return Array<Tensor>{ topi::prelu<float>(inputs[0], inputs[1], param.axis)}; + return Array<Tensor>{ topi::prelu(inputs[0], inputs[1], param.axis)}; }) .set_support_level(4); diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index ee3101c4c..53b899796 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -92,7 +92,6 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t, * * \return A Tensor whose op member is the relu operation */ -template <typename T> inline tvm::Tensor prelu(const tvm::Tensor &x, const tvm::Tensor &slope, const int axis = 1, diff --git a/topi/src/topi.cc b/topi/src/topi.cc index 59652ed66..fe1f40985 100644 --- a/topi/src/topi.cc +++ b/topi/src/topi.cc @@ -191,7 +191,7 @@ TVM_REGISTER_GLOBAL("topi.nn.leaky_relu") TVM_REGISTER_GLOBAL("topi.nn.prelu") .set_body([](TVMArgs args, TVMRetValue *rv) { - *rv = prelu<float>(args[0], args[1]); + *rv = prelu(args[0], args[1], args[2]); }); TVM_REGISTER_GLOBAL("topi.nn.pad") diff --git a/topi/tests/python/test_topi_relu.py b/topi/tests/python/test_topi_relu.py index 2f7898ff2..28a86be03 100644 --- a/topi/tests/python/test_topi_relu.py +++ b/topi/tests/python/test_topi_relu.py @@ -46,16 +46,16 @@ def verify_leaky_relu(m, alpha): np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) -def verify_prelu(x, w): +def verify_prelu(x, w, axis, weight_reshape): X = tvm.placeholder((x), name='X') W = tvm.placeholder((w), name='W') x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype) w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype) def _prelu_numpy(x, W): - return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x + return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x - B = topi.nn.prelu(X, W) + B = topi.nn.prelu(X, W, axis) s = tvm.create_schedule([B.op]) ctx = tvm.cpu(0) @@ -79,7 +79,8 @@ def test_leaky_relu(): verify_leaky_relu(100, 0.1) def test_prelu(): - verify_prelu((1, 3, 2, 2), (3,)) + verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1)) + verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1)) if __name__ == "__main__": test_schedule_big_array() diff --git a/topi/tests/python_cpp/test_topi_relu.py b/topi/tests/python_cpp/test_topi_relu.py index a5064618b..6677c1bf5 100644 --- a/topi/tests/python_cpp/test_topi_relu.py +++ b/topi/tests/python_cpp/test_topi_relu.py @@ -50,16 +50,16 @@ def verify_leaky_relu(m, alpha): foo(a, b) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) -def verify_prelu(x, w): +def verify_prelu(x, w, axis, weight_reshape): X = tvm.placeholder((x), name='X') W = tvm.placeholder((w), name='W') x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype) w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype) def _prelu_numpy(x, W): - return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x + return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x out_np = _prelu_numpy(x_np, w_np) - B = topi.cpp.nn.prelu(X, W) + B = topi.cpp.nn.prelu(X, W, axis) device = "llvm" target = topi.cpp.TEST_create_target(device) s = topi.cpp.generic.schedule_injective(target, [B]) @@ -81,7 +81,8 @@ def test_leaky_relu(): verify_leaky_relu(100, 0.5) def test_prelu(): - verify_prelu((1, 3, 2, 2), (3,)) + verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1)) + verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1)) if __name__ == "__main__": test_relu() -- GitLab