From 182a7852de5ae3b21985191cc1c9dda23f9268d3 Mon Sep 17 00:00:00 2001 From: ziheng <ziheng@apache.org> Date: Fri, 10 Nov 2017 19:02:46 -0800 Subject: [PATCH] [NNPACK] Add argument nthreads (#631) --- python/tvm/contrib/nnpack.py | 16 ++++++++-------- src/contrib/nnpack/convolution.cc | 2 ++ src/contrib/nnpack/fully_connected.cc | 4 ++++ src/contrib/nnpack/nnpack_utils.cc | 25 +++++++++++++++---------- src/contrib/nnpack/nnpack_utils.h | 2 ++ 5 files changed, 31 insertions(+), 18 deletions(-) diff --git a/python/tvm/contrib/nnpack.py b/python/tvm/contrib/nnpack.py index 66e7a9494..d6587df26 100644 --- a/python/tvm/contrib/nnpack.py +++ b/python/tvm/contrib/nnpack.py @@ -16,7 +16,7 @@ def config(nthreads): """ _Config(nthreads) -def fully_connected_inference(lhs, rhs): +def fully_connected_inference(lhs, rhs, nthreads=1): """Create an extern op that compute fully connected of 1D tensor lhs and 2D tensor rhs with nnpack. @@ -37,9 +37,9 @@ def fully_connected_inference(lhs, rhs): (m, ), [lhs, rhs], lambda ins, outs: _intrin.call_packed( "tvm.contrib.nnpack.fully_connected_inference", - ins[0], ins[1], outs[0]), name="C") + ins[0], ins[1], outs[0], nthreads), name="C") -def fully_connected_output(lhs, rhs): +def fully_connected_output(lhs, rhs, nthreads=1): """Create an extern op that compute fully connected of 2D tensor lhs and 2D tensor rhs with nnpack. @@ -61,9 +61,9 @@ def fully_connected_output(lhs, rhs): (n, m), [lhs, rhs], lambda ins, outs: _intrin.call_packed( "tvm.contrib.nnpack.fully_connected_output", - ins[0], ins[1], outs[0]), name="C") + ins[0], ins[1], outs[0], nthreads), name="C") -def convolution_inference(data, kernel, bias, padding, stride): +def convolution_inference(data, kernel, bias, padding, stride, nthreads=1): """Create an extern op to do inference convolution of 3D tensor data and 4D tensor kernel and 1D tensor bias with nnpack. @@ -104,9 +104,9 @@ def convolution_inference(data, kernel, bias, padding, stride): lambda ins, outs: _intrin.call_packed( "tvm.contrib.nnpack.convolution_inference", ins[0], ins[1], ins[2], outs[0], padding[0], padding[1], padding[2], padding[3], - stride[0], stride[1]), name="C") + stride[0], stride[1], nthreads), name="C") -def convolution_output(data, kernel, bias, padding): +def convolution_output(data, kernel, bias, padding, nthreads=1): """Create an extern op to compute convolution of 4D tensor data and 4D tensor kernel and 1D tensor bias with nnpack. @@ -142,6 +142,6 @@ def convolution_output(data, kernel, bias, padding): (batch, output_channels, output_height, output_width), [data, kernel, bias], lambda ins, outs: _intrin.call_packed( "tvm.contrib.nnpack.convolution_output", ins[0], ins[1], ins[2], - outs[0], padding[0], padding[1], padding[2], padding[3]), name="C") + outs[0], padding[0], padding[1], padding[2], padding[3], nthreads), name="C") _init_api("tvm.contrib.nnpack") diff --git a/src/contrib/nnpack/convolution.cc b/src/contrib/nnpack/convolution.cc index 8480a100d..9ca02118a 100644 --- a/src/contrib/nnpack/convolution.cc +++ b/src/contrib/nnpack/convolution.cc @@ -24,6 +24,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; uint64_t stride_width = args[8], stride_height = args[9]; nnp_size stride_size{stride_width, stride_height}; + NNPackConfig(args[10]); CHECK_EQ(input->ndim, 3); CHECK_EQ(kernel->ndim, 4); @@ -80,6 +81,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output") DLTensor* output = args[3]; uint64_t pad_top = args[4], pad_right = args[5], pad_bottom = args[6], pad_left = args[7]; nnp_padding input_padding{pad_top, pad_right, pad_bottom, pad_left}; + NNPackConfig(args[8]); CHECK_EQ(input->ndim, 4); CHECK_EQ(kernel->ndim, 4); diff --git a/src/contrib/nnpack/fully_connected.cc b/src/contrib/nnpack/fully_connected.cc index 6793ecaa3..df6356d93 100644 --- a/src/contrib/nnpack/fully_connected.cc +++ b/src/contrib/nnpack/fully_connected.cc @@ -21,6 +21,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference") DLTensor* A = args[0]; DLTensor* B = args[1]; DLTensor* C = args[2]; + NNPackConfig(args[3]); + CHECK_EQ(A->ndim, 1); CHECK_EQ(B->ndim, 2); CHECK_EQ(C->ndim, 1); @@ -49,6 +51,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output") DLTensor* A = args[0]; DLTensor* B = args[1]; DLTensor* C = args[2]; + NNPackConfig(args[3]); + CHECK_EQ(A->ndim, 2); CHECK_EQ(B->ndim, 2); CHECK_EQ(C->ndim, 2); diff --git a/src/contrib/nnpack/nnpack_utils.cc b/src/contrib/nnpack/nnpack_utils.cc index e1e2773c1..631f25b36 100644 --- a/src/contrib/nnpack/nnpack_utils.cc +++ b/src/contrib/nnpack/nnpack_utils.cc @@ -14,18 +14,23 @@ NNPackThreadLocalEntry* NNPackThreadLocalEntry::ThreadLocal() { return NNPackThreadLocalStore::Get(); } +bool NNPackConfig(uint64_t nthreads) { + NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); + if (entry->threadpool != NULL && + pthreadpool_get_threads_count(entry->threadpool) != nthreads) { + pthreadpool_destroy(entry->threadpool); + entry->threadpool = NULL; + } + if (entry->threadpool == NULL) { + entry->threadpool = pthreadpool_create(nthreads); + } + return true; +} + + TVM_REGISTER_GLOBAL("contrib.nnpack._Config") .set_body([](TVMArgs args, TVMRetValue *ret) { - NNPackThreadLocalEntry *entry = NNPackThreadLocalEntry::ThreadLocal(); - size_t nthreads = args[0].operator uint64_t(); - if (entry->threadpool != NULL && - pthreadpool_get_threads_count(entry->threadpool) != nthreads) { - pthreadpool_destroy(entry->threadpool); - entry->threadpool = NULL; - } - if (entry->threadpool == NULL) { - entry->threadpool = pthreadpool_create(nthreads); - } + CHECK(NNPackConfig(args[0])); }); } // namespace contrib } // namespace tvm diff --git a/src/contrib/nnpack/nnpack_utils.h b/src/contrib/nnpack/nnpack_utils.h index 7a2232add..fe7420786 100644 --- a/src/contrib/nnpack/nnpack_utils.h +++ b/src/contrib/nnpack/nnpack_utils.h @@ -18,6 +18,8 @@ struct NNPackThreadLocalEntry { pthreadpool_t threadpool{NULL}; static NNPackThreadLocalEntry* ThreadLocal(); }; + +bool NNPackConfig(uint64_t nthreads); } // namespace contrib } // namespace tvm #endif // TVM_CONTRIB_NNPACK_NNPACK_UTILS_H_ -- GitLab