diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h index 850ff9439933ab1752dc1330eef4c0dfbb93fedc..39d94155a9900479f751c7a6b59c7aa16f41a0cc 100644 --- a/include/tvm/packed_func_ext.h +++ b/include/tvm/packed_func_ext.h @@ -163,7 +163,11 @@ inline TNodeRef TVMRetValue::AsNodeRef() const { "Conversion only works for NodeRef"); if (type_code_ == kNull) return TNodeRef(); TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); - return TNodeRef(*ptr<std::shared_ptr<Node> >()); + std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >(); + CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get())) + << "Expected type " << NodeTypeName<TNodeRef>() + << " but get " << sptr->type_key(); + return TNodeRef(sptr); } inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*) diff --git a/python/tvm/contrib/rpc.py b/python/tvm/contrib/rpc.py index a6b5dfa6c609c6acaa4fe43f2a863d7cdfcda9e3..3ad77cb3474398490cb64a8fd48d7d2d43046606 100644 --- a/python/tvm/contrib/rpc.py +++ b/python/tvm/contrib/rpc.py @@ -228,6 +228,7 @@ class RPCSession(object): ctx = _context(dev_type, dev_id) encode = (self._tbl_index + 1) * RPC_SESS_MASK ctx.device_type += encode + ctx._rpc_sess = self return ctx def cpu(self, dev_id=0):