From b3f09b019e3e150f416d6e45ad01e252b6b893d8 Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Thu, 12 Apr 2018 17:23:20 -0700
Subject: [PATCH] [RPC] LocalSession to provide RPCSession back by local env
 (#1102)

---
 python/tvm/contrib/rpc/__init__.py        |  2 +-
 python/tvm/contrib/rpc/client.py          | 86 ++++++++++++++++-------
 python/tvm/contrib/rpc/server.py          | 13 +---
 python/tvm/module.py                      | 19 ++++-
 tests/python/unittest/test_runtime_rpc.py | 26 +++++--
 5 files changed, 102 insertions(+), 44 deletions(-)

diff --git a/python/tvm/contrib/rpc/__init__.py b/python/tvm/contrib/rpc/__init__.py
index 3335240ac..c3fbf5be2 100644
--- a/python/tvm/contrib/rpc/__init__.py
+++ b/python/tvm/contrib/rpc/__init__.py
@@ -10,4 +10,4 @@ upload and run remote RPC server, get the result back to verify correctness.
 """
 
 from .server import Server
-from .client import RPCSession, connect, connect_tracker
+from .client import RPCSession, LocalSession, connect, connect_tracker
diff --git a/python/tvm/contrib/rpc/client.py b/python/tvm/contrib/rpc/client.py
index 24b3b9582..f409b2f72 100644
--- a/python/tvm/contrib/rpc/client.py
+++ b/python/tvm/contrib/rpc/client.py
@@ -7,8 +7,11 @@ import struct
 import time
 
 from . import base
+from .. import util
 from ..._ffi.base import TVMError
-from ..._ffi.ndarray import context as _context
+from ..._ffi import function as function
+from ..._ffi import ndarray as nd
+from ...module import load as _load_module
 
 
 class RPCSession(object):
@@ -51,36 +54,12 @@ class RPCSession(object):
         ctx: TVMContext
             The corresponding encoded remote context.
         """
-        ctx = _context(dev_type, dev_id)
+        ctx = nd.context(dev_type, dev_id)
         encode = (self._tbl_index + 1) * base.RPC_SESS_MASK
         ctx.device_type += encode
         ctx._rpc_sess = self
         return ctx
 
-    def cpu(self, dev_id=0):
-        """Construct remote CPU device."""
-        return self.context(1, dev_id)
-
-    def gpu(self, dev_id=0):
-        """Construct remote GPU device."""
-        return self.context(2, dev_id)
-
-    def cl(self, dev_id=0):
-        """Construct remote OpenCL device."""
-        return self.context(4, dev_id)
-
-    def metal(self, dev_id=0):
-        """Construct remote Metal device."""
-        return self.context(8, dev_id)
-
-    def opengl(self, dev_id=0):
-        """Construct remote OpenGL device."""
-        return self.context(11, dev_id)
-
-    def ext_dev(self, dev_id=0):
-        """Construct remote extension device."""
-        return self.context(12, dev_id)
-
     def upload(self, data, target=None):
         """Upload file to remote runtime temp folder
 
@@ -139,6 +118,61 @@ class RPCSession(object):
         """
         return base._LoadRemoteModule(self._sess, path)
 
+    def cpu(self, dev_id=0):
+        """Construct CPU device."""
+        return self.context(1, dev_id)
+
+    def gpu(self, dev_id=0):
+        """Construct GPU device."""
+        return self.context(2, dev_id)
+
+    def cl(self, dev_id=0):
+        """Construct OpenCL device."""
+        return self.context(4, dev_id)
+
+    def metal(self, dev_id=0):
+        """Construct Metal device."""
+        return self.context(8, dev_id)
+
+    def opengl(self, dev_id=0):
+        """Construct OpenGL device."""
+        return self.context(11, dev_id)
+
+    def ext_dev(self, dev_id=0):
+        """Construct extension device."""
+        return self.context(12, dev_id)
+
+
+class LocalSession(RPCSession):
+    """RPCSession interface backed by local environment.
+
+    This class can be used to implement functions that
+    need to be ran both locally and remotely.
+    """
+    def __init__(self):
+        # pylint: disable=super-init-not-called
+        self.context = nd.context
+        self.get_function = function.get_global_func
+        self._temp = util.tempdir()
+
+    def upload(self, data, target=None):
+        if isinstance(data, bytearray):
+            if not target:
+                raise ValueError("target must present when file is a bytearray")
+            blob = data
+        else:
+            blob = bytearray(open(data, "rb").read())
+            if not target:
+                target = os.path.basename(data)
+        with open(self._temp.relpath(target), "wb") as f:
+            f.write(blob)
+
+    def download(self, path):
+        return bytearray(open(self._temp.relpath(path), "rb").read())
+
+    def load_module(self, path):
+        return _load_module(self._temp.relpath(path))
+
 
 class TrackerSession(object):
     """Tracker client session.
diff --git a/python/tvm/contrib/rpc/server.py b/python/tvm/contrib/rpc/server.py
index 8cd1c4ee8..6759f13b6 100644
--- a/python/tvm/contrib/rpc/server.py
+++ b/python/tvm/contrib/rpc/server.py
@@ -22,7 +22,7 @@ import time
 from ..._ffi.function import register_func
 from ..._ffi.base import py_str
 from ...module import load as _load_module
-from .. import util, cc, tar
+from .. import util
 from . import base
 from . base import TrackerCode
 
@@ -38,17 +38,6 @@ def _server_env():
     def load_module(file_name):
         """Load module from remote side."""
         path = temp.relpath(file_name)
-        # Try create a shared library in remote
-        if path.endswith(".o"):
-            logging.info("Create shared library based on %s", path)
-            cc.create_shared(path + ".so", path)
-            path += ".so"
-        elif path.endswith(".tar"):
-            tar_temp = util.tempdir()
-            tar.untar(path, tar_temp.temp_dir)
-            files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
-            cc.create_shared(path + ".so", files)
-            path += ".so"
         m = _load_module(path)
         logging.info("load_module %s", path)
         return m
diff --git a/python/tvm/module.py b/python/tvm/module.py
index 6459733fa..1b83c9b26 100644
--- a/python/tvm/module.py
+++ b/python/tvm/module.py
@@ -186,7 +186,7 @@ def system_lib():
 
 
 def load(path, fmt=""):
-    """Load module from file
+    """Load module from file.
 
     Parameters
     ----------
@@ -201,7 +201,24 @@ def load(path, fmt=""):
     -------
     module : Module
         The loaded module
+
+    Note
+    ----
+    This function will automatically call
+    cc.create_shared if the path is in format .o or .tar
     """
+    # High level handling for .o and .tar file.
+    # We support this to be consistent with RPC module load.
+    if path.endswith(".o"):
+        _cc.create_shared(path + ".so", path)
+        path += ".so"
+    elif path.endswith(".tar"):
+        tar_temp = _util.tempdir()
+        _tar.untar(path, tar_temp.temp_dir)
+        files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
+        _cc.create_shared(path + ".so", files)
+        path += ".so"
+    # Redirect to the load API
     return _LoadFromFile(path, fmt)
 
 
diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py
index 46e84d887..581890efe 100644
--- a/tests/python/unittest/test_runtime_rpc.py
+++ b/tests/python/unittest/test_runtime_rpc.py
@@ -61,14 +61,14 @@ def test_rpc_remote_module():
     if not tvm.module.enabled("rpc"):
         return
     server = rpc.Server("localhost")
-    remote = rpc.connect(server.host, server.port)
+    client = rpc.connect(server.host, server.port)
     # graph
     n = tvm.convert(1024)
     A = tvm.placeholder((n,), name='A')
     B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
     s = tvm.create_schedule(B.op)
 
-    def check_remote():
+    def check_remote(remote):
         if not tvm.module.enabled("llvm"):
             print("Skip because llvm is not enabled")
             return
@@ -86,7 +86,7 @@ def test_rpc_remote_module():
         print('%g secs/op' % cost)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
-    def check_remote_link_cl():
+    def check_remote_link_cl(remote):
         """Test function to run remote code such as cl
 
         This is not enabled because there is forking issue
@@ -134,7 +134,9 @@ def test_rpc_remote_module():
         fhost(a, b)
         np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
 
-    check_remote()
+    check_remote(client)
+    check_remote(rpc.LocalSession())
+
 
 def test_rpc_return_func():
     @tvm.register_func("rpc.test.remote_func")
@@ -147,6 +149,21 @@ def test_rpc_return_func():
     assert fadd(12) == 22
 
 
+def test_local_func():
+    @tvm.register_func("rpc.test.remote_func2")
+    def addone(x):
+        return lambda y: x+y
+    client = rpc.LocalSession()
+    f1 = client.get_function("rpc.test.remote_func2")
+    fadd = f1(10)
+    assert fadd(12) == 22
+
+    blob = bytearray(np.random.randint(0, 10, size=(10)))
+    client.upload(blob, "dat.bin")
+    rev = client.download("dat.bin")
+    assert rev == blob
+
+
 if __name__ == "__main__":
     logging.basicConfig(level=logging.INFO)
     test_rpc_remote_module()
@@ -154,3 +171,4 @@ if __name__ == "__main__":
     test_rpc_file_exchange()
     test_rpc_array()
     test_rpc_simple()
+    test_local_func()
-- 
GitLab