/*! * Copyright (c) 2017 by Contributors * \file rpc_session.h * \brief Base RPC session interface. */ #ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_ #define TVM_RUNTIME_RPC_RPC_SESSION_H_ #include <tvm/runtime/packed_func.h> #include <tvm/runtime/device_api.h> #include <mutex> #include <string> #include "../../common/ring_buffer.h" namespace tvm { namespace runtime { const int kRPCMagic = 0xff271; /*! \brief The remote functio handle */ using RPCFuncHandle = void*; struct RPCArgBuffer; /*! \brief The RPC code */ enum class RPCCode : int { kNone, kCallFunc, kReturn, kException, kShutdown, kCopyFromRemote, kCopyToRemote, kCopyAck, // The following are code that can send over CallRemote kSystemFuncStart, kGetGlobalFunc, kGetTimeEvaluator, kFreeFunc, kDevSetDevice, kDevGetAttr, kDevAllocData, kDevFreeData, kDevStreamSync, kCopyAmongRemote, kModuleLoad, kModuleImport, kModuleFree, kModuleGetFunc, kModuleGetSource, kNDArrayFree }; /*! * \brief Abstract channel interface used to create RPCSession. */ class RPCChannel { public: /*! \brief virtual destructor */ virtual ~RPCChannel() {} /*! * \brief Send data over to the channel. * \param data The data pointer. * \param size The size fo the data. * \return The actual bytes sent. */ virtual size_t Send(const void* data, size_t size) = 0; /*! e * \brief Recv data from channel. * * \param data The data pointer. * \param size The size fo the data. * \return The actual bytes received. */ virtual size_t Recv(void* data, size_t size) = 0; }; // Bidirectional Communication Session of PackedRPC class RPCSession { public: /*! \brief virtual destructor */ ~RPCSession(); /*! * \brief The server loop that server runs to handle RPC calls. */ void ServerLoop(); /*! * \brief Message handling function for event driven server. * Called when the server receives a message. * Event driven handler will never call recv on the channel * and always relies on the ServerEventHandler. * to receive the data. * * \param in_bytes The incoming bytes. * \param event_flag 1: read_available, 2: write_avaiable. * \return State flag. * 1: continue running, no need to write, * 2: need to write * 0: shutdown */ int ServerEventHandler(const std::string& in_bytes, int event_flag); /*! * \brief Call into remote function * \param handle The function handle * \param args The arguments * \param rv The return value. * \param fwrap Wrapper function to turn Function/Module handle into real return. */ void CallFunc(RPCFuncHandle handle, TVMArgs args, TVMRetValue* rv, const PackedFunc* fwrap); /*! * \brief Copy bytes into remote array content. * \param from The source host data. * \param from_offset The byte offeset in the from. * \param to The target array. * \param to_offset The byte offset in the to. * \param nbytes The size of the memory in bytes. * \param ctx_to The target context. * \param type_hint Hint of content data type. */ void CopyToRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, TVMContext ctx_to, TVMType type_hint); /*! * \brief Copy bytes from remote array content. * \param from The source host data. * \param from_offset The byte offeset in the from. * \param to The target array. * \param to_offset The byte offset in the to. * \param nbytes The size of the memory in bytes. * \param ctx_from The source context. * \param type_hint Hint of content data type. */ void CopyFromRemote(void* from, size_t from_offset, void* to, size_t to_offset, size_t nbytes, TVMContext ctx_from, TVMType type_hint); /*! * \brief Get a remote timer function on ctx. * This function consumes fhandle, caller should not call Free on fhandle. * * \param fhandle The function handle. * \param ctx The ctx to run measurement on. * \param number How many steps to run in each time evaluation * \param repeat How many times to repeat the timer * \return A remote timer function */ RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle, TVMContext ctx, int number, int repeat); /*! * \brief Call a remote defined system function with arguments. * \param fcode The function code. * \param args The arguments * \return The returned remote value. */ template<typename... Args> inline TVMRetValue CallRemote(RPCCode fcode, Args&& ...args); /*! * \return The session table index of the session. */ int table_index() const { return table_index_; } /*! * \brief Create a RPC session with given channel. * \param channel The communication channel. * \param name The local name of the session, used for debug * \param remote_key The remote key of the session * if remote_key equals "%toinit", we need to re-intialize * it by event handler. */ static std::shared_ptr<RPCSession> Create( std::unique_ptr<RPCChannel> channel, std::string name, std::string remote_key); /*! * \brief Try get session from the global session table by table index. * \param table_index The table index of the session. * \return The shared_ptr to the session, can be nullptr. */ static std::shared_ptr<RPCSession> Get(int table_index); private: class EventHandler; // Handle events until receives a return // Also flushes channels so that the function advances. RPCCode HandleUntilReturnEvent( TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap); // Initalization void Init(); // Shutdown void Shutdown(); // Internal channel. std::unique_ptr<RPCChannel> channel_; // Internal mutex std::recursive_mutex mutex_; // Internal ring buffer. common::RingBuffer reader_, writer_; // Event handler. std::shared_ptr<EventHandler> handler_; // call remote with specified function code. PackedFunc call_remote_; // The index of this session in RPC session table. int table_index_{0}; // The name of the session. std::string name_; // The remote key std::string remote_key_; }; /*! * \brief Wrap a timer function for a given packed function. * \param f The function argument. * \param ctx The context. * \param number Number of steps in the inner iteration * \param repeat How many steps to repeat the time evaluation. */ PackedFunc WrapTimeEvaluator(PackedFunc f, TVMContext ctx, int number, int repeat); /*! * \brief Create a Global RPC module that refers to the session. * \param sess The RPC session of the global module. * \return The created module. */ Module CreateRPCModule(std::shared_ptr<RPCSession> sess); // Remote space pointer. struct RemoteSpace { void* data; std::shared_ptr<RPCSession> sess; }; // implementation of inline functions template<typename... Args> inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) { std::lock_guard<std::recursive_mutex> lock(mutex_); writer_.Write(&code, sizeof(code)); return call_remote_(std::forward<Args>(args)...); } } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_RPC_RPC_SESSION_H_