diff --git a/cmake/modules/contrib/NNPack.cmake b/cmake/modules/contrib/NNPack.cmake index 82de88a21e63f6216a6550d3b62dd3f3308b0177..4bf844d0c468de32ffcc145c4a0e6bb8506a825c 100644 --- a/cmake/modules/contrib/NNPack.cmake +++ b/cmake/modules/contrib/NNPack.cmake @@ -9,6 +9,10 @@ if(USE_NNPACK) include_directories(${PTHREAD_POOL_PATH}/include) find_library(NNPACK_CONTRIB_LIB nnpack ${NNPACK_PATH}/lib) find_library(NNPACK_PTHREAD_CONTRIB_LIB pthreadpool ${NNPACK_PATH}/lib) + find_library(NNPACK_CPUINFO_CONTRIB_LIB cpuinfo ${NNPACK_PATH}/lib) + find_library(NNPACK_CLOG_CONTRIB_LIB clog ${NNPACK_PATH}/lib) list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CONTRIB_LIB}) list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_PTHREAD_CONTRIB_LIB}) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CPUINFO_CONTRIB_LIB}) + list(APPEND TVM_RUNTIME_LINKER_LIBS ${NNPACK_CLOG_CONTRIB_LIB}) endif(USE_NNPACK) diff --git a/python/tvm/contrib/nnpack.py b/python/tvm/contrib/nnpack.py index d6587df26229729a71fbd7b5a4f8ad0c35abc4df..36f8a76a87db8310cd5641accae121ee40f94545 100644 --- a/python/tvm/contrib/nnpack.py +++ b/python/tvm/contrib/nnpack.py @@ -63,14 +63,32 @@ def fully_connected_output(lhs, rhs, nthreads=1): "tvm.contrib.nnpack.fully_connected_output", ins[0], ins[1], outs[0], nthreads), name="C") -def convolution_inference(data, kernel, bias, padding, stride, nthreads=1): - """Create an extern op to do inference convolution of 3D tensor data and + +class ConvolutionAlgorithm: + AUTO = 0 + FFT_8x8 = 1 + FFT_16x16 = 2 + WT_8x8 = 3 + IMPLICIT_GEMM = 4 + DIRECT = 5 + WT_8x8_FP16 = 6 + + +class ConvolutionTransformStrategy: + COMPUTE = 1 + PRECOMPUTE = 2 + + +def convolution_inference( + data, kernel, bias, padding, stride, nthreads=1, + algorithm=ConvolutionAlgorithm.AUTO): + """Create an extern op to do inference convolution of 4D tensor data and 4D tensor kernel and 1D tensor bias with nnpack. Parameters ---------- data : Tensor - data 3D tensor input[input_channels][input_height][input_width] of + data 4D tensor input[batch][input_channels][input_height][input_width] of FP32 elements. kernel : Tensor kernel 4D tensor kernel[output_channels][input_channels][kernel_height] @@ -88,23 +106,108 @@ def convolution_inference(data, kernel, bias, padding, stride, nthreads=1): Returns ------- output : Tensor - output 3D tensor output[output_channels][output_height][output_width] + output 4D tensor output[batch][output_channels][output_height][output_width] of FP32 elements. """ assert isinstance(padding, list) and len(padding) == 4 assert isinstance(stride, list) and len(stride) == 2 - _, input_height, input_width = data.shape + batch, _, input_height, input_width = data.shape output_channels, _, kernel_height, kernel_width = kernel.shape output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1 output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1 return _api.extern( - (output_channels, output_height, output_width), [data, kernel, bias], + (batch, output_channels, output_height, output_width), + [data, kernel, bias] if bias is not None else [data, kernel], lambda ins, outs: _intrin.call_packed( - "tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], ins[2], + "tvm.contrib.nnpack.convolution_inference", + ins[0], + ins[1], + ins[2] if bias is not None else 0, outs[0], padding[0], padding[1], padding[2], padding[3], - stride[0], stride[1], nthreads), name="C") + stride[0], stride[1], nthreads, algorithm), name="C") + +def convolution_inference_without_weight_transform( + data, transformed_kernel, bias, padding, stride, nthreads=1, + algorithm=ConvolutionAlgorithm.AUTO): + """Create an extern op to do inference convolution of 4D tensor data and + 4D pre-transformed tensor kernel and 1D tensor bias with nnpack. + + Parameters + ---------- + data : Tensor + data 4D tensor input[batch][input_channels][input_height][input_width] of + FP32 elements. + transformed_kernel : Tensor + transformed_kernel 4D tensor kernel[output_channels][input_channels][tile] + [tile] of FP32 elements. + bias : Tensor + bias 1D array bias[output_channels][input_channels][kernel_height] + [kernel_width] of FP32 elements. + padding : list + padding A 4-dim list of [pad_top, pad_bottom, pad_left, pad_right], + which indicates the padding around the feature map. + stride : list + stride A 2-dim list of [stride_height, stride_width], which indicates + the stride. + + Returns + ------- + output : Tensor + output 4D tensor output[batch][output_channels][output_height][output_width] + of FP32 elements. + """ + + assert algorithm in (ConvolutionAlgorithm.WT_8x8, + ConvolutionAlgorithm.WT_8x8_FP16) + assert isinstance(padding, list) and len(padding) == 4 + assert isinstance(stride, list) and len(stride) == 2 + batch, _, input_height, input_width = data.shape + output_channels, _, _, _ = transformed_kernel.shape + kernel_height, kernel_width = (3, 3) + output_height = (input_height + padding[0] + padding[1] - kernel_height) / stride[0] + 1 + output_width = (input_width + padding[0] + padding[1] - kernel_width) / stride[1] + 1 + + return _api.extern( + (batch, output_channels, output_height, output_width), + [data, transformed_kernel, bias] if bias is not None else [data, transformed_kernel], + lambda ins, outs: _intrin.call_packed( + "tvm.contrib.nnpack.convolution_inference_without_weight_transform", + ins[0], + ins[1], + ins[2] if bias is not None else 0, + outs[0], padding[0], padding[1], padding[2], padding[3], + stride[0], stride[1], nthreads, algorithm), name="C") + +def convolution_inference_weight_transform( + kernel, nthreads=1, + algorithm=ConvolutionAlgorithm.AUTO): + """Create an extern op to do inference convolution of 3D tensor data and + 4D tensor kernel and 1D tensor bias with nnpack. + + Parameters + ---------- + kernel : Tensor + kernel 4D tensor kernel[output_channels][input_channels][kernel_height] + [kernel_width] of FP32 elements. + + Returns + ------- + output : Tensor + output 4D tensor output[output_channels][input_channels][tile][tile] + of FP32 elements. + """ + assert algorithm in (ConvolutionAlgorithm.WT_8x8, ConvolutionAlgorithm.WT_8x8_FP16) + output_channels, input_channels, _, _ = kernel.shape + + transform_tile_size = 8 + return _api.extern( + (output_channels, input_channels, transform_tile_size, transform_tile_size), + [kernel], + lambda ins, outs: _intrin.call_packed( + "tvm.contrib.nnpack.convolution_inference_weight_transform", + ins[0], outs[0], nthreads, algorithm), name="transform_kernel") def convolution_output(data, kernel, bias, padding, nthreads=1): """Create an extern op to compute convolution of 4D tensor data and @@ -144,4 +247,5 @@ def convolution_output(data, kernel, bias, padding, nthreads=1): "tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2], outs[0], padding[0], padding[1], padding[2], padding[3], nthreads), name="C") + _init_api("tvm.contrib.nnpack") diff --git a/src/contrib/nnpack/convolution.cc b/src/contrib/nnpack/convolution.cc index f658a1fe96d4956ce78efaf8a15bae5e7f4d33bb..8bcdd64281cce31bde3aa3f79c93a3dfbb43cab0 100644 --- a/src/contrib/nnpack/convolution.cc +++ b/src/contrib/nnpack/convolution.cc @@ -13,62 +13,208 @@ namespace contrib { using namespace runtime; TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") -.set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); - nnp_initialize(); - DLTensor* input = args[0]; - DLTensor* kernel = args[1]; - DLTensor* bias = args[2]; - DLTensor* output = args[3]; - uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7]; - nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; - uint64_t stride_width = args[8], stride_height = args[9]; - nnp_size stride_size{stride_width, stride_height}; - NNPackConfig(args[10]); + .set_body([](TVMArgs args, TVMRetValue *ret) { + NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + static std::once_flag flag; + std::call_once(flag, + []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); + DLTensor *input = args[0]; + DLTensor *kernel = args[1]; + DLTensor *bias = nullptr; + if (args[2].type_code() == kArrayHandle) { + bias = args[2]; + } + DLTensor *output = args[3]; + uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], + pad_left = args[7]; + nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; + uint64_t stride_width = args[8], stride_height = args[9]; + nnp_size stride_size{stride_width, stride_height}; + NNPackConfig(args[10]); - CHECK_EQ(input->ndim, 3); - CHECK_EQ(kernel->ndim, 4); - CHECK_EQ(bias->ndim, 1); - CHECK_EQ(output->ndim, 3); - - CHECK_EQ(input->shape[0], kernel->shape[1]); - size_t input_channels = input->shape[0]; - CHECK_EQ(output->shape[0], kernel->shape[0]); - CHECK_EQ(output->shape[0], bias->shape[0]); - size_t output_channels = output->shape[0]; - nnp_size input_size{static_cast<size_t>(input->shape[1]), - static_cast<size_t>(input->shape[2])}; - nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]), - static_cast<size_t>(kernel->shape[3])}; + uint64_t algo_ = args[11]; + nnp_convolution_algorithm algo = + static_cast<nnp_convolution_algorithm>(algo_); + CHECK_EQ(input->ndim, 4); + CHECK_EQ(kernel->ndim, 4); + if (bias) { + CHECK_EQ(bias->ndim, 1); + } + CHECK_EQ(output->ndim, 4); + CHECK_EQ(input->shape[1], kernel->shape[1]); + CHECK_EQ(input->shape[0], output->shape[0]); + size_t input_channels = input->shape[1]; + CHECK_EQ(output->shape[1], kernel->shape[0]); + if (bias) { + CHECK_EQ(output->shape[1], bias->shape[0]); + } + size_t output_channels = output->shape[1]; + nnp_size input_size{static_cast<size_t>(input->shape[2]), + static_cast<size_t>(input->shape[3])}; + nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]), + static_cast<size_t>(kernel->shape[3])}; + CHECK(input->strides == nullptr); + CHECK(kernel->strides == nullptr); + if (bias) { + CHECK(bias->strides == nullptr); + } - CHECK(input->strides == nullptr); - CHECK(kernel->strides == nullptr); - CHECK(bias->strides == nullptr); + CHECK(TypeMatch(input->dtype, kDLFloat, 32)); + CHECK(TypeMatch(kernel->dtype, kDLFloat, 32)); + if (bias) { + CHECK(TypeMatch(bias->dtype, kDLFloat, 32)); + } + CHECK(TypeMatch(output->dtype, kDLFloat, 32)); - CHECK(TypeMatch(input->dtype, kDLFloat, 32)); - CHECK(TypeMatch(kernel->dtype, kDLFloat, 32)); - CHECK(TypeMatch(bias->dtype, kDLFloat, 32)); - CHECK(TypeMatch(output->dtype, kDLFloat, 32)); + // Allocate a zero-bias if we don't pass one in. + std::unique_ptr<std::vector<float>> zero_bias; + if (!bias) { + zero_bias.reset(new std::vector<float>(output->shape[1], 0.0)); + } - nnp_convolution_inference(nnp_convolution_algorithm_auto, - nnp_convolution_transform_strategy_block_based, - input_channels, - output_channels, - input_size, - input_padding, - kernel_size, - stride_size, - static_cast<float*>(input->data), - static_cast<float*>(kernel->data), - static_cast<float*>(bias->data), - static_cast<float*>(output->data), - NULL, - NULL, - nnp_activation_identity, - NULL, - entry->threadpool, - NULL); - }); + for (auto n = 0; n < input->shape[0]; ++n) { + nnp_status status = nnp_convolution_inference( + algo, nnp_convolution_transform_strategy_compute, input_channels, + output_channels, input_size, input_padding, kernel_size, + stride_size, + static_cast<float *>(input->data) + n * input->shape[1] * + input->shape[2] * + input->shape[3], + static_cast<float *>(kernel->data), + bias ? static_cast<float *>(bias->data) : zero_bias->data(), + static_cast<float *>(output->data) + n * output->shape[1] * + output->shape[2] * + output->shape[3], + NULL, NULL, nnp_activation_identity, NULL, entry->threadpool, NULL); + + CHECK_EQ(status, nnp_status_success); + } + }); + +TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_transform") + .set_body([](TVMArgs args, TVMRetValue *ret) { + NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + static std::once_flag flag; + std::call_once(flag, + []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); + DLTensor *input = args[0]; + DLTensor *transformed_kernel = args[1]; + DLTensor *bias = nullptr; + if (args[2].type_code() == kArrayHandle) { + bias = args[2]; + } + DLTensor *output = args[3]; + uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], + pad_left = args[7]; + nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; + uint64_t stride_width = args[8], stride_height = args[9]; + nnp_size stride_size{stride_width, stride_height}; + NNPackConfig(args[10]); + + uint64_t algo_ = args[11]; + nnp_convolution_algorithm algo = + static_cast<nnp_convolution_algorithm>(algo_); + CHECK_EQ(input->ndim, 4); + if (bias) { + CHECK_EQ(bias->ndim, 1); + } + CHECK_EQ(output->ndim, 4); + CHECK_EQ(input->shape[0], output->shape[0]); + size_t input_channels = input->shape[1]; + if (bias) { + CHECK_EQ(output->shape[1], bias->shape[0]); + } + size_t output_channels = output->shape[1]; + nnp_size input_size{static_cast<size_t>(input->shape[2]), + static_cast<size_t>(input->shape[3])}; + nnp_size kernel_size{3, 3}; + CHECK(input->strides == nullptr); + CHECK(transformed_kernel->strides == nullptr); + if (bias) { + CHECK(bias->strides == nullptr); + } + + CHECK(TypeMatch(input->dtype, kDLFloat, 32)); + CHECK(TypeMatch(transformed_kernel->dtype, kDLFloat, 32)); + if (bias) { + CHECK(TypeMatch(bias->dtype, kDLFloat, 32)); + } + CHECK(TypeMatch(output->dtype, kDLFloat, 32)); + + // Allocate a zero-bias if we don't pass one in. + std::unique_ptr<std::vector<float>> zero_bias; + if (!bias) { + zero_bias.reset(new std::vector<float>(output->shape[1], 0.0)); + } + + for (auto n = 0; n < input->shape[0]; ++n) { + nnp_status status = nnp_convolution_inference( + algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels, + input_size, input_padding, kernel_size, stride_size, + static_cast<float *>(input->data) + n * input->shape[1] * + input->shape[2] * + input->shape[3], + static_cast<float *>(transformed_kernel->data), + bias ? static_cast<float *>(bias->data) : zero_bias->data(), + static_cast<float *>(output->data) + n * output->shape[1] * + output->shape[2] * + output->shape[3], + NULL, NULL, + nnp_activation_identity, NULL, entry->threadpool, NULL); + CHECK_EQ(status, nnp_status_success); + } + }); + +TVM_REGISTER_GLOBAL( + "tvm.contrib.nnpack.convolution_inference_weight_transform") + .set_body([](TVMArgs args, TVMRetValue *ret) { + NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + static std::once_flag flag; + std::call_once(flag, + []() { CHECK_EQ(nnp_initialize(), nnp_status_success); }); + DLTensor *kernel = args[0]; + DLTensor *transformed_kernel = args[1]; + // Dummy sizes + nnp_padding input_padding{1, 1, 1, 1}; + nnp_size stride_size{1, 1}; + + nnp_size input_size{100, 100}; + + NNPackConfig(args[2]); + + uint64_t algo_ = args[3]; + nnp_convolution_algorithm algo = + static_cast<nnp_convolution_algorithm>(algo_); + CHECK_EQ(kernel->ndim, 4); + size_t input_channels = kernel->shape[1]; + size_t output_channels = kernel->shape[0]; + CHECK_EQ(kernel->shape[2], 3); + CHECK_EQ(kernel->shape[3], 3); + nnp_size kernel_size{static_cast<size_t>(kernel->shape[2]), + static_cast<size_t>(kernel->shape[3])}; + CHECK(kernel->strides == nullptr); + CHECK(TypeMatch(kernel->dtype, kDLFloat, 32)); + + size_t transformed_kernel_size = 0; + nnp_status status; + status = nnp_convolution_inference( + algo, nnp_convolution_transform_strategy_precompute, input_channels, + output_channels, input_size, input_padding, kernel_size, stride_size, + nullptr, nullptr, nullptr, nullptr, nullptr, &transformed_kernel_size, + nnp_activation_identity, nullptr, entry->threadpool, nullptr); + CHECK_EQ(status, nnp_status_success); + + CHECK_LE(transformed_kernel_size, GetDataSize(*transformed_kernel)); + + status = nnp_convolution_inference( + algo, nnp_convolution_transform_strategy_precompute, input_channels, + output_channels, input_size, input_padding, kernel_size, stride_size, + nullptr, static_cast<float *>(kernel->data), nullptr, nullptr, + static_cast<float *>(transformed_kernel->data), + &transformed_kernel_size, nnp_activation_identity, nullptr, + entry->threadpool, nullptr); + CHECK_EQ(status, nnp_status_success); + }); TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output") @@ -109,7 +255,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output") CHECK(TypeMatch(bias->dtype, kDLFloat, 32)); CHECK(TypeMatch(output->dtype, kDLFloat, 32)); - nnp_convolution_output(nnp_convolution_algorithm_auto, + nnp_status status = nnp_convolution_output(nnp_convolution_algorithm_auto, batch_size, input_channels, output_channels, @@ -126,6 +272,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output") NULL, entry->threadpool, NULL); + CHECK_EQ(status, nnp_status_success); }); } // namespace contrib } // namespace tvm diff --git a/src/contrib/nnpack/nnpack_utils.cc b/src/contrib/nnpack/nnpack_utils.cc index 3220d7af339f66e2fb1173dcae20a8360dfa8375..d8ef1d0b83276e0534a1a00e71525ec20bafe92c 100644 --- a/src/contrib/nnpack/nnpack_utils.cc +++ b/src/contrib/nnpack/nnpack_utils.cc @@ -10,20 +10,30 @@ using namespace runtime; typedef dmlc::ThreadLocalStore<NNPackThreadLocalEntry> NNPackThreadLocalStore; + NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() { return NNPackThreadLocalStore::Get(); } bool NNPackConfig(uint64_t nthreads) { NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); - if (entry->threadpool != NULL && - pthreadpool_get_threads_count(entry->threadpool) != nthreads) { + if (entry->threadpool && pthreadpool_get_threads_count(entry->threadpool) == nthreads) { + CHECK_NE(nthreads, 1); + return true; + } + if (entry->threadpool) { pthreadpool_destroy(entry->threadpool); - entry->threadpool = NULL; + entry->threadpool = nullptr; } - if (entry->threadpool == NULL) { - entry->threadpool = pthreadpool_create(nthreads); + + if (nthreads == 1) { + // a null threadpool means the function is invoked on the calling thread, + // which is the desired logic for nthreads == 1 + CHECK(!entry->threadpool); + return true; } + + entry->threadpool = pthreadpool_create(nthreads); return true; } diff --git a/src/contrib/nnpack/nnpack_utils.h b/src/contrib/nnpack/nnpack_utils.h index fe7420786bdec2477479cf9be73e072f8bf7c3ad..1d44adff16ef9a6e53650992fbd04adad7cf7608 100644 --- a/src/contrib/nnpack/nnpack_utils.h +++ b/src/contrib/nnpack/nnpack_utils.h @@ -15,7 +15,7 @@ namespace contrib { using namespace runtime; struct NNPackThreadLocalEntry { - pthreadpool_t threadpool{NULL}; + pthreadpool_t threadpool{nullptr}; static NNPackThreadLocalEntry* ThreadLocal(); }; diff --git a/tests/lint/pylintrc b/tests/lint/pylintrc index f5c4452cfa163eb65ffd8943ff7d39d25e8d76a5..18f526702ad88deba459efede37b6e6c883e2338 100644 --- a/tests/lint/pylintrc +++ b/tests/lint/pylintrc @@ -290,10 +290,10 @@ variable-rgx=[a-z_][a-z0-9_]{2,30}$ variable-name-hint=[a-z_][a-z0-9_]{2,30}$ # Regular expression matching correct function names -function-rgx=[a-z_][a-z0-9_]{2,30}$ +function-rgx=[a-z_][a-z0-9_]{2,48}$ # Naming hint for function names -function-name-hint=[a-z_][a-z0-9_]{2,30}$ +function-name-hint=[a-z_][a-z0-9_]{2,48}$ # Regular expression matching correct class names class-rgx=[A-Z_][a-zA-Z0-9]+$ diff --git a/tests/python/contrib/test_nnpack.py b/tests/python/contrib/test_nnpack.py index a6c6b8158ff3bd06b2c4f528abe479a3c7a6b978..0b275fb812bf268785980a3dc50a6a5e63684219 100644 --- a/tests/python/contrib/test_nnpack.py +++ b/tests/python/contrib/test_nnpack.py @@ -100,7 +100,7 @@ def np_conv(na, nw, padding, stride=1): return nb def test_convolution_inference(): - BATCH = 32 + BATCH = 8 IH = 48 IW = 48 IC = 16 @@ -111,19 +111,17 @@ def test_convolution_inference(): OH = (IH + 2*PAD - K) + 1 OW = (IW + 2*PAD - K) + 1 - dshape = (IC, IH, IW) + dshape = (BATCH, IC, IH, IW) kshape = (OC, IC, K, K) bshape = (OC, ) - oshape = (OC, OH, OW) + oshape = (BATCH, OC, OH, OW) data = tvm.placeholder(dshape, name='data') kernel = tvm.placeholder(kshape, name='kernel') bias = tvm.placeholder(bshape, name='bias') - output = nnpack.convolution_inference(data, kernel, bias, - [PAD, PAD, PAD, PAD], [STRIDE, STRIDE]) - s = tvm.create_schedule(output.op) - - def verify(target="llvm"): + def verify(target="llvm", + algorithm=nnpack.ConvolutionAlgorithm.AUTO, + with_bias=True): if not tvm.module.enabled(target): print("skip because %s is not enabled..." % target) return @@ -131,6 +129,12 @@ def test_convolution_inference(): print("skip because extern function is not available") return ctx = tvm.cpu(0) + output = nnpack.convolution_inference( + data, kernel, bias if with_bias else None, + [PAD, PAD, PAD, PAD], [STRIDE, STRIDE], + algorithm=algorithm) + s = tvm.create_schedule(output.op) + f = tvm.build(s, [data, kernel, bias, output], target) na = np.random.uniform(size=dshape).astype(data.dtype) @@ -141,10 +145,77 @@ def test_convolution_inference(): tc = tvm.nd.array(nc, ctx) td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx) f(ta, tb, tc, td) - nd = np_conv(np.reshape(na, (1, IC, IH, IW)), nb, PAD, STRIDE) + nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(1, bshape[0], 1, 1) tvm.testing.assert_allclose( - td.asnumpy(), nd.reshape(IC, IH, IW), rtol=1e-5) - verify() + td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5) + for algorithm in [ + nnpack.ConvolutionAlgorithm.AUTO, + nnpack.ConvolutionAlgorithm.FFT_8x8, + nnpack.ConvolutionAlgorithm.FFT_16x16, + nnpack.ConvolutionAlgorithm.WT_8x8, + nnpack.ConvolutionAlgorithm.IMPLICIT_GEMM, + nnpack.ConvolutionAlgorithm.WT_8x8_FP16, + ]: + for with_bias in [True, False]: + verify(algorithm=algorithm, with_bias=with_bias) + + +def test_convolution_inference_without_weight_transform(): + BATCH = 6 + IH = 48 + IW = 48 + IC = 16 + OC = 16 + K = 3 + PAD = 1 + STRIDE = 1 + + OH = (IH + 2*PAD - K) + 1 + OW = (IW + 2*PAD - K) + 1 + dshape = (BATCH, IC, IH, IW) + kshape = (OC, IC, K, K) + bshape = (OC, ) + oshape = (BATCH, OC, OH, OW) + + data = tvm.placeholder(dshape, name='data') + kernel = tvm.placeholder(kshape, name='kernel') + bias = tvm.placeholder(bshape, name='bias') + def verify(target="llvm", + algorithm=nnpack.ConvolutionAlgorithm.AUTO, + with_bias=True): + if not tvm.module.enabled(target): + print("skip because %s is not enabled..." % target) + return + if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True): + print("skip because extern function is not available") + return + + ctx = tvm.cpu(0) + transformed_kernel = nnpack.convolution_inference_weight_transform( + kernel, algorithm=algorithm) + output = nnpack.convolution_inference_without_weight_transform( + data, transformed_kernel, bias if with_bias else None, + [PAD, PAD, PAD, PAD], [STRIDE, STRIDE], + algorithm=algorithm) + + s = tvm.create_schedule(output.op) + + f = tvm.build(s, [data, kernel, bias, output], target) + + na = np.random.uniform(size=dshape).astype(data.dtype) + nb = np.random.uniform(size=kshape).astype(kernel.dtype) + nc = np.random.uniform(size=bshape).astype(bias.dtype) if with_bias else np.zeros(bshape, dtype=bias.dtype) + ta = tvm.nd.array(na, ctx) + tb = tvm.nd.array(nb, ctx) + tc = tvm.nd.array(nc, ctx) + td = tvm.nd.array(np.zeros(oshape, dtype=output.dtype), ctx) + f(ta, tb, tc, td) + nd = np_conv(np.reshape(na, (BATCH, IC, IH, IW)), nb, PAD, STRIDE) + nc.reshape(1, bshape[0], 1, 1) + tvm.testing.assert_allclose( + td.asnumpy(), nd.reshape(BATCH, IC, IH, IW), rtol=1e-5) + for algorithm in [nnpack.ConvolutionAlgorithm.WT_8x8]: + for with_bias in [True, False]: + verify(algorithm=algorithm, with_bias=with_bias) def test_convolution_output(): BATCH = 32