/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /*! * Copyright (c) 2017 by Contributors * \file rpc_device_api.cc * \brief RPC module. */ #include <tvm/runtime/registry.h> #include <memory> #include <cstring> #include "rpc_session.h" namespace tvm { namespace runtime { // Wrapped remote function to packed func. class RPCWrappedFunc { public: RPCWrappedFunc(void* handle, std::shared_ptr<RPCSession> sess) : handle_(handle), sess_(sess) { fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) { WrapRemote(sess, args, rv); }); } void operator()(TVMArgs args, TVMRetValue *rv) const { sess_->CallFunc(handle_, args, rv, &fwrap_); } ~RPCWrappedFunc() { try { sess_->CallRemote(RPCCode::kFreeFunc, handle_); } catch (const dmlc::Error& e) { // fault tolerance to remote close } } static void WrapRemote(std::shared_ptr<RPCSession> sess, TVMArgs args, TVMRetValue* rv); // deleter of RPC remote array static void RemoteNDArrayDeleter(NDArray::Container* ptr) { RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data); space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx); delete space; delete ptr; } // wrap return value as remote NDArray. static NDArray WrapRemoteNDArray(std::shared_ptr<RPCSession> sess, DLTensor* tensor, void* nd_handle) { NDArray::Container* data = new NDArray::Container(); data->manager_ctx = nd_handle; data->deleter = RemoteNDArrayDeleter; RemoteSpace* space = new RemoteSpace(); space->sess = sess; space->data = tensor->data; data->dl_tensor.data = space; NDArray ret(data); // RAII now in effect data->shape_ = std::vector<int64_t>( tensor->shape, tensor->shape + tensor->ndim); data->dl_tensor.shape = dmlc::BeginPtr(data->shape_); data->dl_tensor.ndim = static_cast<int>(data->shape_.size()); // setup dtype data->dl_tensor.dtype = tensor->dtype; // setup ctx, encode as remote session data->dl_tensor.ctx.device_id = tensor->ctx.device_id; data->dl_tensor.ctx.device_type = static_cast<DLDeviceType>( static_cast<int>(tensor->ctx.device_type) + kRPCSessMask * (sess->table_index() + 1)); // check strides. CHECK(tensor->strides == nullptr); // setup byteoffset data->dl_tensor.byte_offset = tensor->byte_offset; return ret; } private: PackedFunc fwrap_; void* handle_{nullptr}; std::shared_ptr<RPCSession> sess_; }; // RPC that represents a remote module session. class RPCModuleNode final : public ModuleNode { public: RPCModuleNode(void* module_handle, std::shared_ptr<RPCSession> sess) : module_handle_(module_handle), sess_(sess) { } ~RPCModuleNode() { if (module_handle_ != nullptr) { try { sess_->CallRemote(RPCCode::kModuleFree, module_handle_); } catch (const dmlc::Error& e) { // fault tolerance to remote close } module_handle_ = nullptr; } } const char* type_key() const final { return "rpc"; } PackedFunc GetFunction( const std::string& name, const std::shared_ptr<ModuleNode>& sptr_to_self) final { RPCFuncHandle handle = GetFuncHandle(name); return WrapRemote(handle); } std::string GetSource(const std::string& format) final { if (module_handle_ != nullptr) { std::string ret = sess_->CallRemote( RPCCode::kModuleGetSource, module_handle_, format); } return ""; } std::shared_ptr<RPCSession>& sess() { return sess_; } PackedFunc GetTimeEvaluator(const std::string& name, TVMContext ctx, int number, int repeat, int min_repeat_ms) { RPCFuncHandle handle = GetFuncHandle(name); if (handle == nullptr) return PackedFunc(); handle = sess_->GetTimeEvaluator(handle, ctx, number, repeat, min_repeat_ms); return WrapRemote(handle); } void* module_handle() const { return module_handle_; } private: PackedFunc WrapRemote(RPCFuncHandle handle) { if (handle == nullptr) return PackedFunc(); auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_); return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); }); } RPCFuncHandle GetFuncHandle(const std::string& name) { RPCFuncHandle handle = nullptr; if (module_handle_ == nullptr) { handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name); } else { handle = sess_->CallRemote( RPCCode::kModuleGetFunc, module_handle_, name); } return handle; } // The module handle void* module_handle_{nullptr}; // The local channel std::shared_ptr<RPCSession> sess_; // Wrap function to wrap remote module/function. PackedFunc fwrap_; }; void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess, TVMArgs args, TVMRetValue *rv) { void* handle = args.values[0].v_handle; int tcode = args.type_codes[0]; if (handle == nullptr) return; if (tcode == kFuncHandle) { auto wf = std::make_shared<RPCWrappedFunc>(handle, sess); *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { return wf->operator()(args, rv); }); } else if (tcode == kModuleHandle) { std::shared_ptr<RPCModuleNode> n = std::make_shared<RPCModuleNode>(handle, sess); *rv = Module(n); } else if (tcode == kArrayHandle || tcode == kNDArrayContainer) { CHECK_EQ(args.size(), 2); DLTensor* tensor = args[0]; void* nd_handle = args[1]; *rv = WrapRemoteNDArray(sess, tensor, nd_handle); } else { LOG(FATAL) << "Cannot wrap tcode=" << tcode; } } Module CreateRPCModule(std::shared_ptr<RPCSession> sess) { std::shared_ptr<RPCModuleNode> n = std::make_shared<RPCModuleNode>(nullptr, sess); return Module(n); } TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator") .set_body([](TVMArgs args, TVMRetValue* rv) { Module m = args[0]; std::string tkey = m->type_key(); TVMContext ctx; ctx.device_type = static_cast<DLDeviceType>(args[2].operator int()); ctx.device_id = args[3]; if (tkey == "rpc") { *rv = static_cast<RPCModuleNode*>(m.operator->()) ->GetTimeEvaluator(args[1], ctx, args[4], args[5], args[6]); } else { *rv = WrapTimeEvaluator( m.GetFunction(args[1], false), ctx, args[4], args[5], args[6]); } }); TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule") .set_body([](TVMArgs args, TVMRetValue* rv) { Module m = args[0]; std::string tkey = m->type_key(); CHECK_EQ(tkey, "rpc"); auto& sess = static_cast<RPCModuleNode*>(m.operator->())->sess(); void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]); std::shared_ptr<RPCModuleNode> n = std::make_shared<RPCModuleNode>(mhandle, sess); *rv = Module(n); }); TVM_REGISTER_GLOBAL("rpc._ImportRemoteModule") .set_body([](TVMArgs args, TVMRetValue* rv) { Module parent = args[0]; Module child = args[1]; CHECK(!std::strcmp(parent->type_key(), "rpc") && !std::strcmp(child->type_key(), "rpc")); auto* pmod = static_cast<RPCModuleNode*>(parent.operator->()); auto* cmod = static_cast<RPCModuleNode*>(child.operator->()); CHECK(pmod->sess().get() == cmod->sess().get()) << "Import of remote module need to belong to same session."; pmod->sess()->CallRemote(RPCCode::kModuleImport, pmod->module_handle(), cmod->module_handle()); }); TVM_REGISTER_GLOBAL("rpc._ModuleHandle") .set_body([](TVMArgs args, TVMRetValue* rv) { Module m = args[0]; std::string tkey = m->type_key(); CHECK_EQ(tkey, "rpc"); *rv = static_cast<RPCModuleNode*>(m.operator->())->module_handle(); }); TVM_REGISTER_GLOBAL("rpc._SessTableIndex") .set_body([](TVMArgs args, TVMRetValue* rv) { Module m = args[0]; std::string tkey = m->type_key(); CHECK_EQ(tkey, "rpc"); *rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index(); }); } // namespace runtime } // namespace tvm