diff --git a/cmake/modules/contrib/NNPack.cmake b/cmake/modules/contrib/NNPack.cmake
index 82de88a21e63f6216a6550d3b62dd3f3308b0177..4bf844d0c468de32ffcc145c4a0e6bb8506a825c 100644
--- a/cmake/modules/contrib/NNPack.cmake
+++ b/cmake/modules/contrib/NNPack.cmake
@@ -9,6 +9,10 @@ if(USE_NNPACK)
 	include_directories(${PTHREAD_POOL_PATH}/include)
     find_library(NNPACK_CONTRIB_LIB nnpack ${NNPACK_PATH}/lib)
   find_library(NNPACK_PTHREAD_CONTRIB_LIB pthreadpool ${NNPACK_PATH}/lib)
+  find_library(NNPACK_CPUINFO_CONTRIB_LIB cpuinfo ${NNPACK_PATH}/lib)
+  find_library(NNPACK_CLOG_CONTRIB_LIB clog ${NNPACK_PATH}/lib)
   list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CONTRIB_LIB})
   list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_PTHREAD_CONTRIB_LIB})
+  list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CPUINFO_CONTRIB_LIB})
+  list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CLOG_CONTRIB_LIB})
 endif(USE_NNPACK)
diff --git a/python/tvm/contrib/nnpack.py b/python/tvm/contrib/nnpack.py
index d6587df26229729a71fbd7b5a4f8ad0c35abc4df..36f8a76a87db8310cd5641accae121ee40f94545 100644
--- a/python/tvm/contrib/nnpack.py
+++ b/python/tvm/contrib/nnpack.py
@@ -63,14 +63,32 @@ def fully_connected_output(lhs, rhs, nthreads=1):
             "tvm.contrib.nnpack.fully_connected_output",
             ins[0], ins[1], outs[0], nthreads), name="C")
 
-def convolution_inference(data, kernel, bias, padding, stride, nthreads=1):
-    """Create an extern op to do inference convolution of 3D tensor data and
+
+class ConvolutionAlgorithm:
+    AUTO = 0
+    FFT_8x8 = 1
+    FFT_16x16 = 2
+    WT_8x8 = 3
+    IMPLICIT_GEMM = 4
+    DIRECT = 5
+    WT_8x8_FP16 = 6
+
+
+class ConvolutionTransformStrategy:
+    COMPUTE = 1
+    PRECOMPUTE = 2
+
+
+def convolution_inference(
+        data, kernel, bias, padding, stride, nthreads=1,
+        algorithm=ConvolutionAlgorithm.AUTO):
+    """Create an extern op to do inference convolution of 4D tensor data and
     4D tensor kernel and 1D tensor bias with nnpack.
 
     Parameters
     ----------
     data : Tensor
-        data 3D tensor input[input_channels][input_height][input_width] of
+        data 4D tensor input[batch][input_channels][input_height][input_width] of
         FP32 elements.
     kernel : Tensor
         kernel 4D tensor kernel[output_channels][input_channels][kernel_height]
@@ -88,23 +106,108 @@ def convolution_inference(data, kernel, bias, padding, stride, nthreads=1):
     Returns
     -------
     output : Tensor
-        output 3D tensor output[output_channels][output_height][output_width]
+        output 4D tensor output[batch][output_channels][output_height][output_width]
         of FP32 elements.
     """
 
     assert isinstance(padding, list) and len(padding) == 4
     assert isinstance(stride, list) and len(stride) == 2
-    _, input_height, input_width = data.shape
+    batch, _, input_height, input_width = data.shape
     output_channels, _, kernel_height, kernel_width = kernel.shape
     output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1
     output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1
 
     return _api.extern(
-        (output_channels, output_height, output_width), [data, kernel, bias],
+        (batch, output_channels, output_height, output_width),
+        [data, kernel, bias] if bias is not None else [data, kernel],
         lambda ins, outs: _intrin.call_packed(
-            "tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], ins[2],
+            "tvm.contrib.nnpack.convolution_inference",
+            ins[0],
+            ins[1],
+            ins[2] if bias is not None else 0,
             outs[0], padding[0], padding[1], padding[2], padding[3],
-            stride[0], stride[1], nthreads), name="C")
+            stride[0], stride[1], nthreads, algorithm), name="C")
+
+def convolution_inference_without_weight_transform(
+        data, transformed_kernel, bias, padding, stride, nthreads=1,
+        algorithm=ConvolutionAlgorithm.AUTO):
+    """Create an extern op to do inference convolution of 4D tensor data and
+    4D pre-transformed tensor kernel and 1D tensor bias with nnpack.
+
+    Parameters
+    ----------
+    data : Tensor
+        data 4D tensor input[batch][input_channels][input_height][input_width] of
+        FP32 elements.
+    transformed_kernel : Tensor
+        transformed_kernel 4D tensor kernel[output_channels][input_channels][tile]
+        [tile] of FP32 elements.
+    bias : Tensor
+        bias 1D array bias[output_channels][input_channels][kernel_height]
+        [kernel_width] of FP32 elements.
+    padding : list
+        padding A 4-dim list of [pad_top, pad_bottom, pad_left, pad_right],
+        which indicates the padding around the feature map.
+    stride : list
+        stride A 2-dim list of [stride_height, stride_width], which indicates
+        the stride.
+
+    Returns
+    -------
+    output : Tensor
+        output 4D tensor output[batch][output_channels][output_height][output_width]
+        of FP32 elements.
+    """
+
+    assert algorithm in (ConvolutionAlgorithm.WT_8x8,
+                         ConvolutionAlgorithm.WT_8x8_FP16)
+    assert isinstance(padding, list) and len(padding) == 4
+    assert isinstance(stride, list) and len(stride) == 2
+    batch, _, input_height, input_width = data.shape
+    output_channels, _, _, _ = transformed_kernel.shape
+    kernel_height, kernel_width = (3, 3)
+    output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1
+    output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1
+
+    return _api.extern(
+        (batch, output_channels, output_height, output_width),
+        [data, transformed_kernel, bias] if bias is not None else [data, transformed_kernel],
+        lambda ins, outs: _intrin.call_packed(
+            "tvm.contrib.nnpack.convolution_inference_without_weight_transform",
+            ins[0],
+            ins[1],
+            ins[2] if bias is not None else 0,
+            outs[0], padding[0], padding[1], padding[2], padding[3],
+            stride[0], stride[1], nthreads, algorithm), name="C")
+
+def convolution_inference_weight_transform(
+        kernel, nthreads=1,
+        algorithm=ConvolutionAlgorithm.AUTO):
+    """Create an extern op to do inference convolution of 3D tensor data and
+    4D tensor kernel and 1D tensor bias with nnpack.
+
+    Parameters
+    ----------
+    kernel : Tensor
+        kernel 4D tensor kernel[output_channels][input_channels][kernel_height]
+        [kernel_width] of FP32 elements.
+
+    Returns
+    -------
+    output : Tensor
+        output 4D tensor output[output_channels][input_channels][tile][tile]
+        of FP32 elements.
+    """
+    assert algorithm in (ConvolutionAlgorithm.WT_8x8, ConvolutionAlgorithm.WT_8x8_FP16)
+    output_channels, input_channels, _, _ = kernel.shape
+
+    transform_tile_size = 8
+    return _api.extern(
+        (output_channels, input_channels, transform_tile_size, transform_tile_size),
+        [kernel],
+        lambda ins, outs: _intrin.call_packed(
+            "tvm.contrib.nnpack.convolution_inference_weight_transform",
+            ins[0], outs[0], nthreads, algorithm), name="transform_kernel")
 
 def convolution_output(data, kernel, bias, padding, nthreads=1):
     """Create an extern op to compute convolution of 4D tensor data and
@@ -144,4 +247,5 @@ def convolution_output(data, kernel, bias, padding, nthreads=1):
             "tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2],
             outs[0], padding[0], padding[1], padding[2], padding[3], nthreads), name="C")
 
+
 _init_api("tvm.contrib.nnpack")
diff --git a/src/contrib/nnpack/convolution.cc b/src/contrib/nnpack/convolution.cc
index f658a1fe96d4956ce78efaf8a15bae5e7f4d33bb..8bcdd64281cce31bde3aa3f79c93a3dfbb43cab0 100644
--- a/src/contrib/nnpack/convolution.cc
+++ b/src/contrib/nnpack/convolution.cc
@@ -13,62 +13,208 @@ namespace contrib {
 using namespace runtime;
 
 TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
-.set_body([](TVMArgs args, TVMRetValue *ret) {
-    NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
-    nnp_initialize();
-    DLTensor* input  = args[0];
-    DLTensor* kernel = args[1];
-    DLTensor* bias   = args[2];
-    DLTensor* output = args[3];
-    uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7];
-    nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
-    uint64_t stride_width = args[8], stride_height = args[9];
-    nnp_size stride_size{stride_width, stride_height};
-    NNPackConfig(args[10]);
+    .set_body([](TVMArgs args, TVMRetValue *ret) {
+      NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
+      static std::once_flag flag;
+      std::call_once(flag,
+                     []() { CHECK_EQ(nnp_initialize(), nnp_status_success); });
+      DLTensor *input = args[0];
+      DLTensor *kernel = args[1];
+      DLTensor *bias = nullptr;
+      if (args[2].type_code() == kArrayHandle) {
+        bias = args[2];
+      }
+      DLTensor *output = args[3];
+      uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6],
+               pad_left = args[7];
+      nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
+      uint64_t stride_width = args[8], stride_height = args[9];
+      nnp_size stride_size{stride_width, stride_height};
+      NNPackConfig(args[10]);
 
-    CHECK_EQ(input->ndim, 3);
-    CHECK_EQ(kernel->ndim, 4);
-    CHECK_EQ(bias->ndim, 1);
-    CHECK_EQ(output->ndim, 3);
-
-    CHECK_EQ(input->shape[0], kernel->shape[1]);
-    size_t input_channels = input->shape[0];
-    CHECK_EQ(output->shape[0], kernel->shape[0]);
-    CHECK_EQ(output->shape[0], bias->shape[0]);
-    size_t output_channels = output->shape[0];
-    nnp_size input_size{static_cast<size_t>(input->shape[1]),
-                        static_cast<size_t>(input->shape[2])};
-    nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]),
-                         static_cast<size_t>(kernel->shape[3])};
+      uint64_t algo_ = args[11];
+      nnp_convolution_algorithm algo =
+          static_cast<nnp_convolution_algorithm>(algo_);
+      CHECK_EQ(input->ndim, 4);
+      CHECK_EQ(kernel->ndim, 4);
+      if (bias) {
+        CHECK_EQ(bias->ndim, 1);
+      }
+      CHECK_EQ(output->ndim, 4);
+      CHECK_EQ(input->shape[1], kernel->shape[1]);
+      CHECK_EQ(input->shape[0], output->shape[0]);
+      size_t input_channels = input->shape[1];
+      CHECK_EQ(output->shape[1], kernel->shape[0]);
+      if (bias) {
+        CHECK_EQ(output->shape[1], bias->shape[0]);
+      }
+      size_t output_channels = output->shape[1];
+      nnp_size input_size{static_cast<size_t>(input->shape[2]),
+                          static_cast<size_t>(input->shape[3])};
+      nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]),
+                           static_cast<size_t>(kernel->shape[3])};
+      CHECK(input->strides == nullptr);
+      CHECK(kernel->strides == nullptr);
+      if (bias) {
+        CHECK(bias->strides == nullptr);
+      }
 
-    CHECK(input->strides == nullptr);
-    CHECK(kernel->strides == nullptr);
-    CHECK(bias->strides == nullptr);
+      CHECK(TypeMatch(input->dtype, kDLFloat, 32));
+      CHECK(TypeMatch(kernel->dtype, kDLFloat, 32));
+      if (bias) {
+        CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
+      }
+      CHECK(TypeMatch(output->dtype, kDLFloat, 32));
 
-    CHECK(TypeMatch(input->dtype, kDLFloat, 32));
-    CHECK(TypeMatch(kernel->dtype, kDLFloat, 32));
-    CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
-    CHECK(TypeMatch(output->dtype, kDLFloat, 32));
+      // Allocate a zero-bias if we don't pass one in.
+      std::unique_ptr<std::vector<float>> zero_bias;
+      if (!bias) {
+        zero_bias.reset(new std::vector<float>(output->shape[1], 0.0));
+      }
 
-    nnp_convolution_inference(nnp_convolution_algorithm_auto,
-                              nnp_convolution_transform_strategy_block_based,
-                              input_channels,
-                              output_channels,
-                              input_size,
-                              input_padding,
-                              kernel_size,
-                              stride_size,
-                              static_cast<float*>(input->data),
-                              static_cast<float*>(kernel->data),
-                              static_cast<float*>(bias->data),
-                              static_cast<float*>(output->data),
-                              NULL,
-                              NULL,
-                              nnp_activation_identity,
-                              NULL,
-                              entry->threadpool,
-                              NULL);
-  });
+      for (auto n = 0; n < input->shape[0]; ++n) {
+        nnp_status status = nnp_convolution_inference(
+            algo, nnp_convolution_transform_strategy_compute, input_channels,
+            output_channels, input_size, input_padding, kernel_size,
+            stride_size,
+            static_cast<float *>(input->data) + n * input->shape[1] *
+                                                   input->shape[2] *
+                                                   input->shape[3],
+            static_cast<float *>(kernel->data),
+            bias ? static_cast<float *>(bias->data) : zero_bias->data(),
+            static_cast<float *>(output->data) + n * output->shape[1] *
+                                                    output->shape[2] *
+                                                    output->shape[3],
+            NULL, NULL, nnp_activation_identity, NULL, entry->threadpool, NULL);
+
+        CHECK_EQ(status, nnp_status_success);
+      }
+    });
+
+TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform")
+    .set_body([](TVMArgs args, TVMRetValue *ret) {
+      NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
+      static std::once_flag flag;
+      std::call_once(flag,
+                     []() { CHECK_EQ(nnp_initialize(), nnp_status_success); });
+      DLTensor *input = args[0];
+      DLTensor *transformed_kernel = args[1];
+      DLTensor *bias = nullptr;
+      if (args[2].type_code() == kArrayHandle) {
+        bias = args[2];
+      }
+      DLTensor *output = args[3];
+      uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6],
+               pad_left = args[7];
+      nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left};
+      uint64_t stride_width = args[8], stride_height = args[9];
+      nnp_size stride_size{stride_width, stride_height};
+      NNPackConfig(args[10]);
+
+      uint64_t algo_ = args[11];
+      nnp_convolution_algorithm algo =
+          static_cast<nnp_convolution_algorithm>(algo_);
+      CHECK_EQ(input->ndim, 4);
+      if (bias) {
+        CHECK_EQ(bias->ndim, 1);
+      }
+      CHECK_EQ(output->ndim, 4);
+      CHECK_EQ(input->shape[0], output->shape[0]);
+      size_t input_channels = input->shape[1];
+      if (bias) {
+        CHECK_EQ(output->shape[1], bias->shape[0]);
+      }
+      size_t output_channels = output->shape[1];
+      nnp_size input_size{static_cast<size_t>(input->shape[2]),
+                          static_cast<size_t>(input->shape[3])};
+      nnp_size kernel_size{3, 3};
+      CHECK(input->strides == nullptr);
+      CHECK(transformed_kernel->strides == nullptr);
+      if (bias) {
+        CHECK(bias->strides == nullptr);
+      }
+
+      CHECK(TypeMatch(input->dtype, kDLFloat, 32));
+      CHECK(TypeMatch(transformed_kernel->dtype, kDLFloat, 32));
+      if (bias) {
+        CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
+      }
+      CHECK(TypeMatch(output->dtype, kDLFloat, 32));
+
+      // Allocate a zero-bias if we don't pass one in.
+      std::unique_ptr<std::vector<float>> zero_bias;
+      if (!bias) {
+        zero_bias.reset(new std::vector<float>(output->shape[1], 0.0));
+      }
+
+      for (auto n = 0; n < input->shape[0]; ++n) {
+      nnp_status status = nnp_convolution_inference(
+          algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels,
+          input_size, input_padding, kernel_size, stride_size,
+          static_cast<float *>(input->data) + n * input->shape[1] *
+                               input->shape[2] *
+                               input->shape[3],
+          static_cast<float *>(transformed_kernel->data),
+          bias ? static_cast<float *>(bias->data) : zero_bias->data(),
+          static_cast<float *>(output->data) + n * output->shape[1] *
+                               output->shape[2] *
+                               output->shape[3],
+          NULL, NULL,
+          nnp_activation_identity, NULL, entry->threadpool, NULL);
+      CHECK_EQ(status, nnp_status_success);
+      }
+    });
+
+TVM_REGISTER_GLOBAL(
+    "tvm.contrib.nnpack.convolution_inference_weight_transform")
+    .set_body([](TVMArgs args, TVMRetValue *ret) {
+      NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
+      static std::once_flag flag;
+      std::call_once(flag,
+                     []() { CHECK_EQ(nnp_initialize(), nnp_status_success); });
+      DLTensor *kernel = args[0];
+      DLTensor *transformed_kernel = args[1];
+      // Dummy sizes
+      nnp_padding input_padding{1, 1, 1, 1};
+      nnp_size stride_size{1, 1};
+
+      nnp_size input_size{100, 100};
+
+      NNPackConfig(args[2]);
+
+      uint64_t algo_ = args[3];
+      nnp_convolution_algorithm algo =
+          static_cast<nnp_convolution_algorithm>(algo_);
+      CHECK_EQ(kernel->ndim, 4);
+      size_t input_channels = kernel->shape[1];
+      size_t output_channels = kernel->shape[0];
+      CHECK_EQ(kernel->shape[2], 3);
+      CHECK_EQ(kernel->shape[3], 3);
+      nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]),
+                           static_cast<size_t>(kernel->shape[3])};
+      CHECK(kernel->strides == nullptr);
+      CHECK(TypeMatch(kernel->dtype, kDLFloat, 32));
+
+      size_t transformed_kernel_size = 0;
+      nnp_status status;
+      status = nnp_convolution_inference(
+          algo, nnp_convolution_transform_strategy_precompute, input_channels,
+          output_channels, input_size, input_padding, kernel_size, stride_size,
+          nullptr, nullptr, nullptr, nullptr, nullptr, &transformed_kernel_size,
+          nnp_activation_identity, nullptr, entry->threadpool, nullptr);
+      CHECK_EQ(status, nnp_status_success);
+
+      CHECK_LE(transformed_kernel_size, GetDataSize(*transformed_kernel));
+
+      status = nnp_convolution_inference(
+          algo, nnp_convolution_transform_strategy_precompute, input_channels,
+          output_channels, input_size, input_padding, kernel_size, stride_size,
+          nullptr, static_cast<float *>(kernel->data), nullptr, nullptr,
+          static_cast<float *>(transformed_kernel->data),
+          &transformed_kernel_size, nnp_activation_identity, nullptr,
+          entry->threadpool, nullptr);
+      CHECK_EQ(status, nnp_status_success);
+    });
 
 
 TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
@@ -109,7 +255,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
     CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
     CHECK(TypeMatch(output->dtype, kDLFloat, 32));
 
-    nnp_convolution_output(nnp_convolution_algorithm_auto,
+    nnp_status status = nnp_convolution_output(nnp_convolution_algorithm_auto,
                            batch_size,
                            input_channels,
                            output_channels,
@@ -126,6 +272,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
                            NULL,
                            entry->threadpool,
                            NULL);
+    CHECK_EQ(status, nnp_status_success);
   });
 }  // namespace contrib
 }  // namespace tvm
diff --git a/src/contrib/nnpack/nnpack_utils.cc b/src/contrib/nnpack/nnpack_utils.cc
index 3220d7af339f66e2fb1173dcae20a8360dfa8375..d8ef1d0b83276e0534a1a00e71525ec20bafe92c 100644
--- a/src/contrib/nnpack/nnpack_utils.cc
+++ b/src/contrib/nnpack/nnpack_utils.cc
@@ -10,20 +10,30 @@ using namespace runtime;
 
 typedef dmlc::ThreadLocalStore<NNPackThreadLocalEntry> NNPackThreadLocalStore;
 
+
 NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() {
   return NNPackThreadLocalStore::Get();
 }
 
 bool NNPackConfig(uint64_t nthreads) {
   NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal();
-  if (entry->threadpool != NULL &&
-      pthreadpool_get_threads_count(entry->threadpool) != nthreads) {
+  if (entry->threadpool && pthreadpool_get_threads_count(entry->threadpool) == nthreads) {
+    CHECK_NE(nthreads, 1);
+    return true;
+  }
+  if (entry->threadpool) {
     pthreadpool_destroy(entry->threadpool);
-    entry->threadpool = NULL;
+    entry->threadpool = nullptr;
   }
-  if (entry->threadpool == NULL) {
-    entry->threadpool = pthreadpool_create(nthreads);
+
+  if (nthreads == 1) {
+    // a null threadpool means the function is invoked on the calling thread,
+    // which is the desired logic for nthreads == 1
+    CHECK(!entry->threadpool);
+    return true;
   }
+
+  entry->threadpool = pthreadpool_create(nthreads);
   return true;
 }
 
diff --git a/src/contrib/nnpack/nnpack_utils.h b/src/contrib/nnpack/nnpack_utils.h
index fe7420786bdec2477479cf9be73e072f8bf7c3ad..1d44adff16ef9a6e53650992fbd04adad7cf7608 100644
--- a/src/contrib/nnpack/nnpack_utils.h
+++ b/src/contrib/nnpack/nnpack_utils.h
@@ -15,7 +15,7 @@ namespace contrib {
 using namespace runtime;
 
 struct NNPackThreadLocalEntry {
-  pthreadpool_t threadpool{NULL};
+  pthreadpool_t threadpool{nullptr};
   static NNPackThreadLocalEntry* ThreadLocal();
 };
 
diff --git a/tests/lint/pylintrc b/tests/lint/pylintrc
index f5c4452cfa163eb65ffd8943ff7d39d25e8d76a5..18f526702ad88deba459efede37b6e6c883e2338 100644
--- a/tests/lint/pylintrc
+++ b/tests/lint/pylintrc
@@ -290,10 +290,10 @@ variable-rgx=[a-z_][a-z0-9_]{2,30}$
 variable-name-hint=[a-z_][a-z0-9_]{2,30}$
 
 # Regular expression matching correct function names
-function-rgx=[a-z_][a-z0-9_]{2,30}$
+function-rgx=[a-z_][a-z0-9_]{2,48}$
 
 # Naming hint for function names
-function-name-hint=[a-z_][a-z0-9_]{2,30}$
+function-name-hint=[a-z_][a-z0-9_]{2,48}$
 
 # Regular expression matching correct class names
 class-rgx=[A-Z_][a-zA-Z0-9]+$
diff --git a/tests/python/contrib/test_nnpack.py b/tests/python/contrib/test_nnpack.py
index a6c6b8158ff3bd06b2c4f528abe479a3c7a6b978..0b275fb812bf268785980a3dc50a6a5e63684219 100644
--- a/tests/python/contrib/test_nnpack.py
+++ b/tests/python/contrib/test_nnpack.py
@@ -100,7 +100,7 @@ def np_conv(na, nw, padding, stride=1):
     return nb
 
 def test_convolution_inference():
-    BATCH = 32
+    BATCH = 8
     IH = 48
     IW = 48
     IC = 16
@@ -111,19 +111,17 @@ def test_convolution_inference():
 
     OH = (IH + 2*PAD - K) + 1
     OW = (IW + 2*PAD - K) + 1
-    dshape = (IC, IH, IW)
+    dshape = (BATCH, IC, IH, IW)
     kshape = (OC, IC, K, K)
     bshape = (OC, )
-    oshape = (OC, OH, OW)
+    oshape = (BATCH, OC, OH, OW)
 
     data = tvm.placeholder(dshape, name='data')
     kernel = tvm.placeholder(kshape, name='kernel')
     bias = tvm.placeholder(bshape, name='bias')
-    output = nnpack.convolution_inference(data, kernel, bias,
-        [PAD, PAD, PAD, PAD], [STRIDE, STRIDE])
-    s = tvm.create_schedule(output.op)
-
-    def verify(target="llvm"):
+    def verify(target="llvm",
+               algorithm=nnpack.ConvolutionAlgorithm.AUTO,
+               with_bias=True):
         if not tvm.module.enabled(target):
             print("skip because %s is not enabled..." % target)
             return
@@ -131,6 +129,12 @@ def test_convolution_inference():
             print("skip because extern function is not available")
             return
         ctx = tvm.cpu(0)
+        output = nnpack.convolution_inference(
+            data, kernel, bias if with_bias else None,
+            [PAD, PAD, PAD, PAD], [STRIDE, STRIDE],
+            algorithm=algorithm)
+        s = tvm.create_schedule(output.op)
+
         f = tvm.build(s, [data, kernel, bias, output], target)
 
         na = np.random.uniform(size=dshape).astype(data.dtype)
@@ -141,10 +145,77 @@ def test_convolution_inference():
         tc = tvm.nd.array(nc, ctx)
         td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
         f(ta, tb, tc, td)
-        nd = np_conv(np.reshape(na, (1, IC, IH, IW)), nb, PAD, STRIDE)
+        nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(1, bshape[0], 1, 1)
         tvm.testing.assert_allclose(
-            td.asnumpy(), nd.reshape(IC, IH, IW), rtol=1e-5)
-    verify()
+            td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5)
+    for algorithm in [
+            nnpack.ConvolutionAlgorithm.AUTO,
+            nnpack.ConvolutionAlgorithm.FFT_8x8,
+            nnpack.ConvolutionAlgorithm.FFT_16x16,
+            nnpack.ConvolutionAlgorithm.WT_8x8,
+            nnpack.ConvolutionAlgorithm.IMPLICIT_GEMM,
+            nnpack.ConvolutionAlgorithm.WT_8x8_FP16,
+    ]:
+        for with_bias in [True, False]:
+            verify(algorithm=algorithm, with_bias=with_bias)
+
+
+def test_convolution_inference_without_weight_transform():
+    BATCH = 6
+    IH = 48
+    IW = 48
+    IC = 16
+    OC = 16
+    K = 3
+    PAD = 1
+    STRIDE = 1
+
+    OH = (IH + 2*PAD - K) + 1
+    OW = (IW + 2*PAD - K) + 1
+    dshape = (BATCH, IC, IH, IW)
+    kshape = (OC, IC, K, K)
+    bshape = (OC, )
+    oshape = (BATCH, OC, OH, OW)
+
+    data = tvm.placeholder(dshape, name='data')
+    kernel = tvm.placeholder(kshape, name='kernel')
+    bias = tvm.placeholder(bshape, name='bias')
+    def verify(target="llvm",
+               algorithm=nnpack.ConvolutionAlgorithm.AUTO,
+               with_bias=True):
+        if not tvm.module.enabled(target):
+            print("skip because %s is not enabled..." % target)
+            return
+        if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True):
+            print("skip because extern function is not available")
+            return
+
+        ctx = tvm.cpu(0)
+        transformed_kernel = nnpack.convolution_inference_weight_transform(
+            kernel, algorithm=algorithm)
+        output = nnpack.convolution_inference_without_weight_transform(
+            data, transformed_kernel, bias if with_bias else None,
+            [PAD, PAD, PAD, PAD], [STRIDE, STRIDE],
+            algorithm=algorithm)
+
+        s = tvm.create_schedule(output.op)
+
+        f = tvm.build(s, [data, kernel, bias, output], target)
+
+        na = np.random.uniform(size=dshape).astype(data.dtype)
+        nb = np.random.uniform(size=kshape).astype(kernel.dtype)
+        nc = np.random.uniform(size=bshape).astype(bias.dtype) if with_bias else np.zeros(bshape, dtype=bias.dtype)
+        ta = tvm.nd.array(na, ctx)
+        tb = tvm.nd.array(nb, ctx)
+        tc = tvm.nd.array(nc, ctx)
+        td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx)
+        f(ta, tb, tc, td)
+        nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(1, bshape[0], 1, 1)
+        tvm.testing.assert_allclose(
+            td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5)
+    for algorithm in [nnpack.ConvolutionAlgorithm.WT_8x8]:
+        for with_bias in [True, False]:
+            verify(algorithm=algorithm, with_bias=with_bias)
 
 def test_convolution_output():
     BATCH = 32