rpc_session.h 11.7 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27
/*!
 * \file rpc_session.h
 * \brief Base RPC session interface.
 */
#ifndef TVM_RUNTIME_RPC_RPC_SESSION_H_
#define TVM_RUNTIME_RPC_RPC_SESSION_H_

#include <tvm/runtime/packed_func.h>
28
#include <tvm/runtime/device_api.h>
29 30
#include <mutex>
#include <string>
31 32
#include <memory>
#include <utility>
33
#include "../../support/ring_buffer.h"
34 35 36 37

namespace tvm {
namespace runtime {

38
// Magic header for RPC data plane
39
const int kRPCMagic = 0xff271;
40 41 42 43 44 45
// magic header for RPC tracker(control plane)
const int kRPCTrackerMagic = 0x2f271;
// sucess response
const int kRPCSuccess = kRPCMagic + 0;
// cannot found matched key in server
const int kRPCMismatch = kRPCMagic + 2;
46

47 48 49 50 51 52 53 54 55 56 57 58
/*! \brief Enumeration code for the RPC tracker */
enum class TrackerCode : int {
    kFail = -1,
    kSuccess = 0,
    kPing = 1,
    kStop = 2,
    kPut = 3,
    kRequest = 4,
    kUpdateInfo = 5,
    kSummary = 6,
    kGetPendingMatchKeys = 7
};
59 60 61 62 63 64 65
/*! \brief The remote functio handle */
using RPCFuncHandle = void*;

struct RPCArgBuffer;

/*! \brief The RPC code */
enum class RPCCode : int {
66
  kNone,
67 68 69 70 71 72 73 74
  kCallFunc,
  kReturn,
  kException,
  kShutdown,
  kCopyFromRemote,
  kCopyToRemote,
  kCopyAck,
  // The following are code that can send over CallRemote
75
  kSystemFuncStart,
76
  kGetGlobalFunc,
77
  kGetTimeEvaluator,
78 79 80 81 82 83 84 85
  kFreeFunc,
  kDevSetDevice,
  kDevGetAttr,
  kDevAllocData,
  kDevFreeData,
  kDevStreamSync,
  kCopyAmongRemote,
  kModuleLoad,
86
  kModuleImport,
87 88
  kModuleFree,
  kModuleGetFunc,
89
  kModuleGetSource,
90
  kNDArrayFree
91 92
};

93
/*!
94 95 96 97 98 99 100 101 102 103
 * \brief Function that unwraps a remote object to its handle.
 * \param rpc_sess_table_index RPC session table index for validation.
 * \param obj Handle to the object argument.
 * \return The corresponding handle.
 */
typedef void* (*FUnwrapRemoteObject)(
    int rpc_sess_table_index,
    const TVMArgValue& obj);

/*!
104 105 106 107 108 109 110 111 112 113 114 115 116 117
 * \brief Abstract channel interface used to create RPCSession.
 */
class RPCChannel {
 public:
  /*! \brief virtual destructor */
  virtual ~RPCChannel() {}
  /*!
   * \brief Send data over to the channel.
   * \param data The data pointer.
   * \param size The size fo the data.
   * \return The actual bytes sent.
   */
  virtual size_t Send(const void* data, size_t size) = 0;
  /*!
118
   * \brief Recv data from channel.
119 120 121 122 123 124 125 126
   *
   * \param data The data pointer.
   * \param size The size fo the data.
   * \return The actual bytes received.
   */
  virtual size_t Recv(void* data, size_t size) = 0;
};

127 128 129 130 131 132 133 134 135 136
// Bidirectional Communication Session of PackedRPC
class RPCSession {
 public:
  /*! \brief virtual destructor */
  ~RPCSession();
  /*!
   *  \brief The server loop that server runs to handle RPC calls.
   */
  void ServerLoop();
  /*!
137 138 139
   * \brief Message handling function for event driven server.
   *  Called when the server receives a message.
   *  Event driven handler will never call recv on the channel
140
   *  and always relies on the ServerEventHandler.
141 142
   *  to receive the data.
   *
143 144 145 146 147 148
   * \param in_bytes The incoming bytes.
   * \param event_flag  1: read_available, 2: write_avaiable.
   * \return State flag.
   *     1: continue running, no need to write,
   *     2: need to write
   *     0: shutdown
149
   */
150 151
  int ServerEventHandler(const std::string& in_bytes,
                         int event_flag);
152
  /*!
153 154 155 156
   * \brief Call into remote function
   * \param handle The function handle
   * \param args The arguments
   * \param rv The return value.
157
   * \param funpwrap Function that takes a remote object and returns the raw handle.
158
   * \param fwrap Wrapper function to turn Function/Module handle into real return.
159 160 161
   */
  void CallFunc(RPCFuncHandle handle,
                TVMArgs args,
162
                TVMRetValue* rv,
163
                FUnwrapRemoteObject funwrap,
164
                const PackedFunc* fwrap);
165 166 167 168 169 170
  /*!
   * \brief Copy bytes into remote array content.
   * \param from The source host data.
   * \param from_offset The byte offeset in the from.
   * \param to The target array.
   * \param to_offset The byte offset in the to.
171
   * \param nbytes The size of the memory in bytes.
172
   * \param ctx_to The target context.
173
   * \param type_hint Hint of content data type.
174 175 176 177 178
   */
  void CopyToRemote(void* from,
                    size_t from_offset,
                    void* to,
                    size_t to_offset,
179 180
                    size_t nbytes,
                    TVMContext ctx_to,
181
                    DLDataType type_hint);
182 183 184 185 186 187
  /*!
   * \brief Copy bytes from remote array content.
   * \param from The source host data.
   * \param from_offset The byte offeset in the from.
   * \param to The target array.
   * \param to_offset The byte offset in the to.
188
   * \param nbytes The size of the memory in bytes.
189
   * \param ctx_from The source context.
190
   * \param type_hint Hint of content data type.
191 192 193 194 195
   */
  void CopyFromRemote(void* from,
                      size_t from_offset,
                      void* to,
                      size_t to_offset,
196 197
                      size_t nbytes,
                      TVMContext ctx_from,
198
                      DLDataType type_hint);
199
  /*!
200 201 202 203 204
   * \brief Get a remote timer function on ctx.
   *  This function consumes fhandle, caller should not call Free on fhandle.
   *
   * \param fhandle The function handle.
   * \param ctx The ctx to run measurement on.
205 206 207 208 209 210 211 212 213 214 215 216 217
   * \param number The number of times to run this function for taking average.
          We call these runs as one `repeat` of measurement.
   * \param repeat The number of times to repeat the measurement.
          In total, the function will be invoked (1 + number x repeat) times,
          where the first one is warm up and will be discarded.
          The returned result contains `repeat` costs,
          each of which is an average of `number` costs.
   * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
          By default, one `repeat` contains `number` runs. If this parameter is set,
          the parameters `number` will be dynamically adjusted to meet the
          minimum duration requirement of one `repeat`.
          i.e., When the run time of one `repeat` falls below this time,
          the `number` parameter will be automatically increased.
218 219 220 221
   * \return A remote timer function
   */
  RPCFuncHandle GetTimeEvaluator(RPCFuncHandle fhandle,
                                 TVMContext ctx,
222
                                 int number,
223 224
                                 int repeat,
                                 int min_repeat_ms);
225
  /*!
226 227 228 229 230 231 232 233 234 235 236 237 238 239
   * \brief Call a remote defined system function with arguments.
   * \param fcode The function code.
   * \param args The arguments
   * \return The returned remote value.
   */
  template<typename... Args>
  inline TVMRetValue CallRemote(RPCCode fcode, Args&& ...args);
  /*!
   * \return The session table index of the session.
   */
  int table_index() const {
    return table_index_;
  }
  /*!
240
   * \brief Create a RPC session with given channel.
241
   * \param channel The communication channel.
242 243 244 245
   * \param name The local name of the session, used for debug
   * \param remote_key The remote key of the session
   *   if remote_key equals "%toinit", we need to re-intialize
   *   it by event handler.
246
   */
247 248
  static std::shared_ptr<RPCSession> Create(
      std::unique_ptr<RPCChannel> channel,
249 250
      std::string name,
      std::string remote_key);
251 252 253 254 255 256 257 258
  /*!
   * \brief Try get session from the global session table by table index.
   * \param table_index The table index of the session.
   * \return The shared_ptr to the session, can be nullptr.
   */
  static std::shared_ptr<RPCSession> Get(int table_index);

 private:
259 260 261
  class EventHandler;
  // Handle events until receives a return
  // Also flushes channels so that the function advances.
262 263
  RPCCode HandleUntilReturnEvent(
      TVMRetValue* rv, bool client_mode, const PackedFunc* fwrap);
264
  // Initalization
265
  void Init();
266
  // Shutdown
267
  void Shutdown();
268 269
  // Internal channel.
  std::unique_ptr<RPCChannel> channel_;
270 271
  // Internal mutex
  std::recursive_mutex mutex_;
272
  // Internal ring buffer.
273
  support::RingBuffer reader_, writer_;
274 275
  // Event handler.
  std::shared_ptr<EventHandler> handler_;
276
  // call remote with specified function code.
277 278 279
  PackedFunc call_remote_;
  // The index of this session in RPC session table.
  int table_index_{0};
280 281
  // The name of the session.
  std::string name_;
282 283
  // The remote key
  std::string remote_key_;
284 285
};

286
/*!
287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317
 * \brief RPC channel which callback
 * frontend (Python/Java/etc.)'s send & recv function
 */
class CallbackChannel final : public RPCChannel {
 public:
  explicit CallbackChannel(PackedFunc fsend, PackedFunc frecv)
      : fsend_(std::move(fsend)), frecv_(std::move(frecv)) {}

  ~CallbackChannel() {}
  /*!
   * \brief Send data over to the channel.
   * \param data The data pointer.
   * \param size The size fo the data.
   * \return The actual bytes sent.
   */
  size_t Send(const void* data, size_t size) final;
  /*!
   * \brief Recv data from channel.
   *
   * \param data The data pointer.
   * \param size The size fo the data.
   * \return The actual bytes received.
   */
  size_t Recv(void* data, size_t size) final;

 private:
  PackedFunc fsend_;
  PackedFunc frecv_;
};

/*!
318
 * \brief Wrap a timer function to measure the time cost of a given packed function.
319 320
 * \param f The function argument.
 * \param ctx The context.
321 322 323 324 325 326 327 328 329 330 331 332 333 334
 * \param number The number of times to run this function for taking average.
          We call these runs as one `repeat` of measurement.
 * \param repeat The number of times to repeat the measurement.
          In total, the function will be invoked (1 + number x repeat) times,
          where the first one is warm up and will be discarded.
          The returned result contains `repeat` costs,
          each of which is an average of `number` costs.
 * \param min_repeat_ms The minimum duration of one `repeat` in milliseconds.
          By default, one `repeat` contains `number` runs. If this parameter is set,
          the parameters `number` will be dynamically adjusted to meet the
          minimum duration requirement of one `repeat`.
          i.e., When the run time of one `repeat` falls below this time,
          the `number` parameter will be automatically increased.
 * \return f_timer A timer function.
335
 */
336 337 338 339 340
PackedFunc WrapTimeEvaluator(PackedFunc f,
                             TVMContext ctx,
                             int number,
                             int repeat,
                             int min_repeat_ms);
341

342 343 344 345 346 347 348
/*!
 * \brief Create a Global RPC module that refers to the session.
 * \param sess The RPC session of the global module.
 * \return The created module.
 */
Module CreateRPCModule(std::shared_ptr<RPCSession> sess);

349 350 351 352 353 354 355 356 357 358
// Remote space pointer.
struct RemoteSpace {
  void* data;
  std::shared_ptr<RPCSession> sess;
};

// implementation of inline functions
template<typename... Args>
inline TVMRetValue RPCSession::CallRemote(RPCCode code, Args&& ...args) {
  std::lock_guard<std::recursive_mutex> lock(mutex_);
359
  writer_.Write(&code, sizeof(code));
360 361 362 363 364
  return call_remote_(std::forward<Args>(args)...);
}
}  // namespace runtime
}  // namespace tvm
#endif  // TVM_RUNTIME_RPC_RPC_SESSION_H_