From fe564d9037e15324d897dd10ff92ea43d220e52b Mon Sep 17 00:00:00 2001 From: Tianqi Chen <tqchen@users.noreply.github.com> Date: Sun, 17 Sep 2017 19:05:55 -0700 Subject: [PATCH] [RPC] Include rpc session info into context (#458) * [RPC] Include rpc session info into context * add type checker in return converison --- include/tvm/packed_func_ext.h | 6 +++++- python/tvm/contrib/rpc.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 850ff9439..39d94155a 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -163,7 +163,11 @@ inline TNodeRef TVMRetValue::AsNodeRef() const { "Conversion only works for NodeRef"); if (type_code_ == kNull) return TNodeRef(); TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - return TNodeRef(*ptr<std::shared_ptr<Node> >()); + std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >(); + CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get())) + << "Expected type " << NodeTypeName<TNodeRef>() + << " but get " << sptr->type_key(); + return TNodeRef(sptr); } inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*) diff --git a/python/tvm/contrib/rpc.py b/python/tvm/contrib/rpc.py index a6b5dfa6c..3ad77cb34 100644 --- a/python/tvm/contrib/rpc.py +++ b/python/tvm/contrib/rpc.py @@ -228,6 +228,7 @@ class RPCSession(object): ctx = _context(dev_type, dev_id) encode = (self._tbl_index + 1) * RPC_SESS_MASK ctx.device_type += encode + ctx._rpc_sess = self return ctx def cpu(self, dev_id=0): -- GitLab