From ff5dffa440a8787e94efe3a69972d7f094a32166 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Thu, 22 Nov 2018 10:32:00 -0800 Subject: [PATCH] [APPS] add an external dll call example (#2156) --- apps/extension/python/tvm_ext/__init__.py | 4 +++- apps/extension/src/tvm_ext.cc | 5 +++++ apps/extension/tests/test_ext.py | 20 ++++++++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/apps/extension/python/tvm_ext/__init__.py b/apps/extension/python/tvm_ext/__init__.py index 5045a9ec0..25286f67b 100644 --- a/apps/extension/python/tvm_ext/__init__.py +++ b/apps/extension/python/tvm_ext/__init__.py @@ -8,7 +8,9 @@ import tvm def load_lib(): """Load library, the functions will be registered into TVM""" curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - lib = ctypes.CDLL(os.path.join(curr_path, "../../lib/libtvm_ext.so")) + # load in as global so the global extern symbol is visible to other dll. + lib = ctypes.CDLL( + os.path.join(curr_path, "../../lib/libtvm_ext.so"), ctypes.RTLD_GLOBAL) return lib _LIB = load_lib() diff --git a/apps/extension/src/tvm_ext.cc b/apps/extension/src/tvm_ext.cc index bb8b4b694..362ac62de 100644 --- a/apps/extension/src/tvm_ext.cc +++ b/apps/extension/src/tvm_ext.cc @@ -66,6 +66,11 @@ TVM_REGISTER_GLOBAL("device_api.ext_dev") }); } // namespace tvm_ext +// External function exposed to runtime. +extern "C" float TVMTestAddOne(float y) { + return y + 1; +} + // This callback approach allows extension allows tvm to extract // This way can be helpful when we want to use a header only // minimum version of TVM Runtime. diff --git a/apps/extension/tests/test_ext.py b/apps/extension/tests/test_ext.py index b7b97897a..def308031 100644 --- a/apps/extension/tests/test_ext.py +++ b/apps/extension/tests/test_ext.py @@ -49,7 +49,27 @@ def test_extract_ext(): assert fdict["mul"](3, 4) == 12 +def test_extern_call(): + n = 10 + A = tvm.placeholder((n,), name='A') + B = tvm.compute((n,), lambda *i: tvm.call_extern("float32", "TVMTestAddOne", A(*i)), name='B') + s = tvm.create_schedule(B.op) + + def check_llvm(): + if not tvm.module.enabled("llvm"): + return + f = tvm.build(s, [A, B], "llvm") + ctx = tvm.cpu(0) + # launch the kernel. + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx) + f(a, b) + tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1) + check_llvm() + + if __name__ == "__main__": + test_extern_call() test_ext_dev() test_ext_vec() test_bind_add() -- GitLab