From 1077f8e8147ba335098daa6b120499a1904d6f69 Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Sun, 10 Sep 2017 23:56:46 -0700 Subject: [PATCH] [RUNTIME][RPC] Enable remote linking of device code. (#444) * [RUNTIME][RPC] Enable remote linking of device code. * fix build --- python/tvm/contrib/rpc.py | 34 +++++++----- python/tvm/contrib/tar.py | 68 +++++++++++++++++++++++ python/tvm/module.py | 8 ++- src/runtime/module.cc | 11 ++++ src/runtime/rpc/rpc_module.cc | 20 +++++++ src/runtime/rpc/rpc_session.cc | 8 +++ src/runtime/rpc/rpc_session.h | 1 + tests/python/unittest/test_runtime_rpc.py | 51 ++++++++++++++++- 8 files changed, 184 insertions(+), 17 deletions(-) create mode 100644 python/tvm/contrib/tar.py diff --git a/python/tvm/contrib/rpc.py b/python/tvm/contrib/rpc.py index 6376f476d..a6b5dfa6c 100644 --- a/python/tvm/contrib/rpc.py +++ b/python/tvm/contrib/rpc.py @@ -15,7 +15,7 @@ import socket import struct import logging import multiprocessing -from . import util, cc +from . import util, cc, tar from ..module import load as _load_module from .._ffi.function import _init_api, register_func from .._ffi.ndarray import context as _context @@ -37,10 +37,16 @@ def _server_env(): """Load module from remote side.""" path = temp.relpath(file_name) # Try create a shared library in remote - if path.endswith('.o'): - logging.info('Create shared library based on %s', path) - cc.create_shared(path + '.so', path) - path += '.so' + if path.endswith(".o"): + logging.info("Create shared library based on %s", path) + cc.create_shared(path + ".so", path) + path += ".so" + elif path.endswith(".tar"): + tar_temp = util.tempdir() + tar.untar(path, tar_temp.temp_dir) + files = [tar_temp.relpath(x) for x in tar_temp.listdir()] + cc.create_shared(path + ".so", files) + path += ".so" m = _load_module(path) logging.info("load_module %s", path) return m @@ -63,7 +69,7 @@ def _recvall(sock, nbytes): chunk = sock.recv(min(nbytes - nread, 1024)) nread += len(chunk) res.append(chunk) - return b''.join(res) + return b"".join(res) def _listen_loop(sock): @@ -71,16 +77,16 @@ def _listen_loop(sock): while True: conn, addr = sock.accept() logging.info("RPCServer: connection from %s", addr) - magic = struct.unpack('@i', _recvall(conn, 4))[0] + magic = struct.unpack("@i", _recvall(conn, 4))[0] if magic != RPC_MAGIC: conn.close() continue - keylen = struct.unpack('@i', _recvall(conn, 4))[0] + keylen = struct.unpack("@i", _recvall(conn, 4))[0] key = py_str(_recvall(conn, keylen)) if not key.startswith("client:"): - conn.sendall(struct.pack('@i', RPC_MAGIC + 2)) + conn.sendall(struct.pack("@i", RPC_MAGIC + 2)) else: - conn.sendall(struct.pack('@i', RPC_MAGIC)) + conn.sendall(struct.pack("@i", RPC_MAGIC)) logging.info("Connection from %s", addr) process = multiprocessing.Process(target=_serve_loop, args=(conn, addr)) process.deamon = True @@ -94,10 +100,10 @@ def _connect_proxy_loop(addr, key): while True: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(addr) - sock.sendall(struct.pack('@i', RPC_MAGIC)) - sock.sendall(struct.pack('@i', len(key))) + sock.sendall(struct.pack("@i", RPC_MAGIC)) + sock.sendall(struct.pack("@i", len(key))) sock.sendall(key.encode("utf-8")) - magic = struct.unpack('@i', _recvall(sock, 4))[0] + magic = struct.unpack("@i", _recvall(sock, 4))[0] if magic == RPC_MAGIC + 1: raise RuntimeError("key: %s has already been used in proxy" % key) elif magic == RPC_MAGIC + 2: @@ -321,7 +327,7 @@ def connect(url, port, key=""): try: sess = _Connect(url, port, key) except NameError: - raise RuntimeError('Please compile with USE_RPC=1') + raise RuntimeError("Please compile with USE_RPC=1") return RPCSession(sess) _init_api("tvm.contrib.rpc") diff --git a/python/tvm/contrib/tar.py b/python/tvm/contrib/tar.py new file mode 100644 index 000000000..ca3bf3478 --- /dev/null +++ b/python/tvm/contrib/tar.py @@ -0,0 +1,68 @@ + +"""Util to invoke tarball in the system.""" +# pylint: disable=invalid-name +from __future__ import absolute_import as _abs +import os +import shutil +import subprocess +from . import util + +def tar(output, files): + """Create tarball containing all files in root. + + Parameters + ---------- + output : str + The target shared library. + + files : list + List of files to be bundled. + """ + cmd = ["tar"] + cmd += ["-czf"] + temp = util.tempdir() + fset = set() + for fname in files: + base = os.path.basename(fname) + if base in fset: + raise ValueError("duplicate file name %s" % base) + fset.add(base) + shutil.copy(fname, temp.relpath(base)) + cmd += [output] + cmd += ["-C", temp.temp_dir] + cmd += temp.listdir() + proc = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + + if proc.returncode != 0: + msg = "Tar error:\n" + msg += out + raise RuntimeError(msg) + + +def untar(tar_file, directory): + """Unpack all tar files into the directory + + Parameters + ---------- + tar_file : str + The source tar file. + + directory : str + The target directory + """ + cmd = ["tar"] + cmd += ["-xf"] + cmd += [tar_file] + cmd += ["-C", directory] + proc = subprocess.Popen(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT) + (out, _) = proc.communicate() + + if proc.returncode != 0: + msg = "Tar error:\n" + msg += out + raise RuntimeError(msg) diff --git a/python/tvm/module.py b/python/tvm/module.py index d5bf49cf8..9c44a9f1e 100644 --- a/python/tvm/module.py +++ b/python/tvm/module.py @@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs from collections import namedtuple from ._ffi.function import ModuleBase, _set_class_module from ._ffi.function import _init_api -from .contrib import cc as _cc, util as _util +from .contrib import cc as _cc, tar as _tar, util as _util ProfileResult = namedtuple("ProfileResult", ["mean"]) @@ -100,7 +100,11 @@ class Module(ModuleBase): with open(path_cc, "w") as f: f.write(_PackImportsToC(self, is_system_lib)) files.append(path_cc) - fcompile = fcompile if fcompile else _cc.create_shared + if not fcompile: + if file_name.endswith(".tar"): + fcompile = _tar.tar + else: + fcompile = _cc.create_shared fcompile(file_name, files, **kwargs) def time_evaluator(self, func_name, ctx, number): diff --git a/src/runtime/module.cc b/src/runtime/module.cc index 6afe98016..2f2b0a214 100644 --- a/src/runtime/module.cc +++ b/src/runtime/module.cc @@ -7,6 +7,7 @@ #include <tvm/runtime/registry.h> #include <tvm/runtime/packed_func.h> #include <unordered_set> +#include <cstring> #include "./file_util.h" namespace tvm { @@ -26,6 +27,16 @@ PackedFunc Module::GetFunction( } void Module::Import(Module other) { + // specially handle rpc + if (!std::strcmp((*this)->type_key(), "rpc")) { + static const PackedFunc* fimport_ = nullptr; + if (fimport_ == nullptr) { + fimport_ = runtime::Registry::Get("contrib.rpc._ImportRemoteModule"); + CHECK(fimport_ != nullptr); + } + (*fimport_)(*this, other); + return; + } // cyclic detection. std::unordered_set<const ModuleNode*> visited{other.node_.get()}; std::vector<const ModuleNode*> stack{other.node_.get()}; diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 9eca74620..a0952e8d4 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -5,6 +5,7 @@ */ #include <tvm/runtime/registry.h> #include <memory> +#include <cstring> #include "./rpc_session.h" namespace tvm { @@ -83,6 +84,10 @@ class RPCModuleNode final : public ModuleNode { return WrapRemote(handle); } + void* module_handle() const { + return module_handle_; + } + private: PackedFunc WrapRemote(RPCFuncHandle handle) { if (handle == nullptr) return PackedFunc(); @@ -162,6 +167,21 @@ TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule") *rv = Module(n); }); +TVM_REGISTER_GLOBAL("contrib.rpc._ImportRemoteModule") +.set_body([](TVMArgs args, TVMRetValue* rv) { + Module parent = args[0]; + Module child = args[1]; + CHECK(!std::strcmp(parent->type_key(), "rpc") && + !std::strcmp(child->type_key(), "rpc")); + auto* pmod = static_cast<RPCModuleNode*>(parent.operator->()); + auto* cmod = static_cast<RPCModuleNode*>(child.operator->()); + CHECK(pmod->sess().get() == cmod->sess().get()) + << "Import of remote module need to belong to same session."; + pmod->sess()->CallRemote(RPCCode::kModuleImport, + pmod->module_handle(), + cmod->module_handle()); + }); + TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex") .set_body([](TVMArgs args, TVMRetValue* rv) { Module m = args[0]; diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index 7354789d0..7bdb4d3ef 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -940,6 +940,13 @@ void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) { *rv = static_cast<void*>(new Module(m)); } +void RPCModuleImport(TVMArgs args, TVMRetValue *rv) { + void* pmod = args[0]; + void* cmod = args[1]; + static_cast<Module*>(pmod)->Import( + *static_cast<Module*>(cmod)); +} + void RPCModuleFree(TVMArgs args, TVMRetValue *rv) { void* mhandle = args[0]; delete static_cast<Module*>(mhandle); @@ -1006,6 +1013,7 @@ void RPCSession::EventHandler::HandlePackedCall() { case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break; case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break; case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break; + case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break; case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break; case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break; case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break; diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index b84f3d59e..80dde9171 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -44,6 +44,7 @@ enum class RPCCode : int { kDevStreamSync, kCopyAmongRemote, kModuleLoad, + kModuleImport, kModuleFree, kModuleGetFunc, kModuleGetSource, diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index 2a0b078c9..cae65176e 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -86,6 +86,55 @@ def test_rpc_remote_module(): cost = time_f(a, b).mean print('%g secs/op' % cost) np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + + def check_remote_link_cl(): + """Test function to run remote code such as cl + + This is not enabled because there is forking issue + of TVM runtime when server launches after OpenCL + runtime initializes. We leave it as an example + on how to do rpc when we want to do linking on remote. + """ + if not tvm.module.enabled("llvm"): + print("Skip because llvm is not enabled") + return + if not tvm.module.enabled("opencl"): + print("Skip because opencl is not enabled") + return + temp = util.tempdir() + ctx = remote.cl(0) + s = tvm.create_schedule(B.op) + xo, xi = s[B].split(B.op.axis[0], factor=32) + s[B].bind(xo, tvm.thread_axis("blockIdx.x")) + s[B].bind(xi, tvm.thread_axis("threadIdx.x")) + f = tvm.build(s, [A, B], "opencl", target_host="llvm", name="myadd") + # Option 1: save modules separately and rely on remote compiler + path_o = temp.relpath("myadd.o") + path_cl = temp.relpath("myadd.cl") + path_json = temp.relpath("myadd.tvm_meta.json") + f.save(path_o) + f.imported_modules[0].save(path_cl) + remote.upload(path_o) + remote.upload(path_cl) + # upload meta data + remote.upload(path_json) + fhost = remote.load_module("myadd.o") + fdev = remote.load_module("myadd.cl") + fhost.import_module(fdev) + a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + fhost(a, b) + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + # Option 2: export library as a tar ball then handled by remote compiler + path_tar = temp.relpath("myadd.tar") + f.export_library(path_tar) + remote.upload(path_tar) + fhost = remote.load_module("myadd.tar") + a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx) + b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx) + fhost(a, b) + np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1) + check_remote() def test_rpc_return_func(): @@ -101,8 +150,8 @@ def test_rpc_return_func(): if __name__ == "__main__": logging.basicConfig(level=logging.INFO) + test_rpc_remote_module() test_rpc_return_func() test_rpc_file_exchange() test_rpc_array() - test_rpc_remote_module() test_rpc_simple() -- GitLab