diff --git a/python/tvm/contrib/nnpack.py b/python/tvm/contrib/nnpack.py index 66e7a9494d89c84f99a6ef6616fdc29511473aa5..d6587df26229729a71fbd7b5a4f8ad0c35abc4df 100644 --- a/python/tvm/contrib/nnpack.py +++ b/python/tvm/contrib/nnpack.py @@ -16,7 +16,7 @@ def config(nthreads): """ _Config(nthreads) -def fully_connected_inference(lhs, rhs): +def fully_connected_inference(lhs, rhs, nthreads=1): """Create an extern op that compute fully connected of 1D tensor lhs and 2D tensor rhs with nnpack. @@ -37,9 +37,9 @@ def fully_connected_inference(lhs, rhs): (m, ), [lhs, rhs], lambda ins, outs: _intrin.call_packed( "tvm.contrib.nnpack.fully_connected_inference", - ins[0], ins[1], outs[0]), name="C") + ins[0], ins[1], outs[0], nthreads), name="C") -def fully_connected_output(lhs, rhs): +def fully_connected_output(lhs, rhs, nthreads=1): """Create an extern op that compute fully connected of 2D tensor lhs and 2D tensor rhs with nnpack. @@ -61,9 +61,9 @@ def fully_connected_output(lhs, rhs): (n, m), [lhs, rhs], lambda ins, outs: _intrin.call_packed( "tvm.contrib.nnpack.fully_connected_output", - ins[0], ins[1], outs[0]), name="C") + ins[0], ins[1], outs[0], nthreads), name="C") -def convolution_inference(data, kernel, bias, padding, stride): +def convolution_inference(data, kernel, bias, padding, stride, nthreads=1): """Create an extern op to do inference convolution of 3D tensor data and 4D tensor kernel and 1D tensor bias with nnpack. @@ -104,9 +104,9 @@ def convolution_inference(data, kernel, bias, padding, stride): lambda ins, outs: _intrin.call_packed( "tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], ins[2], outs[0], padding[0], padding[1], padding[2], padding[3], - stride[0], stride[1]), name="C") + stride[0], stride[1], nthreads), name="C") -def convolution_output(data, kernel, bias, padding): +def convolution_output(data, kernel, bias, padding, nthreads=1): """Create an extern op to compute convolution of 4D tensor data and 4D tensor kernel and 1D tensor bias with nnpack. @@ -142,6 +142,6 @@ def convolution_output(data, kernel, bias, padding): (batch, output_channels, output_height, output_width), [data, kernel, bias], lambda ins, outs: _intrin.call_packed( "tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2], - outs[0], padding[0], padding[1], padding[2], padding[3]), name="C") + outs[0], padding[0], padding[1], padding[2], padding[3], nthreads), name="C") _init_api("tvm.contrib.nnpack") diff --git a/src/contrib/nnpack/convolution.cc b/src/contrib/nnpack/convolution.cc index 8480a100dfd76207d4cca085f9effc7ec3705e7f..9ca02118aeb3d20834f46693eadc01a4f8ef5832 100644 --- a/src/contrib/nnpack/convolution.cc +++ b/src/contrib/nnpack/convolution.cc @@ -24,6 +24,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; uint64_t stride_width = args[8], stride_height = args[9]; nnp_size stride_size{stride_width, stride_height}; + NNPackConfig(args[10]); CHECK_EQ(input->ndim, 3); CHECK_EQ(kernel->ndim, 4); @@ -80,6 +81,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output") DLTensor* output = args[3]; uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7]; nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; + NNPackConfig(args[8]); CHECK_EQ(input->ndim, 4); CHECK_EQ(kernel->ndim, 4); diff --git a/src/contrib/nnpack/fully_connected.cc b/src/contrib/nnpack/fully_connected.cc index 6793ecaa36a7968a317d788c822ef38caf767f58..df6356d933aa5d0074dc8dbf216f49308c4790d1 100644 --- a/src/contrib/nnpack/fully_connected.cc +++ b/src/contrib/nnpack/fully_connected.cc @@ -21,6 +21,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference") DLTensor* A = args[0]; DLTensor* B = args[1]; DLTensor* C = args[2]; + NNPackConfig(args[3]); + CHECK_EQ(A->ndim, 1); CHECK_EQ(B->ndim, 2); CHECK_EQ(C->ndim, 1); @@ -49,6 +51,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output") DLTensor* A = args[0]; DLTensor* B = args[1]; DLTensor* C = args[2]; + NNPackConfig(args[3]); + CHECK_EQ(A->ndim, 2); CHECK_EQ(B->ndim, 2); CHECK_EQ(C->ndim, 2); diff --git a/src/contrib/nnpack/nnpack_utils.cc b/src/contrib/nnpack/nnpack_utils.cc index e1e2773c1c8d3891313b6a94bedcc5e1de57e1c8..631f25b36647fae42aba96e0cc640d0154163630 100644 --- a/src/contrib/nnpack/nnpack_utils.cc +++ b/src/contrib/nnpack/nnpack_utils.cc @@ -14,18 +14,23 @@ NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() { return NNPackThreadLocalStore::Get(); } +bool NNPackConfig(uint64_t nthreads) { + NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + if (entry->threadpool != NULL && + pthreadpool_get_threads_count(entry->threadpool) != nthreads) { + pthreadpool_destroy(entry->threadpool); + entry->threadpool = NULL; + } + if (entry->threadpool == NULL) { + entry->threadpool = pthreadpool_create(nthreads); + } + return true; +} + + TVM_REGISTER_GLOBAL("contrib.nnpack._Config") .set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); - size_t nthreads = args[0].operator uint64_t(); - if (entry->threadpool != NULL && - pthreadpool_get_threads_count(entry->threadpool) != nthreads) { - pthreadpool_destroy(entry->threadpool); - entry->threadpool = NULL; - } - if (entry->threadpool == NULL) { - entry->threadpool = pthreadpool_create(nthreads); - } + CHECK(NNPackConfig(args[0])); }); } // namespace contrib } // namespace tvm diff --git a/src/contrib/nnpack/nnpack_utils.h b/src/contrib/nnpack/nnpack_utils.h index 7a2232add145802954fa352043874616334b0b17..fe7420786bdec2477479cf9be73e072f8bf7c3ad 100644 --- a/src/contrib/nnpack/nnpack_utils.h +++ b/src/contrib/nnpack/nnpack_utils.h @@ -18,6 +18,8 @@ struct NNPackThreadLocalEntry { pthreadpool_t threadpool{NULL}; static NNPackThreadLocalEntry* ThreadLocal(); }; + +bool NNPackConfig(uint64_t nthreads); } // namespace contrib } // namespace tvm #endif // TVM_CONTRIB_NNPACK_NNPACK_UTILS_H_