diff --git a/nnvm/src/top/nn/nn.cc b/nnvm/src/top/nn/nn.cc
index ab47ae521224256495af4cca0d7471f1b5575e9e..ba89e5ceba58907fa59a50c2e5dc85cc5011be1c 100644
--- a/nnvm/src/top/nn/nn.cc
+++ b/nnvm/src/top/nn/nn.cc
@@ -563,7 +563,7 @@ where :math:`*` is an channelwise multiplication for each sample in the
                     const Array<Tensor>& inputs,
                     const Array<Tensor>& out_info) {
     const PReLUParam& param = nnvm::get<PReLUParam>(attrs.parsed);
-    return Array<Tensor>{ topi::prelu<float>(inputs[0], inputs[1], param.axis)};
+    return Array<Tensor>{ topi::prelu(inputs[0], inputs[1], param.axis)};
   })
 .set_support_level(4);
 
diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h
index ee3101c4cc18363659cfda98ecf9bf70aa184aef..53b899796e37f844e197298a62d2efa63eabc59d 100644
--- a/topi/include/topi/nn.h
+++ b/topi/include/topi/nn.h
@@ -92,7 +92,6 @@ inline tvm::Tensor leaky_relu(const tvm::Tensor& t,
  *
  * \return A Tensor whose op member is the relu operation
  */
-template <typename T>
 inline tvm::Tensor prelu(const tvm::Tensor &x,
                          const tvm::Tensor &slope,
                          const int axis = 1,
diff --git a/topi/src/topi.cc b/topi/src/topi.cc
index 59652ed6680eadb53de76f81d0f9da77f26f619d..fe1f4098570d1f9675ddf06705f07550ef19c868 100644
--- a/topi/src/topi.cc
+++ b/topi/src/topi.cc
@@ -191,7 +191,7 @@ TVM_REGISTER_GLOBAL("topi.nn.leaky_relu")
 
 TVM_REGISTER_GLOBAL("topi.nn.prelu")
 .set_body([](TVMArgs args, TVMRetValue *rv) {
-  *rv = prelu<float>(args[0], args[1]);
+  *rv = prelu(args[0], args[1], args[2]);
   });
 
 TVM_REGISTER_GLOBAL("topi.nn.pad")
diff --git a/topi/tests/python/test_topi_relu.py b/topi/tests/python/test_topi_relu.py
index 2f7898ff242a8c47448b62e31d35f07c8d2b1e6d..28a86be03ea03d6107ec285c03db33b677a397da 100644
--- a/topi/tests/python/test_topi_relu.py
+++ b/topi/tests/python/test_topi_relu.py
@@ -46,16 +46,16 @@ def verify_leaky_relu(m, alpha):
     np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
 
-def verify_prelu(x, w):
+def verify_prelu(x, w, axis, weight_reshape):
     X = tvm.placeholder((x), name='X')
     W = tvm.placeholder((w), name='W')
     x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
     w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
 
     def _prelu_numpy(x, W):
-        return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x
+        return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x
 
-    B = topi.nn.prelu(X, W)
+    B = topi.nn.prelu(X, W, axis)
     s = tvm.create_schedule([B.op])
 
     ctx = tvm.cpu(0)
@@ -79,7 +79,8 @@ def test_leaky_relu():
     verify_leaky_relu(100, 0.1)
 
 def test_prelu():
-    verify_prelu((1, 3, 2, 2), (3,))
+    verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
+    verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))
 
 if __name__ == "__main__":
     test_schedule_big_array()
diff --git a/topi/tests/python_cpp/test_topi_relu.py b/topi/tests/python_cpp/test_topi_relu.py
index a5064618b8900ba6a5ca19fe9cf7d5f6c2d02297..6677c1bf55517bcbf0de86999b6ae22ecf11adb4 100644
--- a/topi/tests/python_cpp/test_topi_relu.py
+++ b/topi/tests/python_cpp/test_topi_relu.py
@@ -50,16 +50,16 @@ def verify_leaky_relu(m, alpha):
     foo(a, b)
     np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
 
-def verify_prelu(x, w):
+def verify_prelu(x, w, axis, weight_reshape):
     X = tvm.placeholder((x), name='X')
     W = tvm.placeholder((w), name='W')
     x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
     w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
     def _prelu_numpy(x, W):
-        return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x
+        return (x < 0) * (x *W.reshape(weight_reshape)) + (x>=0) * x
 
     out_np = _prelu_numpy(x_np, w_np)
-    B = topi.cpp.nn.prelu(X, W)
+    B = topi.cpp.nn.prelu(X, W, axis)
     device = "llvm"
     target = topi.cpp.TEST_create_target(device)
     s = topi.cpp.generic.schedule_injective(target, [B])
@@ -81,7 +81,8 @@ def test_leaky_relu():
     verify_leaky_relu(100, 0.5)
 
 def test_prelu():
-    verify_prelu((1, 3, 2, 2), (3,))
+    verify_prelu((1, 3, 2, 2), (3,), 1, (3, 1, 1))
+    verify_prelu((1, 3, 2, 2), (2,), 2, (2, 1))
 
 if __name__ == "__main__":
     test_relu()