diff --git a/python/tvm/contrib/rpc.py b/python/tvm/contrib/rpc.py index 3448c4c554d104f2af4e516dc5251631e8339cb8..831481509c76794e0bc4f125b1464d7f57ff202d 100644 --- a/python/tvm/contrib/rpc.py +++ b/python/tvm/contrib/rpc.py @@ -74,10 +74,16 @@ def _recvall(sock, nbytes): return b"".join(res) -def _listen_loop(sock): +def _listen_loop(sock, exclusive): """Lisenting loop""" + last_proc = None while True: conn, addr = sock.accept() + + if last_proc and last_proc.is_alive() and exclusive: + logging.info("Kill last call") + last_proc.terminate() + logging.info("RPCServer: connection from %s", addr) magic = struct.unpack("@i", _recvall(conn, 4))[0] if magic != RPC_MAGIC: @@ -90,9 +96,11 @@ def _listen_loop(sock): else: conn.sendall(struct.pack("@i", RPC_MAGIC)) logging.info("Connection from %s", addr) + process = multiprocessing.Process(target=_serve_loop, args=(conn, addr)) process.deamon = True process.start() + last_proc = process # close from our side. conn.close() @@ -158,6 +166,11 @@ class Server(object): This is recommended to switch on if we want to do local RPC demonstration for GPU devices to avoid fork safety issues. + exclusive : bool, optional + If this is enabled, the server will kill old connection + when new connection comes. This can make sure the current call + monopolize the hardware resource. + key : str, optional The key used to identify the server in Proxy connection. """ @@ -167,6 +180,7 @@ class Server(object): port_end=9199, is_proxy=False, use_popen=False, + exclusive=False, key=""): self.host = host self.port = port @@ -201,7 +215,7 @@ class Server(object): sock.listen(1) self.sock = sock self.proc = multiprocessing.Process( - target=_listen_loop, args=(self.sock,)) + target=_listen_loop, args=(self.sock, exclusive)) self.proc.deamon = True self.proc.start() else: @@ -210,8 +224,6 @@ class Server(object): self.proc.deamon = True self.proc.start() - - def terminate(self): """Terminate the server process""" if self.proc: @@ -222,7 +234,6 @@ class Server(object): self.terminate() - class RPCSession(object): """RPC Client session module diff --git a/python/tvm/exec/rpc_server.py b/python/tvm/exec/rpc_server.py index 432860f58d1e434bd95a4c1f0c6d45ac650159ed..deb830bdc583998c609f4c23b34fca6fe82b279d 100644 --- a/python/tvm/exec/rpc_server.py +++ b/python/tvm/exec/rpc_server.py @@ -21,6 +21,9 @@ def main(): help="Whether to load executor runtime") parser.add_argument('--load-library', type=str, default="", help="Additional library to load") + parser.add_argument('--exclusive', action='store_true', + help="If this is enabled, the server will kill old connection" + "when new connection comes") args = parser.parse_args() logging.basicConfig(level=logging.INFO) @@ -35,7 +38,7 @@ def main(): libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL)) logging.info("Load additional library %s", file_name) - server = rpc.Server(args.host, args.port, args.port_end) + server = rpc.Server(args.host, args.port, args.port_end, exclusive=args.exclusive) server.libs += libs server.proc.join()