rpc_session.h 9.18 KB
Newer Older
1 2 3 4 5 6 7 8 9
/*!
 *  Copyright (c) 2017 by Contributors
 * \file rpc_session.h
 * \brief Base RPC session interface.
 */
#ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_
#define TVM_RUNTIME_RPC_RPC_SESSION_H_

#include <tvm/runtime/packed_func.h>
10
#include <tvm/runtime/device_api.h>
11 12
#include <mutex>
#include <string>
13
#include "../../common/ring_buffer.h"
14 15 16 17

namespace tvm {
namespace runtime {

18 19
const int kRPCMagic = 0xff271;

20 21 22 23 24 25 26
/*! \brief The remote functio handle */
using RPCFuncHandle = void*;

struct RPCArgBuffer;

/*! \brief The RPC code */
enum class RPCCode : int {
27
  kNone,
28 29 30 31 32 33 34 35
  kCallFunc,
  kReturn,
  kException,
  kShutdown,
  kCopyFromRemote,
  kCopyToRemote,
  kCopyAck,
  // The following are code that can send over CallRemote
36
  kSystemFuncStart,
37
  kGetGlobalFunc,
38
  kGetTimeEvaluator,
39 40 41 42 43 44 45 46
  kFreeFunc,
  kDevSetDevice,
  kDevGetAttr,
  kDevAllocData,
  kDevFreeData,
  kDevStreamSync,
  kCopyAmongRemote,
  kModuleLoad,
47
  kModuleImport,
48 49
  kModuleFree,
  kModuleGetFunc,
50
  kModuleGetSource,
51
  kNDArrayFree
52 53
};

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
/*!
 * \brief Abstract channel interface used to create RPCSession.
 */
class RPCChannel {
 public:
  /*! \brief virtual destructor */
  virtual ~RPCChannel() {}
  /*!
   * \brief Send data over to the channel.
   * \param data The data pointer.
   * \param size The size fo the data.
   * \return The actual bytes sent.
   */
  virtual size_t Send(const void* data, size_t size) = 0;
  /*!
e   * \brief Recv data from channel.
   *
   * \param data The data pointer.
   * \param size The size fo the data.
   * \return The actual bytes received.
   */
  virtual size_t Recv(void* data, size_t size) = 0;
};

78 79 80 81 82 83 84 85 86 87
// Bidirectional Communication Session of PackedRPC
class RPCSession {
 public:
  /*! \brief virtual destructor */
  ~RPCSession();
  /*!
   *  \brief The server loop that server runs to handle RPC calls.
   */
  void ServerLoop();
  /*!
88 89 90
   * \brief Message handling function for event driven server.
   *  Called when the server receives a message.
   *  Event driven handler will never call recv on the channel
91
   *  and always relies on the ServerEventHandler.
92 93
   *  to receive the data.
   *
94 95 96 97 98 99
   * \param in_bytes The incoming bytes.
   * \param event_flag  1: read_available, 2: write_avaiable.
   * \return State flag.
   *     1: continue running, no need to write,
   *     2: need to write
   *     0: shutdown
100
   */
101 102
  int ServerEventHandler(const std::string& in_bytes,
                         int event_flag);
103
  /*!
104 105 106 107
   * \brief Call into remote function
   * \param handle The function handle
   * \param args The arguments
   * \param rv The return value.
108
   * \param fwrap Wrapper function to turn Function/Module handle into real return.
109 110 111
   */
  void CallFunc(RPCFuncHandle handle,
                TVMArgs args,
112 113
                TVMRetValue* rv,
                const PackedFunc* fwrap);
114 115 116 117 118 119
  /*!
   * \brief Copy bytes into remote array content.
   * \param from The source host data.
   * \param from_offset The byte offeset in the from.
   * \param to The target array.
   * \param to_offset The byte offset in the to.
120
   * \param nbytes The size of the memory in bytes.
121
   * \param ctx_to The target context.
122
   * \param type_hint Hint of content data type.
123 124 125 126 127
   */
  void CopyToRemote(void* from,
                    size_t from_offset,
                    void* to,
                    size_t to_offset,
128 129 130
                    size_t nbytes,
                    TVMContext ctx_to,
                    TVMType type_hint);
131 132 133 134 135 136
  /*!
   * \brief Copy bytes from remote array content.
   * \param from The source host data.
   * \param from_offset The byte offeset in the from.
   * \param to The target array.
   * \param to_offset The byte offset in the to.
137
   * \param nbytes The size of the memory in bytes.
138
   * \param ctx_from The source context.
139
   * \param type_hint Hint of content data type.
140 141 142 143 144
   */
  void CopyFromRemote(void* from,
                      size_t from_offset,
                      void* to,
                      size_t to_offset,
145 146 147
                      size_t nbytes,
                      TVMContext ctx_from,
                      TVMType type_hint);
148
  /*!
149 150 151 152 153
   * \brief Get a remote timer function on ctx.
   *  This function consumes fhandle, caller should not call Free on fhandle.
   *
   * \param fhandle The function handle.
   * \param ctx The ctx to run measurement on.
154 155 156 157 158 159 160 161 162 163 164 165 166
   * \param number The number of times to run this function for taking average.
          We call these runs as one `repeat` of measurement.
   * \param repeat The number of times to repeat the measurement.
          In total, the function will be invoked (1 + number x repeat) times,
          where the first one is warm up and will be discarded.
          The returned result contains `repeat` costs,
          each of which is an average of `number` costs.
   * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
          By default, one `repeat` contains `number` runs. If this parameter is set,
          the parameters `number` will be dynamically adjusted to meet the
          minimum duration requirement of one `repeat`.
          i.e., When the run time of one `repeat` falls below this time,
          the `number` parameter will be automatically increased.
167 168 169 170
   * \return A remote timer function
   */
  RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle,
                                 TVMContext ctx,
171
                                 int number,
172 173
                                 int repeat,
                                 int min_repeat_ms);
174
  /*!
175 176 177 178 179 180 181 182 183 184 185 186 187 188
   * \brief Call a remote defined system function with arguments.
   * \param fcode The function code.
   * \param args The arguments
   * \return The returned remote value.
   */
  template<typename... Args>
  inline TVMRetValue CallRemote(RPCCode fcode, Args&& ...args);
  /*!
   * \return The session table index of the session.
   */
  int table_index() const {
    return table_index_;
  }
  /*!
189
   * \brief Create a RPC session with given channel.
190
   * \param channel The communication channel.
191 192 193 194
   * \param name The local name of the session, used for debug
   * \param remote_key The remote key of the session
   *   if remote_key equals "%toinit", we need to re-intialize
   *   it by event handler.
195
   */
196 197
  static std::shared_ptr<RPCSession> Create(
      std::unique_ptr<RPCChannel> channel,
198 199
      std::string name,
      std::string remote_key);
200 201 202 203 204 205 206 207
  /*!
   * \brief Try get session from the global session table by table index.
   * \param table_index The table index of the session.
   * \return The shared_ptr to the session, can be nullptr.
   */
  static std::shared_ptr<RPCSession> Get(int table_index);

 private:
208 209 210
  class EventHandler;
  // Handle events until receives a return
  // Also flushes channels so that the function advances.
211 212
  RPCCode HandleUntilReturnEvent(
      TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap);
213
  // Initalization
214
  void Init();
215
  // Shutdown
216
  void Shutdown();
217 218
  // Internal channel.
  std::unique_ptr<RPCChannel> channel_;
219 220
  // Internal mutex
  std::recursive_mutex mutex_;
221 222 223 224
  // Internal ring buffer.
  common::RingBuffer reader_, writer_;
  // Event handler.
  std::shared_ptr<EventHandler> handler_;
225
  // call remote with specified function code.
226 227 228
  PackedFunc call_remote_;
  // The index of this session in RPC session table.
  int table_index_{0};
229 230
  // The name of the session.
  std::string name_;
231 232
  // The remote key
  std::string remote_key_;
233 234
};

235
/*!
236
 * \brief Wrap a timer function to measure the time cost of a given packed function.
237 238
 * \param f The function argument.
 * \param ctx The context.
239 240 241 242 243 244 245 246 247 248 249 250 251 252
 * \param number The number of times to run this function for taking average.
          We call these runs as one `repeat` of measurement.
 * \param repeat The number of times to repeat the measurement.
          In total, the function will be invoked (1 + number x repeat) times,
          where the first one is warm up and will be discarded.
          The returned result contains `repeat` costs,
          each of which is an average of `number` costs.
 * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
          By default, one `repeat` contains `number` runs. If this parameter is set,
          the parameters `number` will be dynamically adjusted to meet the
          minimum duration requirement of one `repeat`.
          i.e., When the run time of one `repeat` falls below this time,
          the `number` parameter will be automatically increased.
 * \return f_timer A timer function.
253
 */
254 255 256 257 258
PackedFunc WrapTimeEvaluator(PackedFunc f,
                             TVMContext ctx,
                             int number,
                             int repeat,
                             int min_repeat_ms);
259

260 261 262 263 264 265 266
/*!
 * \brief Create a Global RPC module that refers to the session.
 * \param sess The RPC session of the global module.
 * \return The created module.
 */
Module CreateRPCModule(std::shared_ptr<RPCSession> sess);

267 268 269 270 271 272 273 274 275 276
// Remote space pointer.
struct RemoteSpace {
  void* data;
  std::shared_ptr<RPCSession> sess;
};

// implementation of inline functions
template<typename... Args>
inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) {
  std::lock_guard<std::recursive_mutex> lock(mutex_);
277
  writer_.Write(&code, sizeof(code));
278 279 280 281 282
  return call_remote_(std::forward<Args>(args)...);
}
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_RUNTIME_RPC_RPC_SESSION_H_