From 1077f8e8147ba335098daa6b120499a1904d6f69 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sun, 10 Sep 2017 23:56:46 -0700
Subject: [PATCH] [RUNTIME][RPC] Enable remote linking of device code. (#444)

* [RUNTIME][RPC] Enable remote linking of device code.

* fix build
---
 python/tvm/contrib/rpc.py                 | 34 +++++++-----
 python/tvm/contrib/tar.py                 | 68 +++++++++++++++++++++++
 python/tvm/module.py                      |  8 ++-
 src/runtime/module.cc                     | 11 ++++
 src/runtime/rpc/rpc_module.cc             | 20 +++++++
 src/runtime/rpc/rpc_session.cc            |  8 +++
 src/runtime/rpc/rpc_session.h             |  1 +
 tests/python/unittest/test_runtime_rpc.py | 51 ++++++++++++++++-
 8 files changed, 184 insertions(+), 17 deletions(-)
 create mode 100644 python/tvm/contrib/tar.py

diff --git a/python/tvm/contrib/rpc.py b/python/tvm/contrib/rpc.py
index 6376f476d..a6b5dfa6c 100644
--- a/python/tvm/contrib/rpc.py
+++ b/python/tvm/contrib/rpc.py
@@ -15,7 +15,7 @@ import socket
 import struct
 import logging
 import multiprocessing
-from . import util, cc
+from . import util, cc, tar
 from ..module import load as _load_module
 from .._ffi.function import _init_api, register_func
 from .._ffi.ndarray import context as _context
@@ -37,10 +37,16 @@ def _server_env():
         """Load module from remote side."""
         path = temp.relpath(file_name)
         # Try create a shared library in remote
-        if path.endswith('.o'):
-            logging.info('Create shared library based on %s', path)
-            cc.create_shared(path + '.so', path)
-            path += '.so'
+        if path.endswith(".o"):
+            logging.info("Create shared library based on %s", path)
+            cc.create_shared(path + ".so", path)
+            path += ".so"
+        elif path.endswith(".tar"):
+            tar_temp = util.tempdir()
+            tar.untar(path, tar_temp.temp_dir)
+            files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
+            cc.create_shared(path + ".so", files)
+            path += ".so"
         m = _load_module(path)
         logging.info("load_module %s", path)
         return m
@@ -63,7 +69,7 @@ def _recvall(sock, nbytes):
         chunk = sock.recv(min(nbytes - nread, 1024))
         nread += len(chunk)
         res.append(chunk)
-    return b''.join(res)
+    return b"".join(res)
 
 
 def _listen_loop(sock):
@@ -71,16 +77,16 @@ def _listen_loop(sock):
     while True:
         conn, addr = sock.accept()
         logging.info("RPCServer: connection from %s", addr)
-        magic = struct.unpack('@i', _recvall(conn, 4))[0]
+        magic = struct.unpack("@i", _recvall(conn, 4))[0]
         if magic != RPC_MAGIC:
             conn.close()
             continue
-        keylen = struct.unpack('@i', _recvall(conn, 4))[0]
+        keylen = struct.unpack("@i", _recvall(conn, 4))[0]
         key = py_str(_recvall(conn, keylen))
         if not key.startswith("client:"):
-            conn.sendall(struct.pack('@i', RPC_MAGIC + 2))
+            conn.sendall(struct.pack("@i", RPC_MAGIC + 2))
         else:
-            conn.sendall(struct.pack('@i', RPC_MAGIC))
+            conn.sendall(struct.pack("@i", RPC_MAGIC))
         logging.info("Connection from %s", addr)
         process = multiprocessing.Process(target=_serve_loop, args=(conn, addr))
         process.deamon = True
@@ -94,10 +100,10 @@ def _connect_proxy_loop(addr, key):
     while True:
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.connect(addr)
-        sock.sendall(struct.pack('@i', RPC_MAGIC))
-        sock.sendall(struct.pack('@i', len(key)))
+        sock.sendall(struct.pack("@i", RPC_MAGIC))
+        sock.sendall(struct.pack("@i", len(key)))
         sock.sendall(key.encode("utf-8"))
-        magic = struct.unpack('@i', _recvall(sock, 4))[0]
+        magic = struct.unpack("@i", _recvall(sock, 4))[0]
         if magic == RPC_MAGIC + 1:
             raise RuntimeError("key: %s has already been used in proxy" % key)
         elif magic == RPC_MAGIC + 2:
@@ -321,7 +327,7 @@ def connect(url, port, key=""):
     try:
         sess = _Connect(url, port, key)
     except NameError:
-        raise RuntimeError('Please compile with USE_RPC=1')
+        raise RuntimeError("Please compile with USE_RPC=1")
     return RPCSession(sess)
 
 _init_api("tvm.contrib.rpc")
diff --git a/python/tvm/contrib/tar.py b/python/tvm/contrib/tar.py
new file mode 100644
index 000000000..ca3bf3478
--- /dev/null
+++ b/python/tvm/contrib/tar.py
@@ -0,0 +1,68 @@
+
+"""Util to invoke tarball in the system."""
+# pylint: disable=invalid-name
+from __future__ import absolute_import as _abs
+import os
+import shutil
+import subprocess
+from . import util
+
+def tar(output, files):
+    """Create tarball containing all files in root.
+
+    Parameters
+    ----------
+    output : str
+        The target shared library.
+
+    files : list
+        List of files to be bundled.
+    """
+    cmd = ["tar"]
+    cmd += ["-czf"]
+    temp = util.tempdir()
+    fset = set()
+    for fname in files:
+        base = os.path.basename(fname)
+        if base in fset:
+            raise ValueError("duplicate file name %s" % base)
+        fset.add(base)
+        shutil.copy(fname, temp.relpath(base))
+    cmd += [output]
+    cmd += ["-C", temp.temp_dir]
+    cmd += temp.listdir()
+    proc = subprocess.Popen(cmd,
+                            stdout=subprocess.PIPE,
+                            stderr=subprocess.STDOUT)
+    (out, _) = proc.communicate()
+
+    if proc.returncode != 0:
+        msg = "Tar error:\n"
+        msg += out
+        raise RuntimeError(msg)
+
+
+def untar(tar_file, directory):
+    """Unpack all tar files into the directory
+
+    Parameters
+    ----------
+    tar_file : str
+        The source tar file.
+
+    directory : str
+        The target directory
+    """
+    cmd = ["tar"]
+    cmd += ["-xf"]
+    cmd += [tar_file]
+    cmd += ["-C", directory]
+    proc = subprocess.Popen(cmd,
+                            stdout=subprocess.PIPE,
+                            stderr=subprocess.STDOUT)
+    (out, _) = proc.communicate()
+
+    if proc.returncode != 0:
+        msg = "Tar error:\n"
+        msg += out
+        raise RuntimeError(msg)
diff --git a/python/tvm/module.py b/python/tvm/module.py
index d5bf49cf8..9c44a9f1e 100644
--- a/python/tvm/module.py
+++ b/python/tvm/module.py
@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
 from collections import namedtuple
 from ._ffi.function import ModuleBase, _set_class_module
 from ._ffi.function import _init_api
-from .contrib import cc as _cc, util as _util
+from .contrib import cc as _cc, tar as _tar, util as _util
 
 ProfileResult = namedtuple("ProfileResult", ["mean"])
 
@@ -100,7 +100,11 @@ class Module(ModuleBase):
             with open(path_cc, "w") as f:
                 f.write(_PackImportsToC(self, is_system_lib))
             files.append(path_cc)
-        fcompile = fcompile if fcompile else _cc.create_shared
+        if not fcompile:
+            if file_name.endswith(".tar"):
+                fcompile = _tar.tar
+            else:
+                fcompile = _cc.create_shared
         fcompile(file_name, files, **kwargs)
 
     def time_evaluator(self, func_name, ctx, number):
diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index 6afe98016..2f2b0a214 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -7,6 +7,7 @@
 #include <tvm/runtime/registry.h>
 #include <tvm/runtime/packed_func.h>
 #include <unordered_set>
+#include <cstring>
 #include "./file_util.h"
 
 namespace tvm {
@@ -26,6 +27,16 @@ PackedFunc Module::GetFunction(
 }
 
 void Module::Import(Module other) {
+  // specially handle rpc
+  if (!std::strcmp((*this)->type_key(), "rpc")) {
+    static const PackedFunc* fimport_ = nullptr;
+    if (fimport_ == nullptr) {
+      fimport_ = runtime::Registry::Get("contrib.rpc._ImportRemoteModule");
+      CHECK(fimport_ != nullptr);
+    }
+    (*fimport_)(*this, other);
+    return;
+  }
   // cyclic detection.
   std::unordered_set<const ModuleNode*> visited{other.node_.get()};
   std::vector<const ModuleNode*> stack{other.node_.get()};
diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc
index 9eca74620..a0952e8d4 100644
--- a/src/runtime/rpc/rpc_module.cc
+++ b/src/runtime/rpc/rpc_module.cc
@@ -5,6 +5,7 @@
  */
 #include <tvm/runtime/registry.h>
 #include <memory>
+#include <cstring>
 #include "./rpc_session.h"
 
 namespace tvm {
@@ -83,6 +84,10 @@ class RPCModuleNode final : public ModuleNode {
     return WrapRemote(handle);
   }
 
+  void* module_handle() const {
+    return module_handle_;
+  }
+
  private:
   PackedFunc WrapRemote(RPCFuncHandle handle) {
     if (handle == nullptr) return PackedFunc();
@@ -162,6 +167,21 @@ TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
     *rv = Module(n);
   });
 
+TVM_REGISTER_GLOBAL("contrib.rpc._ImportRemoteModule")
+.set_body([](TVMArgs args, TVMRetValue* rv) {
+    Module parent = args[0];
+    Module child = args[1];
+    CHECK(!std::strcmp(parent->type_key(), "rpc") &&
+          !std::strcmp(child->type_key(), "rpc"));
+    auto* pmod = static_cast<RPCModuleNode*>(parent.operator->());
+    auto* cmod = static_cast<RPCModuleNode*>(child.operator->());
+    CHECK(pmod->sess().get() == cmod->sess().get())
+        << "Import of remote module need to belong to same session.";
+    pmod->sess()->CallRemote(RPCCode::kModuleImport,
+                             pmod->module_handle(),
+                             cmod->module_handle());
+  });
+
 TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
 .set_body([](TVMArgs args, TVMRetValue* rv) {
     Module m = args[0];
diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc
index 7354789d0..7bdb4d3ef 100644
--- a/src/runtime/rpc/rpc_session.cc
+++ b/src/runtime/rpc/rpc_session.cc
@@ -940,6 +940,13 @@ void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
   *rv = static_cast<void*>(new Module(m));
 }
 
+void RPCModuleImport(TVMArgs args, TVMRetValue *rv) {
+  void* pmod = args[0];
+  void* cmod = args[1];
+  static_cast<Module*>(pmod)->Import(
+      *static_cast<Module*>(cmod));
+}
+
 void RPCModuleFree(TVMArgs args, TVMRetValue *rv) {
   void* mhandle = args[0];
   delete static_cast<Module*>(mhandle);
@@ -1006,6 +1013,7 @@ void RPCSession::EventHandler::HandlePackedCall() {
     case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break;
     case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break;
     case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break;
+    case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break;
     case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break;
     case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break;
     case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break;
diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h
index b84f3d59e..80dde9171 100644
--- a/src/runtime/rpc/rpc_session.h
+++ b/src/runtime/rpc/rpc_session.h
@@ -44,6 +44,7 @@ enum class RPCCode : int {
   kDevStreamSync,
   kCopyAmongRemote,
   kModuleLoad,
+  kModuleImport,
   kModuleFree,
   kModuleGetFunc,
   kModuleGetSource,
diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py
index 2a0b078c9..cae65176e 100644
--- a/tests/python/unittest/test_runtime_rpc.py
+++ b/tests/python/unittest/test_runtime_rpc.py
@@ -86,6 +86,55 @@ def test_rpc_remote_module():
         cost = time_f(a, b).mean
         print('%g secs/op' % cost)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
+
+    def check_remote_link_cl():
+        """Test function to run remote code such as cl
+
+        This is not enabled because there is forking issue
+        of TVM runtime when server launches after OpenCL
+        runtime initializes. We leave it as an example
+        on how to do rpc when we want to do linking on remote.
+        """
+        if not tvm.module.enabled("llvm"):
+            print("Skip because llvm is not enabled")
+            return
+        if not tvm.module.enabled("opencl"):
+            print("Skip because opencl is not enabled")
+            return
+        temp = util.tempdir()
+        ctx = remote.cl(0)
+        s = tvm.create_schedule(B.op)
+        xo, xi = s[B].split(B.op.axis[0], factor=32)
+        s[B].bind(xo, tvm.thread_axis("blockIdx.x"))
+        s[B].bind(xi, tvm.thread_axis("threadIdx.x"))
+        f = tvm.build(s, [A, B], "opencl", target_host="llvm", name="myadd")
+        # Option 1: save modules separately and rely on remote compiler
+        path_o = temp.relpath("myadd.o")
+        path_cl = temp.relpath("myadd.cl")
+        path_json = temp.relpath("myadd.tvm_meta.json")
+        f.save(path_o)
+        f.imported_modules[0].save(path_cl)
+        remote.upload(path_o)
+        remote.upload(path_cl)
+        # upload meta data
+        remote.upload(path_json)
+        fhost = remote.load_module("myadd.o")
+        fdev = remote.load_module("myadd.cl")
+        fhost.import_module(fdev)
+        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
+        fhost(a, b)
+        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
+        # Option 2: export library as a tar ball then handled by remote compiler
+        path_tar = temp.relpath("myadd.tar")
+        f.export_library(path_tar)
+        remote.upload(path_tar)
+        fhost = remote.load_module("myadd.tar")
+        a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
+        b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
+        fhost(a, b)
+        np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
+
     check_remote()
 
 def test_rpc_return_func():
@@ -101,8 +150,8 @@ def test_rpc_return_func():
 
 if __name__ == "__main__":
     logging.basicConfig(level=logging.INFO)
+    test_rpc_remote_module()
     test_rpc_return_func()
     test_rpc_file_exchange()
     test_rpc_array()
-    test_rpc_remote_module()
     test_rpc_simple()
-- 
GitLab