From b3f09b019e3e150f416d6e45ad01e252b6b893d8 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Thu, 12 Apr 2018 17:23:20 -0700 Subject: [PATCH] [RPC] LocalSession to provide RPCSession back by local env (#1102) --- python/tvm/contrib/rpc/__init__.py | 2 +- python/tvm/contrib/rpc/client.py | 86 ++++++++++++++++------- python/tvm/contrib/rpc/server.py | 13 +--- python/tvm/module.py | 19 ++++- tests/python/unittest/test_runtime_rpc.py | 26 +++++-- 5 files changed, 102 insertions(+), 44 deletions(-) diff --git a/python/tvm/contrib/rpc/__init__.py b/python/tvm/contrib/rpc/__init__.py index 3335240ac..c3fbf5be2 100644 --- a/python/tvm/contrib/rpc/__init__.py +++ b/python/tvm/contrib/rpc/__init__.py @@ -10,4 +10,4 @@ upload and run remote RPC server, get the result back to verify correctness. """ from .server import Server -from .client import RPCSession, connect, connect_tracker +from .client import RPCSession, LocalSession, connect, connect_tracker diff --git a/python/tvm/contrib/rpc/client.py b/python/tvm/contrib/rpc/client.py index 24b3b9582..f409b2f72 100644 --- a/python/tvm/contrib/rpc/client.py +++ b/python/tvm/contrib/rpc/client.py @@ -7,8 +7,11 @@ import struct import time from . import base +from .. import util from ..._ffi.base import TVMError -from ..._ffi.ndarray import context as _context +from ..._ffi import function as function +from ..._ffi import ndarray as nd +from ...module import load as _load_module class RPCSession(object): @@ -51,36 +54,12 @@ class RPCSession(object): ctx: TVMContext The corresponding encoded remote context. """ - ctx = _context(dev_type, dev_id) + ctx = nd.context(dev_type, dev_id) encode = (self._tbl_index + 1) * base.RPC_SESS_MASK ctx.device_type += encode ctx._rpc_sess = self return ctx - def cpu(self, dev_id=0): - """Construct remote CPU device.""" - return self.context(1, dev_id) - - def gpu(self, dev_id=0): - """Construct remote GPU device.""" - return self.context(2, dev_id) - - def cl(self, dev_id=0): - """Construct remote OpenCL device.""" - return self.context(4, dev_id) - - def metal(self, dev_id=0): - """Construct remote Metal device.""" - return self.context(8, dev_id) - - def opengl(self, dev_id=0): - """Construct remote OpenGL device.""" - return self.context(11, dev_id) - - def ext_dev(self, dev_id=0): - """Construct remote extension device.""" - return self.context(12, dev_id) - def upload(self, data, target=None): """Upload file to remote runtime temp folder @@ -139,6 +118,61 @@ class RPCSession(object): """ return base._LoadRemoteModule(self._sess, path) + def cpu(self, dev_id=0): + """Construct CPU device.""" + return self.context(1, dev_id) + + def gpu(self, dev_id=0): + """Construct GPU device.""" + return self.context(2, dev_id) + + def cl(self, dev_id=0): + """Construct OpenCL device.""" + return self.context(4, dev_id) + + def metal(self, dev_id=0): + """Construct Metal device.""" + return self.context(8, dev_id) + + def opengl(self, dev_id=0): + """Construct OpenGL device.""" + return self.context(11, dev_id) + + def ext_dev(self, dev_id=0): + """Construct extension device.""" + return self.context(12, dev_id) + + +class LocalSession(RPCSession): + """RPCSession interface backed by local environment. + + This class can be used to implement functions that + need to be ran both locally and remotely. + """ + def __init__(self): + # pylint: disable=super-init-not-called + self.context = nd.context + self.get_function = function.get_global_func + self._temp = util.tempdir() + + def upload(self, data, target=None): + if isinstance(data, bytearray): + if not target: + raise ValueError("target must present when file is a bytearray") + blob = data + else: + blob = bytearray(open(data, "rb").read()) + if not target: + target = os.path.basename(data) + with open(self._temp.relpath(target), "wb") as f: + f.write(blob) + + def download(self, path): + return bytearray(open(self._temp.relpath(path), "rb").read()) + + def load_module(self, path): + return _load_module(self._temp.relpath(path)) + class TrackerSession(object): """Tracker client session. diff --git a/python/tvm/contrib/rpc/server.py b/python/tvm/contrib/rpc/server.py index 8cd1c4ee8..6759f13b6 100644 --- a/python/tvm/contrib/rpc/server.py +++ b/python/tvm/contrib/rpc/server.py @@ -22,7 +22,7 @@ import time from ..._ffi.function import register_func from ..._ffi.base import py_str from ...module import load as _load_module -from .. import util, cc, tar +from .. import util from . import base from . base import TrackerCode @@ -38,17 +38,6 @@ def _server_env(): def load_module(file_name): """Load module from remote side.""" path = temp.relpath(file_name) - # Try create a shared library in remote - if path.endswith(".o"): - logging.info("Create shared library based on %s", path) - cc.create_shared(path + ".so", path) - path += ".so" - elif path.endswith(".tar"): - tar_temp = util.tempdir() - tar.untar(path, tar_temp.temp_dir) - files = [tar_temp.relpath(x) for x in tar_temp.listdir()] - cc.create_shared(path + ".so", files) - path += ".so" m = _load_module(path) logging.info("load_module %s", path) return m diff --git a/python/tvm/module.py b/python/tvm/module.py index 6459733fa..1b83c9b26 100644 --- a/python/tvm/module.py +++ b/python/tvm/module.py @@ -186,7 +186,7 @@ def system_lib(): def load(path, fmt=""): - """Load module from file + """Load module from file. Parameters ---------- @@ -201,7 +201,24 @@ def load(path, fmt=""): ------- module : Module The loaded module + + Note + ---- + This function will automatically call + cc.create_shared if the path is in format .o or .tar """ + # High level handling for .o and .tar file. + # We support this to be consistent with RPC module load. + if path.endswith(".o"): + _cc.create_shared(path + ".so", path) + path += ".so" + elif path.endswith(".tar"): + tar_temp = _util.tempdir() + _tar.untar(path, tar_temp.temp_dir) + files = [tar_temp.relpath(x) for x in tar_temp.listdir()] + _cc.create_shared(path + ".so", files) + path += ".so" + # Redirect to the load API return _LoadFromFile(path, fmt) diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 46e84d887..581890efe 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -61,14 +61,14 @@ def test_rpc_remote_module(): if not tvm.module.enabled("rpc"): return server = rpc.Server("localhost") - remote = rpc.connect(server.host, server.port) + client = rpc.connect(server.host, server.port) # graph n = tvm.convert(1024) A = tvm.placeholder((n,), name='A') B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') s = tvm.create_schedule(B.op) - def check_remote(): + def check_remote(remote): if not tvm.module.enabled("llvm"): print("Skip because llvm is not enabled") return @@ -86,7 +86,7 @@ def test_rpc_remote_module(): print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - def check_remote_link_cl(): + def check_remote_link_cl(remote): """Test function to run remote code such as cl This is not enabled because there is forking issue @@ -134,7 +134,9 @@ def test_rpc_remote_module(): fhost(a, b) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) - check_remote() + check_remote(client) + check_remote(rpc.LocalSession()) + def test_rpc_return_func(): @tvm.register_func("rpc.test.remote_func") @@ -147,6 +149,21 @@ def test_rpc_return_func(): assert fadd(12) == 22 +def test_local_func(): + @tvm.register_func("rpc.test.remote_func2") + def addone(x): + return lambda y: x+y + client = rpc.LocalSession() + f1 = client.get_function("rpc.test.remote_func2") + fadd = f1(10) + assert fadd(12) == 22 + + blob = bytearray(np.random.randint(0, 10, size=(10))) + client.upload(blob, "dat.bin") + rev = client.download("dat.bin") + assert rev == blob + + if __name__ == "__main__": logging.basicConfig(level=logging.INFO) test_rpc_remote_module() @@ -154,3 +171,4 @@ if __name__ == "__main__": test_rpc_file_exchange() test_rpc_array() test_rpc_simple() + test_local_func() -- GitLab