Commit fe564d90 by Tianqi Chen Committed by GitHub

[RPC] Include rpc session info into context (#458)

* [RPC] Include rpc session info into context

* add type checker in return converison
parent 3a0d3a39
......@@ -163,7 +163,11 @@ inline TNodeRef TVMRetValue::AsNodeRef() const {
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return TNodeRef(*ptr<std::shared_ptr<Node> >());
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TNodeRef>()
<< " but get " << sptr->type_key();
return TNodeRef(sptr);
}
inline void TVMArgsSetter::operator()(size_t i, const NodeRef& other) const { // NOLINT(*)
......
......@@ -228,6 +228,7 @@ class RPCSession(object):
ctx = _context(dev_type, dev_id)
encode = (self._tbl_index + 1) * RPC_SESS_MASK
ctx.device_type += encode
ctx._rpc_sess = self
return ctx
def cpu(self, dev_id=0):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment