diff --git a/python/tvm/contrib/nnpack.py b/python/tvm/contrib/nnpack.py index 36f8a76a87db8310cd5641accae121ee40f94545..3fb00a3f85e5b034f72888eb5d4707a3f514c415 100644 --- a/python/tvm/contrib/nnpack.py +++ b/python/tvm/contrib/nnpack.py @@ -5,16 +5,11 @@ from .. import api as _api from .. import intrin as _intrin from .._ffi.function import _init_api -def config(nthreads): - """Configure the nnpack library. - - Parameters - ---------- - nthreads : int - The threads number of nnpack thread pool, must be a nonnegative. - +def is_available(): + """Check whether NNPACK is available, that is, `nnp_initialize()` + returns `nnp_status_success`. """ - _Config(nthreads) + return _initialize() == 0 def fully_connected_inference(lhs, rhs, nthreads=1): """Create an extern op that compute fully connected of 1D tensor lhs and diff --git a/src/contrib/nnpack/nnpack_utils.cc b/src/contrib/nnpack/nnpack_utils.cc index d8ef1d0b83276e0534a1a00e71525ec20bafe92c..12eb828cc7e6a77a08ce967e5f7b8b81adfc7089 100644 --- a/src/contrib/nnpack/nnpack_utils.cc +++ b/src/contrib/nnpack/nnpack_utils.cc @@ -38,9 +38,10 @@ bool NNPackConfig(uint64_t nthreads) { } -TVM_REGISTER_GLOBAL("contrib.nnpack._Config") +TVM_REGISTER_GLOBAL("contrib.nnpack._initialize") .set_body([](TVMArgs args, TVMRetValue *ret) { - CHECK(NNPackConfig(args[0])); + *ret = nnp_initialize(); }); + } // namespace contrib } // namespace tvm diff --git a/tests/python/contrib/test_nnpack.py b/tests/python/contrib/test_nnpack.py index 151869729d42b90ce2c27ce8f69b3cd029af2c2e..a4b77a39af63da817b7eaa64acc6e62f558ce075 100644 --- a/tests/python/contrib/test_nnpack.py +++ b/tests/python/contrib/test_nnpack.py @@ -21,7 +21,9 @@ def test_fully_connected_output(): if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_output", True): print("skip because extern function is not available") return - return + if not nnpack.is_available(): + return + ctx = tvm.cpu(0) f = tvm.build(s, [A, B, D, bias], target) a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), ctx) @@ -52,7 +54,9 @@ def test_fully_connected_inference(): if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True): print("skip because extern function is not available") return - return + if not nnpack.is_available(): + return + ctx = tvm.cpu(0) f = tvm.build(s, [A, B, D, bias], target) a = tvm.nd.array(np.random.uniform(size=(l)).astype(A.dtype), ctx) @@ -130,7 +134,9 @@ def test_convolution_inference(): if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True): print("skip because extern function is not available") return - return + if not nnpack.is_available(): + return + ctx = tvm.cpu(0) output = nnpack.convolution_inference( data, kernel, bias if with_bias else None, @@ -192,7 +198,9 @@ def test_convolution_inference_without_weight_transform(): if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True): print("skip because extern function is not available") return - return + if not nnpack.is_available(): + return + ctx = tvm.cpu(0) transformed_kernel = nnpack.convolution_inference_weight_transform( kernel, algorithm=algorithm) @@ -249,7 +257,9 @@ def test_convolution_output(): if not tvm.get_global_func("tvm.contrib.nnpack.fully_connected_inference", True): print("skip because extern function is not available") return - return + if not nnpack.is_available(): + return + ctx = tvm.cpu(0) f = tvm.build(s, [data, kernel, bias, output], target)