From 42608ddafb4e1952f3b11933ad663c73c3a44524 Mon Sep 17 00:00:00 2001 From: tqchen <tianqi.tchen@gmail.com> Date: Wed, 30 May 2018 16:49:24 -0700 Subject: [PATCH] [IO] Support cross-endian --- dmlc-core | 2 +- include/tvm/runtime/serializer.h | 50 ++++++++ nnvm/src/compiler/graph_runtime.cc | 74 ++++++----- python/tvm/contrib/rpc/base.py | 4 +- python/tvm/contrib/rpc/client.py | 4 +- python/tvm/contrib/rpc/proxy.py | 28 ++--- python/tvm/contrib/rpc/server.py | 22 ++-- python/tvm/contrib/rpc/tracker.py | 12 +- src/runtime/file_util.cc | 1 + src/runtime/graph/graph_runtime.cc | 30 +++-- src/runtime/rpc/rpc_session.cc | 196 +++++++++++++++++------------ 11 files changed, 265 insertions(+), 158 deletions(-) create mode 100644 include/tvm/runtime/serializer.h diff --git a/dmlc-core b/dmlc-core index d3f7fbb53..9b3f9753a 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit d3f7fbb53e5b037c0f5bf6bd21871ccc720690cc +Subproject commit 9b3f9753ae81d657743c555e0cacc4e43f0bed2d diff --git a/include/tvm/runtime/serializer.h b/include/tvm/runtime/serializer.h new file mode 100644 index 000000000..391c7806a --- /dev/null +++ b/include/tvm/runtime/serializer.h @@ -0,0 +1,50 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file tvm/runtime/serializer.h + * \brief Serializer extension to support TVM data types + * Include this file to enable serialization of DLDataType, DLContext + */ +#ifndef TVM_RUNTIME_SERIALIZER_H_ +#define TVM_RUNTIME_SERIALIZER_H_ + +#include <dmlc/io.h> +#include <dmlc/serializer.h> +#include "./c_runtime_api.h" + +namespace dmlc { +namespace serializer { + +template<> +struct Handler<DLDataType> { + inline static void Write(Stream *strm, const DLDataType& dtype) { + Handler<uint8_t>::Write(strm, dtype.code); + Handler<uint8_t>::Write(strm, dtype.bits); + Handler<uint16_t>::Write(strm, dtype.lanes); + } + inline static bool Read(Stream *strm, DLDataType* dtype) { + if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false; + if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false; + if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false; + return true; + } +}; + +template<> +struct Handler<DLContext> { + inline static void Write(Stream *strm, const DLContext& ctx) { + int32_t device_type = static_cast<int32_t>(ctx.device_type); + Handler<int32_t>::Write(strm, device_type); + Handler<int32_t>::Write(strm, ctx.device_id); + } + inline static bool Read(Stream *strm, DLContext* ctx) { + int32_t device_type = 0; + if (!Handler<int32_t>::Read(strm, &(device_type))) return false; + ctx->device_type = static_cast<DLDeviceType>(device_type); + if (!Handler<int32_t>::Read(strm, &(ctx->device_id))) return false; + return true; + } +}; + +} // namespace serializer +} // namespace dmlc +#endif // TVM_RUNTIME_SERIALIZER_H_ diff --git a/nnvm/src/compiler/graph_runtime.cc b/nnvm/src/compiler/graph_runtime.cc index 51a240a46..689ed70ce 100644 --- a/nnvm/src/compiler/graph_runtime.cc +++ b/nnvm/src/compiler/graph_runtime.cc @@ -7,6 +7,7 @@ #include <tvm/runtime/packed_func.h> #include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/registry.h> +#include <tvm/runtime/serializer.h> #include "./graph_runtime.h" namespace nnvm { @@ -38,46 +39,53 @@ NNVM_REGISTER_OP(tvm_op) bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) { uint64_t header = kTVMNDArrayMagic, reserved = 0; - strm->Write(&header, sizeof(header)); - strm->Write(&reserved, sizeof(reserved)); - - strm->Write(&tensor->ctx, sizeof(tensor->ctx)); - strm->Write(&tensor->ndim, sizeof(tensor->ndim)); - strm->Write(&tensor->dtype, sizeof(tensor->dtype)); - + strm->Write(header); + strm->Write(reserved); + strm->Write(tensor->ctx); + strm->Write(tensor->ndim); + strm->Write(tensor->dtype); int ndim = tensor->ndim; - strm->Write(tensor->shape, sizeof(int64_t) * ndim); + strm->WriteArray(tensor->shape, ndim); - int type_size = tensor->dtype.bits / 8; - int64_t size = 1; + int type_bytes = tensor->dtype.bits / 8; + int64_t num_elems = 1; for (int i = 0; i < ndim; ++i) { - size *= tensor->shape[i]; + num_elems *= tensor->shape[i]; + } + int64_t data_byte_size = type_bytes * num_elems; + strm->Write(data_byte_size); + // handle endianness of data correctly. + if (DMLC_IO_NO_ENDIAN_SWAP) { + strm->Write(tensor->data, data_byte_size); + } else { + uint8_t* dptr = reinterpret_cast<uint8_t*>(tensor->data); + std::vector<uint8_t> bytes(dptr, dptr + data_byte_size); + dmlc::ByteSwap(dmlc::BeginPtr(bytes), type_bytes, num_elems); + strm->Write(dmlc::BeginPtr(bytes), data_byte_size); } - int64_t data_byte_size = type_size * size; - strm->Write(&data_byte_size, sizeof(data_byte_size)); - strm->Write(tensor->data, data_byte_size); return true; } DLTensor* LoadDLTensor(dmlc::Stream* strm) { uint64_t header, reserved; - CHECK(strm->Read(&header, sizeof(header))) + CHECK(strm->Read(&header)) << "Invalid DLTensor file format"; - CHECK(strm->Read(&reserved, sizeof(reserved))) + CHECK(strm->Read(&reserved)) << "Invalid DLTensor file format"; CHECK(header == kTVMNDArrayMagic) << "Invalid DLTensor file format"; - DLTensor tensor; - CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx))) + CHECK(strm->Read(&(tensor.ctx))) << "Invalid DLTensor file format"; - CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim))) + CHECK(strm->Read(&(tensor.ndim))) << "Invalid DLTensor file format"; - CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype))) + CHECK(strm->Read(&(tensor.dtype))) << "Invalid DLTensor file format"; std::vector<int64_t> shape(tensor.ndim); - CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim)) - << "Invalid DLTensor file format"; + if (tensor.ndim != 0) { + CHECK(strm->ReadArray(&shape[0], tensor.ndim)) + << "Invalid DLTensor file format"; + } DLTensor* ret; CHECK_EQ(TVMArrayAlloc(shape.data(), tensor.ndim, @@ -87,18 +95,21 @@ DLTensor* LoadDLTensor(dmlc::Stream* strm) { static_cast<int>(tensor.ctx.device_type), tensor.ctx.device_id, &ret), 0) << TVMGetLastError(); - int64_t size = 1; - int type_size = ret->dtype.bits / 8; + int64_t num_elems = 1; + int elem_bytes = (ret->dtype.bits + 7) / 8; for (int i = 0; i < ret->ndim; ++i) { - size *= ret->shape[i]; + num_elems *= ret->shape[i]; } int64_t data_byte_size; - CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size))) + CHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format"; - CHECK(data_byte_size == type_size * size) + CHECK(data_byte_size == num_elems * elem_bytes) << "Invalid DLTensor file format"; - CHECK(strm->Read(ret->data, type_size * size)) + CHECK(strm->Read(ret->data, data_byte_size)) << "Invalid DLTensor file format"; + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(ret->data, elem_bytes, num_elems); + } return ret; } @@ -118,12 +129,12 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict") dmlc::MemoryStringStream strm(&bytes); dmlc::Stream* fo = &strm; uint64_t header = kTVMNDArrayListMagic, reserved = 0; - fo->Write(&header, sizeof(header)); - fo->Write(&reserved, sizeof(reserved)); + fo->Write(header); + fo->Write(reserved); fo->Write(names); { uint64_t sz = static_cast<uint64_t>(arrays.size()); - fo->Write(&sz, sizeof(sz)); + fo->Write(sz); for (size_t i = 0; i < sz; ++i) { SaveDLTensor(fo, arrays[i]); } @@ -150,7 +161,6 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict") << "Invalid parameters file format"; CHECK(strm->Read(&reserved)) << "Invalid parameters file format"; - CHECK(strm->Read(&names)) << "Invalid parameters file format"; uint64_t sz; diff --git a/python/tvm/contrib/rpc/base.py b/python/tvm/contrib/rpc/base.py index a9fd232f5..67e6d6b43 100644 --- a/python/tvm/contrib/rpc/base.py +++ b/python/tvm/contrib/rpc/base.py @@ -73,7 +73,7 @@ def sendjson(sock, data): Python value to be sent. """ data = json.dumps(data) - sock.sendall(struct.pack("@i", len(data))) + sock.sendall(struct.pack("<i", len(data))) sock.sendall(data.encode("utf-8")) @@ -90,7 +90,7 @@ def recvjson(sock): value : object The value received. """ - size = struct.unpack("@i", recvall(sock, 4))[0] + size = struct.unpack("<i", recvall(sock, 4))[0] data = json.loads(py_str(recvall(sock, size))) return data diff --git a/python/tvm/contrib/rpc/client.py b/python/tvm/contrib/rpc/client.py index f409b2f72..637aed6e4 100644 --- a/python/tvm/contrib/rpc/client.py +++ b/python/tvm/contrib/rpc/client.py @@ -192,8 +192,8 @@ class TrackerSession(object): def _connect(self): self._sock = base.connect_with_retry(self._addr) - self._sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC)) - magic = struct.unpack("@i", base.recvall(self._sock, 4))[0] + self._sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC)) + magic = struct.unpack("<i", base.recvall(self._sock, 4))[0] if magic != base.RPC_TRACKER_MAGIC: raise RuntimeError("%s is not RPC Tracker" % str(self._addr)) diff --git a/python/tvm/contrib/rpc/proxy.py b/python/tvm/contrib/rpc/proxy.py index 857e69433..315354ede 100644 --- a/python/tvm/contrib/rpc/proxy.py +++ b/python/tvm/contrib/rpc/proxy.py @@ -58,14 +58,14 @@ class ForwardHandler(object): def _init_step(self, message): if self._magic is None: assert len(message) == 4 - self._magic = struct.unpack('@i', message)[0] + self._magic = struct.unpack('<i', message)[0] if self._magic != base.RPC_MAGIC: logging.info("Invalid RPC magic from %s", self.name()) self.close() self._init_req_nbytes = 4 elif self._rpc_key_length is None: assert len(message) == 4 - self._rpc_key_length = struct.unpack('@i', message)[0] + self._rpc_key_length = struct.unpack('<i', message)[0] self._init_req_nbytes = self._rpc_key_length elif self.rpc_key is None: assert len(message) == self._rpc_key_length @@ -269,12 +269,12 @@ class ProxyServerHandler(object): lhs.forward_proxy = rhs rhs.forward_proxy = lhs - lhs.send_data(struct.pack('@i', base.RPC_CODE_SUCCESS)) - lhs.send_data(struct.pack('@i', len(rhs.rpc_key))) + lhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS)) + lhs.send_data(struct.pack('<i', len(rhs.rpc_key))) lhs.send_data(rhs.rpc_key.encode("utf-8")) - rhs.send_data(struct.pack('@i', base.RPC_CODE_SUCCESS)) - rhs.send_data(struct.pack('@i', len(lhs.rpc_key))) + rhs.send_data(struct.pack('<i', base.RPC_CODE_SUCCESS)) + rhs.send_data(struct.pack('<i', len(lhs.rpc_key))) rhs.send_data(lhs.rpc_key.encode("utf-8")) logging.info("Pairup connect %s and %s", lhs.name(), rhs.name()) @@ -299,8 +299,8 @@ class ProxyServerHandler(object): if self._tracker_conn is None: self._tracker_conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._tracker_conn.connect(self._tracker_addr) - self._tracker_conn.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC)) - magic = struct.unpack("@i", base.recvall(self._tracker_conn, 4))[0] + self._tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC)) + magic = struct.unpack("<i", base.recvall(self._tracker_conn, 4))[0] if magic != base.RPC_TRACKER_MAGIC: self.loop.stop() raise RuntimeError("%s is not RPC Tracker" % str(self._tracker_addr)) @@ -371,7 +371,7 @@ class ProxyServerHandler(object): if handler.match_key in self._server_pool: self._pair_up(self._server_pool.pop(handler.match_key), handler) else: - handler.send_data(struct.pack('@i', base.RPC_CODE_MISMATCH)) + handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH)) handler.signal_close() def _handler_ready_proxy_mode(self, handler): @@ -395,12 +395,12 @@ class ProxyServerHandler(object): logging.info("Timeout client connection %s, cannot find match key=%s", handler.name(), key) pool_dst.pop(key) - handler.send_data(struct.pack('@i', base.RPC_CODE_MISMATCH)) + handler.send_data(struct.pack('<i', base.RPC_CODE_MISMATCH)) handler.signal_close() self.loop.call_later(timeout, cleanup) else: logging.info("Duplicate connection with same key=%s", key) - handler.send_data(struct.pack('@i', base.RPC_CODE_DUPLICATE)) + handler.send_data(struct.pack('<i', base.RPC_CODE_DUPLICATE)) handler.signal_close() def handler_ready(self, handler): @@ -538,13 +538,13 @@ def websocket_proxy_server(url, key=""): on_message = create_on_message(conn) temp = _server_env(None) # Start connecton - conn.write_message(struct.pack('@i', base.RPC_MAGIC), binary=True) + conn.write_message(struct.pack('<i', base.RPC_MAGIC), binary=True) key = "server:" + key - conn.write_message(struct.pack('@i', len(key)), binary=True) + conn.write_message(struct.pack('<i', len(key)), binary=True) conn.write_message(key.encode("utf-8"), binary=True) msg = yield conn.read_message() assert len(msg) >= 4 - magic = struct.unpack('@i', msg[:4])[0] + magic = struct.unpack('<i', msg[:4])[0] if magic == base.RPC_CODE_DUPLICATE: raise RuntimeError("key: %s has already been used in proxy" % key) elif magic == base.RPC_CODE_MISMATCH: diff --git a/python/tvm/contrib/rpc/server.py b/python/tvm/contrib/rpc/server.py index b49a6900b..e3c731a20 100644 --- a/python/tvm/contrib/rpc/server.py +++ b/python/tvm/contrib/rpc/server.py @@ -124,23 +124,23 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library): unmatch_period_count = 0 continue conn, addr = listen_sock.accept() - magic = struct.unpack("@i", base.recvall(conn, 4))[0] + magic = struct.unpack("<i", base.recvall(conn, 4))[0] if magic != base.RPC_MAGIC: conn.close() continue - keylen = struct.unpack("@i", base.recvall(conn, 4))[0] + keylen = struct.unpack("<i", base.recvall(conn, 4))[0] key = py_str(base.recvall(conn, keylen)) arr = key.split() expect_header = "client:" + matchkey server_key = "server:" + rpc_key if arr[0] != expect_header: - conn.sendall(struct.pack("@i", base.RPC_CODE_MISMATCH)) + conn.sendall(struct.pack("<i", base.RPC_CODE_MISMATCH)) conn.close() logging.info("RPCServer: mismatch key from %s", addr) continue else: - conn.sendall(struct.pack("@i", base.RPC_CODE_SUCCESS)) - conn.sendall(struct.pack("@i", len(server_key))) + conn.sendall(struct.pack("<i", base.RPC_CODE_SUCCESS)) + conn.sendall(struct.pack("<i", len(server_key))) conn.sendall(server_key.encode("utf-8")) return conn, addr, _parse_server_opt(arr[1:]) @@ -151,8 +151,8 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library): # step 1: setup tracker and report to tracker if tracker_addr and tracker_conn is None: tracker_conn = base.connect_with_retry(tracker_addr) - tracker_conn.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC)) - magic = struct.unpack("@i", base.recvall(tracker_conn, 4))[0] + tracker_conn.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC)) + magic = struct.unpack("<i", base.recvall(tracker_conn, 4))[0] if magic != base.RPC_TRACKER_MAGIC: raise RuntimeError("%s is not RPC Tracker" % str(tracker_addr)) # report status of current queue @@ -193,17 +193,17 @@ def _connect_proxy_loop(addr, key, load_library): try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(addr) - sock.sendall(struct.pack("@i", base.RPC_MAGIC)) - sock.sendall(struct.pack("@i", len(key))) + sock.sendall(struct.pack("<i", base.RPC_MAGIC)) + sock.sendall(struct.pack("<i", len(key))) sock.sendall(key.encode("utf-8")) - magic = struct.unpack("@i", base.recvall(sock, 4))[0] + magic = struct.unpack("<i", base.recvall(sock, 4))[0] if magic == base.RPC_CODE_DUPLICATE: raise RuntimeError("key: %s has already been used in proxy" % key) elif magic == base.RPC_CODE_MISMATCH: logging.info("RPCProxy do not have matching client key %s", key) elif magic != base.RPC_CODE_SUCCESS: raise RuntimeError("%s is not RPC Proxy" % str(addr)) - keylen = struct.unpack("@i", base.recvall(sock, 4))[0] + keylen = struct.unpack("<i", base.recvall(sock, 4))[0] remote_key = py_str(base.recvall(sock, keylen)) opts = _parse_server_opt(remote_key.split()[1:]) logging.info("RPCProxy connected to %s", str(addr)) diff --git a/python/tvm/contrib/rpc/tracker.py b/python/tvm/contrib/rpc/tracker.py index 5e5620042..165ff5b80 100644 --- a/python/tvm/contrib/rpc/tracker.py +++ b/python/tvm/contrib/rpc/tracker.py @@ -143,11 +143,11 @@ class TCPEventHandler(tornado_util.TCPHandler): if len(message) != 4: logging.info("Invalid connection from %s", self.name()) self.close() - magic = struct.unpack('@i', message)[0] + magic = struct.unpack('<i', message)[0] if magic != RPC_TRACKER_MAGIC: logging.info("Invalid magic from %s", self.name()) self.close() - self.write_message(struct.pack('@i', RPC_TRACKER_MAGIC), binary=True) + self.write_message(struct.pack('<i', RPC_TRACKER_MAGIC), binary=True) self._init_req_nbytes = 0 def on_message(self, message): @@ -168,7 +168,7 @@ class TCPEventHandler(tornado_util.TCPHandler): while True: if self._msg_size == 0: if len(self._data) >= 4: - self._msg_size = struct.unpack('@i', self._data[:4])[0] + self._msg_size = struct.unpack('<i', self._data[:4])[0] else: return if self._msg_size != 0 and len(self._data) >= self._msg_size + 4: @@ -184,7 +184,7 @@ class TCPEventHandler(tornado_util.TCPHandler): """return value to the output""" data = json.dumps(data) self.write_message( - struct.pack('@i', len(data)), binary=True) + struct.pack('<i', len(data)), binary=True) self.write_message(data.encode("utf-8"), binary=True) def call_handler(self, args): @@ -355,8 +355,8 @@ class Tracker(object): def _stop_tracker(self): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((self.host, self.port)) - sock.sendall(struct.pack("@i", base.RPC_TRACKER_MAGIC)) - magic = struct.unpack("@i", base.recvall(sock, 4))[0] + sock.sendall(struct.pack("<i", base.RPC_TRACKER_MAGIC)) + magic = struct.unpack("<i", base.recvall(sock, 4))[0] assert magic == base.RPC_TRACKER_MAGIC base.sendjson(sock, [TrackerCode.STOP, self.stop_key]) assert base.recvjson(sock) == TrackerCode.SUCCESS diff --git a/src/runtime/file_util.cc b/src/runtime/file_util.cc index cfd2d72f1..7606bf89c 100644 --- a/src/runtime/file_util.cc +++ b/src/runtime/file_util.cc @@ -4,6 +4,7 @@ */ #include <dmlc/json.h> #include <dmlc/logging.h> +#include <tvm/runtime/serializer.h> #include <fstream> #include "./file_util.h" diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index ce1c51ba5..89d5e7a28 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -4,6 +4,7 @@ */ #include <tvm/runtime/packed_func.h> #include <tvm/runtime/registry.h> +#include <tvm/runtime/serializer.h> #include <dmlc/memory_io.h> #include <dmlc/json.h> #include <numeric> @@ -397,24 +398,25 @@ class GraphRuntime : public ModuleNode { void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) { + // always use strm->Read to maintain endianness conversion uint64_t header, reserved; - CHECK(strm->Read(&header, sizeof(header))) + CHECK(strm->Read(&header)) << "Invalid DLTensor file format"; - CHECK(strm->Read(&reserved, sizeof(reserved))) + CHECK(strm->Read(&reserved)) << "Invalid DLTensor file format"; CHECK(header == kTVMNDArrayMagic) << "Invalid DLTensor file format"; DLTensor tensor; - CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx))) + CHECK(strm->Read(&(tensor.ctx))) << "Invalid DLTensor file format"; - CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim))) + CHECK(strm->Read(&(tensor.ndim))) << "Invalid DLTensor file format"; - CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype))) + CHECK(strm->Read(&(tensor.dtype))) << "Invalid DLTensor file format"; std::vector<int64_t> shape(tensor.ndim); if (tensor.ndim != 0) { - CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim)) + CHECK(strm->ReadArray(&shape[0], tensor.ndim)) << "Invalid DLTensor file format"; } CHECK_EQ(tensor.ndim, dst->ndim) << "param dimension mismatch"; @@ -425,18 +427,23 @@ void GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* dst) { CHECK_EQ(shape[i], dst->shape[i]) << "param shape mismatch"; } size_t bits = dst->dtype.bits * dst->dtype.lanes; - size_t size = (bits + 7) / 8; + size_t elem_bytes = (bits + 7) / 8; + size_t num_elems = 1; for (int i = 0; i < dst->ndim; ++i) { - size *= dst->shape[i]; + num_elems *= dst->shape[i]; } uint64_t data_byte_size; - CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size))) + CHECK(strm->Read(&data_byte_size)) << "Invalid DLTensor file format"; - CHECK(data_byte_size == size) + CHECK_EQ(data_byte_size, elem_bytes * num_elems) << "Invalid DLTensor file format"; std::vector<uint8_t> bytes(data_byte_size + 1); CHECK(strm->Read(&bytes[0], data_byte_size)) << "Invalid DLTensor file format"; + // explicitly swap endian when necessary. + if (!DMLC_IO_NO_ENDIAN_SWAP) { + dmlc::ByteSwap(&bytes[0], elem_bytes, num_elems); + } TVM_CCALL(TVMArrayCopyFromBytes(dst, &bytes[0], data_byte_size)); } @@ -453,9 +460,8 @@ void GraphRuntime::LoadParams(dmlc::Stream* strm) { CHECK(strm->Read(&names)) << "Invalid parameters file format"; uint64_t sz; - strm->Read(&sz, sizeof(sz)); + strm->Read(&sz); size_t size = static_cast<size_t>(sz); - CHECK(size == names.size()) << "Invalid parameters file format"; for (size_t i = 0; i < size; ++i) { diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc index 923194cf4..d5eee1697 100644 --- a/src/runtime/rpc/rpc_session.cc +++ b/src/runtime/rpc/rpc_session.cc @@ -6,6 +6,7 @@ #include <tvm/runtime/packed_func.h> #include <tvm/runtime/device_api.h> #include <tvm/runtime/registry.h> +#include <tvm/runtime/serializer.h> #include <memory> #include <array> #include <string> @@ -44,7 +45,7 @@ struct RPCArgBuffer { }; // Event handler for RPC events. -class RPCSession::EventHandler { +class RPCSession::EventHandler : public dmlc::Stream { public: EventHandler(common::RingBuffer* reader, common::RingBuffer* writer, @@ -71,6 +72,15 @@ class RPCSession::EventHandler { return 0; } } + // Request number of bytes from reader. + void RequestBytes(size_t nbytes) { + pending_request_bytes_ += nbytes; + reader_->Reserve(pending_request_bytes_); + } + // Whether we are ready to handle next request. + bool Ready() { + return reader_->bytes_available() >= pending_request_bytes_; + } bool CanCleanShutdown() const { return state_ == kRecvCode; } @@ -86,12 +96,12 @@ class RPCSession::EventHandler { case kInitHeader: HandleInitHeader(); break; case kRecvCode: HandleRecvCode(); break; case kRecvCallHandle: { - this->Read(&call_handle_, sizeof(call_handle_)); + CHECK(this->Read(&call_handle_)); this->SwitchToState(kRecvPackedSeqNumArgs); break; } case kRecvPackedSeqNumArgs: { - this->Read(&num_packed_args_, sizeof(num_packed_args_)); + CHECK(this->Read(&num_packed_args_)); arg_buf_.reset(new RPCArgBuffer()); arg_buf_->value.resize(num_packed_args_); arg_buf_->tcode.resize(num_packed_args_); @@ -100,7 +110,7 @@ class RPCSession::EventHandler { } case kRecvPackedSeqTypeCode: { if (num_packed_args_ != 0) { - this->Read(arg_buf_->tcode.data(), sizeof(int) * num_packed_args_); + this->ReadArray(arg_buf_->tcode.data(), num_packed_args_); } arg_index_ = 0; arg_recv_stage_ = 0; @@ -164,8 +174,8 @@ class RPCSession::EventHandler { } // send Packed sequence to writer. void SendPackedSeq(const TVMValue* arg_values, const int* type_codes, int n) { - writer_->Write(&n, sizeof(n)); - writer_->Write(type_codes, sizeof(int) * n); + this->Write(n); + this->WriteArray(type_codes, n); // Argument packing. for (int i = 0; i < n; ++i) { int tcode = type_codes[i]; @@ -173,14 +183,20 @@ class RPCSession::EventHandler { switch (tcode) { case kDLInt: case kDLUInt: - case kDLFloat: + case kDLFloat: { + this->Write<int64_t>(value.v_int64); + break; + } case kTVMType: { - writer_->Write(&value, sizeof(TVMValue)); + this->Write(value.v_type); + // padding + int32_t padding = 0; + this->Write<int32_t>(padding); break; } case kTVMContext: { value.v_ctx = StripSessMask(value.v_ctx); - writer_->Write(&value, sizeof(TVMValue)); + this->Write(value.v_ctx); break; } case kFuncHandle: @@ -188,7 +204,7 @@ class RPCSession::EventHandler { case kHandle: { // always send handle in 64 bit. uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle); - writer_->Write(&handle, sizeof(uint64_t)); + this->Write(handle); break; } case kArrayHandle: { @@ -196,11 +212,11 @@ class RPCSession::EventHandler { TVMContext ctx = StripSessMask(arr->ctx); uint64_t data = reinterpret_cast<uint64_t>( static_cast<RemoteSpace*>(arr->data)->data); - writer_->Write(&data, sizeof(uint64_t)); - writer_->Write(&ctx, sizeof(ctx)); - writer_->Write(&(arr->ndim), sizeof(int)); - writer_->Write(&(arr->dtype), sizeof(DLDataType)); - writer_->Write(arr->shape, sizeof(int64_t) * arr->ndim); + this->Write(data); + this->Write(ctx); + this->Write(arr->ndim); + this->Write(arr->dtype); + this->WriteArray(arr->shape, arr->ndim); CHECK(arr->strides == nullptr) << "Donot support strided remote array"; CHECK_EQ(arr->byte_offset, 0) @@ -211,15 +227,15 @@ class RPCSession::EventHandler { case kStr: { const char* s = value.v_str; uint64_t len = strlen(s); - writer_->Write(&len, sizeof(len)); - writer_->Write(s, sizeof(char) * len); + this->Write(len); + this->WriteArray(s, len); break; } case kBytes: { TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle); uint64_t len = bytes->size; - writer_->Write(&len, sizeof(len)); - writer_->Write(bytes->data, sizeof(char) * len); + this->Write(len); + this->WriteArray(bytes->data, len); break; } default: { @@ -230,6 +246,23 @@ class RPCSession::EventHandler { } } + // Endian aware IO handling + using Stream::Read; + using Stream::Write; + using Stream::ReadArray; + using Stream::WriteArray; + + inline bool Read(RPCCode* code) { + int cdata; + if (!this->Read(&cdata)) return false; + *code = static_cast<RPCCode>(cdata); + return true; + } + inline void Write(RPCCode code) { + int cdata = static_cast<int>(code); + this->Write(cdata); + } + protected: enum State { kInitHeader, @@ -370,10 +403,22 @@ class RPCSession::EventHandler { switch (tcode) { case kDLInt: case kDLUInt: - case kDLFloat: - case kTVMType: + case kDLFloat: { + this->Read<int64_t>(&(value.v_int64)); + ++arg_index_; + this->SwitchToState(kRecvPackedSeqArg); + break; + } + case kTVMType: { + this->Read(&(value.v_type)); + int32_t padding = 0; + this->Read<int32_t>(&padding); + ++arg_index_; + this->SwitchToState(kRecvPackedSeqArg); + break; + } case kTVMContext: { - this->Read(&value, sizeof(TVMValue)); + this->Read(&(value.v_ctx)); ++arg_index_; this->SwitchToState(kRecvPackedSeqArg); break; @@ -383,7 +428,7 @@ class RPCSession::EventHandler { case kHandle: { // always send handle in 64 bit. uint64_t handle; - this->Read(&handle, sizeof(handle)); + this->Read(&handle); value.v_handle = reinterpret_cast<void*>(handle); ++arg_index_; this->SwitchToState(kRecvPackedSeqArg); @@ -398,7 +443,7 @@ class RPCSession::EventHandler { case kStr: case kBytes: { uint64_t len; - this->Read(&len, sizeof(len)); + this->Read(&len); temp_bytes_.reset( new RPCByteArrayBuffer()); temp_bytes_->data.resize(len); arg_recv_stage_ = 1; @@ -409,12 +454,12 @@ class RPCSession::EventHandler { case kArrayHandle: { temp_array_.reset(new RPCDataArrayBuffer()); uint64_t handle; - this->Read(&handle, sizeof(handle)); + this->Read(&handle); DLTensor& tensor = temp_array_->tensor; tensor.data = reinterpret_cast<void*>(handle); - this->Read(&(tensor.ctx), sizeof(TVMContext)); - this->Read(&(tensor.ndim), sizeof(int)); - this->Read(&(tensor.dtype), sizeof(DLDataType)); + this->Read(&(tensor.ctx)); + this->Read(&(tensor.ndim)); + this->Read(&(tensor.dtype)); temp_array_->shape.resize(tensor.ndim); tensor.shape = temp_array_->shape.data(); arg_recv_stage_ = 1; @@ -432,7 +477,7 @@ class RPCSession::EventHandler { CHECK_EQ(arg_recv_stage_, 1); if (tcode == kStr || tcode == kBytes) { if (temp_bytes_->data.size() != 0) { - this->Read(&(temp_bytes_->data[0]), temp_bytes_->data.size()); + this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size()); } if (tcode == kStr) { value.v_str = temp_bytes_->data.c_str(); @@ -445,7 +490,7 @@ class RPCSession::EventHandler { } else { CHECK_EQ(tcode, kArrayHandle); DLTensor& tensor = temp_array_->tensor; - this->Read(tensor.shape, tensor.ndim * sizeof(int64_t)); + this->ReadArray(tensor.shape, tensor.ndim); value.v_handle = &tensor; arg_buf_->temp_array.emplace_back(std::move(temp_array_)); } @@ -458,20 +503,20 @@ class RPCSession::EventHandler { void HandleInitHeader() { if (init_header_step_ == 0) { int32_t len; - this->Read(&len, sizeof(len)); + this->Read(&len); remote_key_->resize(len); init_header_step_ = 1; this->RequestBytes(len); return; } else { CHECK_EQ(init_header_step_, 1); - this->Read(dmlc::BeginPtr(*remote_key_), remote_key_->length()); + this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length()); this->SwitchToState(kRecvCode); } } // Handler for read code. void HandleRecvCode() { - this->Read(&code_, sizeof(code_)); + this->Read(&code_); if (code_ > RPCCode::kSystemFuncStart) { SwitchToState(kRecvPackedSeqNumArgs); return; @@ -511,14 +556,14 @@ class RPCSession::EventHandler { void HandleCopyFromRemote() { uint64_t handle, offset, size; TVMContext ctx; - this->Read(&handle, sizeof(handle)); - this->Read(&offset, sizeof(offset)); - this->Read(&size, sizeof(size)); - this->Read(&ctx, sizeof(ctx)); + this->Read(&handle); + this->Read(&offset); + this->Read(&size); + this->Read(&ctx); if (ctx.device_type == kDLCPU) { RPCCode code = RPCCode::kCopyAck; - writer_->Write(&code, sizeof(code)); - writer_->Write(reinterpret_cast<char*>(handle) + offset, size); + this->Write(code); + this->WriteArray(reinterpret_cast<char*>(handle) + offset, size); } else { temp_data_.resize(size + 1); try { @@ -530,11 +575,11 @@ class RPCSession::EventHandler { dmlc::BeginPtr(temp_data_), 0, size, ctx, cpu_ctx, nullptr); RPCCode code = RPCCode::kCopyAck; - writer_->Write(&code, sizeof(code)); - writer_->Write(&temp_data_[0], size); + this->Write(code); + this->WriteArray(&temp_data_[0], size); } catch (const std::runtime_error &e) { RPCCode code = RPCCode::kException; - writer_->Write(&code, sizeof(code)); + this->Write(code); TVMValue ret_value; ret_value.v_str = e.what(); int ret_tcode = kStr; @@ -548,10 +593,10 @@ class RPCSession::EventHandler { // use static variable to persist state. // This only works if next stage is immediately after this. if (arg_recv_stage_ == 0) { - this->Read(©_handle_, sizeof(uint64_t)); - this->Read(©_offset_, sizeof(uint64_t)); - this->Read(©_size_, sizeof(uint64_t)); - this->Read(©_ctx_, sizeof(TVMContext)); + CHECK(this->Read(©_handle_)); + CHECK(this->Read(©_offset_)); + CHECK(this->Read(©_size_)); + CHECK(this->Read(©_ctx_)); arg_recv_stage_ = 1; CHECK_EQ(pending_request_bytes_, 0U); this->RequestBytes(copy_size_); @@ -563,11 +608,11 @@ class RPCSession::EventHandler { RPCCode code = RPCCode::kReturn; std::string errmsg; if (copy_ctx_.device_type == kDLCPU) { - this->Read( + this->ReadArray( reinterpret_cast<char*>(copy_handle_) + copy_offset_, copy_size_); } else { temp_data_.resize(copy_size_ + 1); - this->Read(&temp_data_[0], copy_size_); + this->ReadArray(&temp_data_[0], copy_size_); try { TVMContext cpu_ctx; cpu_ctx.device_type = kDLCPU; @@ -583,7 +628,7 @@ class RPCSession::EventHandler { ret_tcode = kStr; } } - writer_->Write(&code, sizeof(code)); + this->Write(code); SendPackedSeq(&ret_value, &ret_tcode, 1); arg_recv_stage_ = 0; this->SwitchToState(kRecvCode); @@ -603,7 +648,7 @@ class RPCSession::EventHandler { std::unique_ptr<RPCArgBuffer> args = std::move(arg_buf_); f(args->AsTVMArgs(), &rv); RPCCode code = RPCCode::kReturn; - writer_->Write(&code, sizeof(code)); + this->Write(code); if (rv.type_code() == kStr) { ret_value.v_str = rv.ptr<std::string>()->c_str(); ret_tcode = kStr; @@ -630,7 +675,7 @@ class RPCSession::EventHandler { } } catch (const std::runtime_error& e) { RPCCode code = RPCCode::kException; - writer_->Write(&code, sizeof(code)); + this->Write(code); ret_value.v_str = e.what(); ret_tcode = kStr; SendPackedSeq(&ret_value, &ret_tcode, 1); @@ -640,19 +685,14 @@ class RPCSession::EventHandler { private: // Utility functions // Internal read function, update pending_request_bytes_ - void Read(void* data, size_t size) { + size_t Read(void* data, size_t size) final { CHECK_LE(size, pending_request_bytes_); reader_->Read(data, size); pending_request_bytes_ -= size; + return size; } - // Request number of bytes from reader. - void RequestBytes(size_t nbytes) { - pending_request_bytes_ += nbytes; - reader_->Reserve(pending_request_bytes_); - } - // Whether we are ready to handle next request. - bool Ready() { - return reader_->bytes_available() >= pending_request_bytes_; + void Write(const void* data, size_t size) final { + writer_->Write(data, size); } // Number of pending bytes requests size_t pending_request_bytes_; @@ -766,7 +806,7 @@ RPCSession::~RPCSession() { void RPCSession::Shutdown() { if (channel_ != nullptr) { RPCCode code = RPCCode::kShutdown; - writer_.Write(&code, sizeof(code)); + handler_->Write(code); // flush all writing buffer to output channel. try { while (writer_.bytes_available() != 0) { @@ -788,7 +828,6 @@ void RPCSession::ServerLoop() { } TVMRetValue rv; CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown); - LOG(INFO) << "Shutdown..."; if (const auto* f = Registry::Get("tvm.contrib.rpc.server.shutdown")) { (*f)(); } @@ -821,9 +860,9 @@ void RPCSession::CallFunc(void* h, const PackedFunc* fwrap) { std::lock_guard<std::recursive_mutex> lock(mutex_); RPCCode code = RPCCode::kCallFunc; - writer_.Write(&code, sizeof(code)); + handler_->Write(code); uint64_t handle = reinterpret_cast<uint64_t>(h); - writer_.Write(&handle, sizeof(handle)); + handler_->Write(handle); handler_->SendPackedSeq(args.values, args.type_codes, args.num_args); code = HandleUntilReturnEvent(rv, true, fwrap); CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code); @@ -838,15 +877,15 @@ void RPCSession::CopyToRemote(void* from, std::lock_guard<std::recursive_mutex> lock(mutex_); ctx_to = handler_->StripSessMask(ctx_to); RPCCode code = RPCCode::kCopyToRemote; - writer_.Write(&code, sizeof(code)); + handler_->Write(code); uint64_t handle = reinterpret_cast<uint64_t>(to); - writer_.Write(&handle, sizeof(handle)); + handler_->Write(handle); uint64_t offset = static_cast<uint64_t>(to_offset); - writer_.Write(&offset, sizeof(offset)); + handler_->Write(offset); uint64_t size = static_cast<uint64_t>(data_size); - writer_.Write(&size, sizeof(size)); - writer_.Write(&ctx_to, sizeof(ctx_to)); - writer_.Write(reinterpret_cast<char*>(from) + from_offset, data_size); + handler_->Write(size); + handler_->Write(ctx_to); + handler_->WriteArray(reinterpret_cast<char*>(from) + from_offset, data_size); TVMRetValue rv; CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn); } @@ -860,26 +899,27 @@ void RPCSession::CopyFromRemote(void* from, std::lock_guard<std::recursive_mutex> lock(mutex_); ctx_from = handler_->StripSessMask(ctx_from); RPCCode code = RPCCode::kCopyFromRemote; - writer_.Write(&code, sizeof(code)); + handler_->Write(code); uint64_t handle = reinterpret_cast<uint64_t>(from); - writer_.Write(&handle, sizeof(handle)); + handler_->Write(handle); uint64_t offset = static_cast<uint64_t>(from_offset); - writer_.Write(&offset, sizeof(offset)); + handler_->Write(offset); uint64_t size = static_cast<uint64_t>(data_size); - writer_.Write(&size, sizeof(size)); - writer_.Write(&ctx_from, sizeof(ctx_from)); + handler_->Write(size); + handler_->Write(ctx_from); TVMRetValue rv; CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck); reader_.Reserve(data_size); - while (reader_.bytes_available() < data_size) { - size_t bytes_needed = data_size - reader_.bytes_available(); + handler_->RequestBytes(data_size); + while (!handler_->Ready()) { + size_t bytes_needed = handler_->BytesNeeded(); reader_.WriteWithCallback([this](void* data, size_t size) { size_t n = channel_->Recv(data, size); CHECK_NE(n, 0U) << "Channel closes before we get neded bytes"; return n; }, bytes_needed); } - reader_.Read(reinterpret_cast<char*>(to) + to_offset, data_size); + handler_->ReadArray(reinterpret_cast<char*>(to) + to_offset, data_size); handler_->FinishCopyAck(); } -- GitLab