From fe564d9037e15324d897dd10ff92ea43d220e52b Mon Sep 17 00:00:00 2001
From: Tianqi Chen <tqchen@users.noreply.github.com>
Date: Sun, 17 Sep 2017 19:05:55 -0700
Subject: [PATCH] [RPC] Include rpc session info into context (#458)

* [RPC] Include rpc session info into context

* add type checker in return converison
---
 include/tvm/packed_func_ext.h | 6 +++++-
 python/tvm/contrib/rpc.py     | 1 +
 2 files changed, 6 insertions(+), 1 deletion(-)

diff --git a/include/tvm/packed_func_ext.h b/include/tvm/packed_func_ext.h
index 850ff9439..39d94155a 100644
--- a/include/tvm/packed_func_ext.h
+++ b/include/tvm/packed_func_ext.h
@@ -163,7 +163,11 @@ inline TNodeRef TVMRetValue::AsNodeRef() const {
       "Conversion only works for NodeRef");
   if (type_code_ == kNull) return TNodeRef();
   TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
-  return TNodeRef(*ptr<std::shared_ptr<Node> >());
+  std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
+  CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
+      << "Expected type " << NodeTypeName<TNodeRef>()
+      << " but get " << sptr->type_key();
+  return TNodeRef(sptr);
 }
 
 inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const {  // NOLINT(*)
diff --git a/python/tvm/contrib/rpc.py b/python/tvm/contrib/rpc.py
index a6b5dfa6c..3ad77cb34 100644
--- a/python/tvm/contrib/rpc.py
+++ b/python/tvm/contrib/rpc.py
@@ -228,6 +228,7 @@ class RPCSession(object):
         ctx = _context(dev_type, dev_id)
         encode = (self._tbl_index + 1) * RPC_SESS_MASK
         ctx.device_type += encode
+        ctx._rpc_sess = self
         return ctx
 
     def cpu(self, dev_id=0):
-- 
GitLab