rpc_session.cc 39.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25
/*!
 *  Copyright (c) 2017 by Contributors
 * \file rpc_session.cc
 * \brief RPC session for remote function call.
 */
#include <tvm/runtime/packed_func.h>
26 27
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
tqchen committed
28
#include <tvm/runtime/serializer.h>
29 30
#include <memory>
#include <array>
31
#include <string>
32
#include <chrono>
33 34
#include <vector>
#include <utility>
35 36
#include <cmath>
#include <algorithm>
37
#include "rpc_session.h"
38
#include "../../common/ring_buffer.h"
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65

namespace tvm {
namespace runtime {
// Temp buffer for data array
struct RPCByteArrayBuffer {
  TVMByteArray arr;
  std::string data;
};
// Temp buffer for data array
struct RPCDataArrayBuffer {
  DLTensor tensor;
  std::vector<int64_t> shape;
};
/*!
 * \brief Temporal argument buffer.
 */
struct RPCArgBuffer {
  // The argument values
  std::vector<TVMValue> value;
  // The type codes.
  std::vector<int> tcode;
  // Temporal resources.
  std::vector<std::unique_ptr<RPCByteArrayBuffer> > temp_bytes;
  // Temporal array
  std::vector<std::unique_ptr<RPCDataArrayBuffer> > temp_array;
  // convert buffer as TVMArgs
  TVMArgs AsTVMArgs() const {
66
    return TVMArgs(value.data(), tcode.data(), static_cast<int>(value.size()));
67 68 69
  }
};

70
// Event handler for RPC events.
tqchen committed
71
class RPCSession::EventHandler : public dmlc::Stream {
72 73 74 75
 public:
  EventHandler(common::RingBuffer* reader,
               common::RingBuffer* writer,
               int rpc_sess_table_index,
76 77 78 79
               std::string name,
               std::string* remote_key)
      : reader_(reader),
        writer_(writer),
80
        rpc_sess_table_index_(rpc_sess_table_index),
81 82
        name_(name),
        remote_key_(remote_key) {
83
    this->Clear();
84 85 86 87 88
    if (*remote_key == "%toinit") {
      state_ = kInitHeader;
      remote_key_->resize(0);
      pending_request_bytes_ = sizeof(int32_t);
    }
89 90 91 92 93 94 95 96 97
  }
  // Bytes needed to fulfill current request
  size_t BytesNeeded() {
    if (reader_->bytes_available() < pending_request_bytes_) {
      return pending_request_bytes_ - reader_->bytes_available();
    } else {
      return 0;
    }
  }
tqchen committed
98 99 100 101 102 103 104 105 106
  // Request number of bytes from reader.
  void RequestBytes(size_t nbytes) {
    pending_request_bytes_ += nbytes;
    reader_->Reserve(pending_request_bytes_);
  }
  // Whether we are ready to handle next request.
  bool Ready() {
    return reader_->bytes_available() >= pending_request_bytes_;
  }
107 108 109 110 111 112
  bool CanCleanShutdown() const {
    return state_ == kRecvCode;
  }
  void FinishCopyAck() {
    this->SwitchToState(kRecvCode);
  }
113 114 115 116
  RPCCode HandleNextEvent(TVMRetValue* rv,
                          bool client_mode,
                          const PackedFunc* fwrap) {
    std::swap(client_mode_, client_mode);
117 118
    while (this->Ready()) {
      switch (state_) {
119
        case kInitHeader: HandleInitHeader(); break;
120 121
        case kRecvCode: HandleRecvCode(); break;
        case kRecvCallHandle: {
tqchen committed
122
          CHECK(this->Read(&call_handle_));
123 124 125 126
          this->SwitchToState(kRecvPackedSeqNumArgs);
          break;
        }
        case kRecvPackedSeqNumArgs: {
tqchen committed
127
          CHECK(this->Read(&num_packed_args_));
128 129 130 131 132 133 134 135
          arg_buf_.reset(new RPCArgBuffer());
          arg_buf_->value.resize(num_packed_args_);
          arg_buf_->tcode.resize(num_packed_args_);
          this->SwitchToState(kRecvPackedSeqTypeCode);
          break;
        }
        case kRecvPackedSeqTypeCode: {
          if (num_packed_args_ != 0) {
tqchen committed
136
            this->ReadArray(arg_buf_->tcode.data(), num_packed_args_);
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
          }
          arg_index_ = 0;
          arg_recv_stage_ = 0;
          this->SwitchToState(kRecvPackedSeqArg);
          break;
        }
        case kRecvPackedSeqArg: {
          this->HandleRecvPackedSeqArg();
          break;
        }
        case kDoCopyFromRemote: {
          this->HandleCopyFromRemote();
          break;
        }
        case kDoCopyToRemote: {
          this->HandleCopyToRemote();
          break;
        }
        case kReturnReceived: {
156 157
          CHECK_GE(arg_buf_->value.size(), 1U);

158
          TVMArgValue argv = arg_buf_->AsTVMArgs()[0];
159
          if (argv.type_code() == kFuncHandle ||
160 161
              argv.type_code() == kModuleHandle ||
              argv.type_code() == kArrayHandle) {
162 163 164
            CHECK(fwrap != nullptr) << "function/module wrapper not available";
            fwrap->CallPacked(arg_buf_->AsTVMArgs(), rv);
          } else {
165
            CHECK_EQ(arg_buf_->value.size(), 1U);
166 167
            *rv = argv;
          }
168 169
          arg_buf_.reset();
          this->SwitchToState(kRecvCode);
170
          std::swap(client_mode_, client_mode);
171
          return RPCCode::kReturn;
172 173
        }
        case kCopyAckReceived: {
174
          std::swap(client_mode_, client_mode);
175 176 177
          return RPCCode::kCopyAck;
        }
        case kShutdownReceived: {
178
          std::swap(client_mode_, client_mode);
179 180 181 182
          return RPCCode::kShutdown;
        }
      }
    }
183
    std::swap(client_mode_, client_mode);
184 185 186 187 188 189 190 191 192
    return RPCCode::kNone;
  }
  // Reset and clear all states.
  void Clear() {
    state_ = kRecvCode;
    pending_request_bytes_ = sizeof(RPCCode);
    arg_recv_stage_ = 0;
    arg_buf_.reset();
  }
193
  // strip session on mask
194 195 196
  TVMContext StripSessMask(TVMContext ctx) {
    int dev_type = ctx.device_type;
    CHECK_EQ(dev_type / kRPCSessMask, rpc_sess_table_index_ + 1)
197
        << "Can not pass in local context or context with a different remote session";
198 199 200
    ctx.device_type = static_cast<DLDeviceType>(dev_type % kRPCSessMask);
    return ctx;
  }
201 202 203 204 205 206 207 208 209
  // Send Packed sequence to writer.
  // return_ndarray is a special flag to handle returning of ndarray
  //    In this case, we return the shape, context and data of the array,
  //    as well as a customized PackedFunc that handles deletion of
  //    the array in the remote.
  void SendPackedSeq(const TVMValue* arg_values,
                     const int* type_codes,
                     int n,
                     bool return_ndarray = false) {
tqchen committed
210
    this->Write(n);
211 212 213 214 215
    for (int i = 0; i < n; ++i) {
      int tcode = type_codes[i];
      if (tcode == kNDArrayContainer) tcode = kArrayHandle;
      this->Write(tcode);
    }
216

217 218 219 220 221
    // Argument packing.
    for (int i = 0; i < n; ++i) {
      int tcode = type_codes[i];
      TVMValue value = arg_values[i];
      switch (tcode) {
222 223
        case kDLInt:
        case kDLUInt:
tqchen committed
224 225 226 227
        case kDLFloat: {
          this->Write<int64_t>(value.v_int64);
          break;
        }
228
        case kTVMType: {
tqchen committed
229 230 231 232
          this->Write(value.v_type);
          // padding
          int32_t padding = 0;
          this->Write<int32_t>(padding);
233 234 235 236
          break;
        }
        case kTVMContext: {
          value.v_ctx = StripSessMask(value.v_ctx);
tqchen committed
237
          this->Write(value.v_ctx);
238 239
          break;
        }
240 241
        case kFuncHandle:
        case kModuleHandle:
242 243 244
        case kHandle: {
          // always send handle in 64 bit.
          uint64_t handle = reinterpret_cast<uint64_t>(value.v_handle);
tqchen committed
245
          this->Write(handle);
246 247
          break;
        }
248
        case kNDArrayContainer:
249 250
        case kArrayHandle: {
          DLTensor* arr = static_cast<DLTensor*>(value.v_handle);
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267
          TVMContext ctx;
          uint64_t data;
          if (!return_ndarray) {
            // in the client mode
            // ctx contains the remote table index
            // the space is wrapped by an RemoteSpace
            // that holds reference to the session.
            ctx = StripSessMask(arr->ctx);
            data = reinterpret_cast<uint64_t>(
                static_cast<RemoteSpace*>(arr->data)->data);
          } else {
            // When we return NDArray, we directly return
            // the space and the context
            // The client will be further wrapping
            ctx = arr->ctx;
            data = reinterpret_cast<uint64_t>(arr->data);
          }
tqchen committed
268 269 270 271 272
          this->Write(data);
          this->Write(ctx);
          this->Write(arr->ndim);
          this->Write(arr->dtype);
          this->WriteArray(arr->shape, arr->ndim);
273
          CHECK(arr->strides == nullptr)
Siju committed
274
              << "Do not support strided remote array";
275
          CHECK_EQ(arr->byte_offset, 0)
Siju committed
276
              << "Do not support send byte offset";
277 278 279 280 281 282
          break;
        }
        case kNull: break;
        case kStr: {
          const char* s = value.v_str;
          uint64_t len = strlen(s);
tqchen committed
283 284
          this->Write(len);
          this->WriteArray(s, len);
285 286 287 288 289
          break;
        }
        case kBytes: {
          TVMByteArray* bytes = static_cast<TVMByteArray*>(arg_values[i].v_handle);
          uint64_t len = bytes->size;
tqchen committed
290 291
          this->Write(len);
          this->WriteArray(bytes->data, len);
292 293 294 295 296 297 298 299 300 301
          break;
        }
        default: {
          LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
          break;
        }
      }
    }
  }

tqchen committed
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
  // Endian aware IO handling
  using Stream::Read;
  using Stream::Write;
  using Stream::ReadArray;
  using Stream::WriteArray;

  inline bool Read(RPCCode* code) {
    int cdata;
    if (!this->Read(&cdata)) return false;
    *code = static_cast<RPCCode>(cdata);
    return true;
  }
  inline void Write(RPCCode code) {
    int cdata = static_cast<int>(code);
    this->Write(cdata);
  }

319 320
 protected:
  enum State {
321
    kInitHeader,
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
    kRecvCode,
    kRecvCallHandle,
    kRecvPackedSeqNumArgs,
    kRecvPackedSeqTypeCode,
    kRecvPackedSeqArg,
    kDoCopyFromRemote,
    kDoCopyToRemote,
    kReturnReceived,
    kCopyAckReceived,
    kShutdownReceived
  };
  // Current state;
  State state_;
  // The RPCCode to be read.
  RPCCode code_;
  // Handle for the remote function call.
  uint64_t call_handle_;
339 340
  // Initialize remote header
  bool init_header_step_{0};
341 342 343 344 345 346
  // Number of packed arguments.
  int num_packed_args_;
  // Current argument index.
  int arg_index_;
  // The stage of each argument receiver.
  int arg_recv_stage_;
347 348
  // Whether current handler is client or server mode.
  bool client_mode_{false};
349 350 351 352 353 354 355 356 357 358
  // Argument buffer
  std::unique_ptr<RPCArgBuffer> arg_buf_;
  // Temp byte buffer.
  std::unique_ptr<RPCByteArrayBuffer> temp_bytes_;
  // Temp array buffer.
  std::unique_ptr<RPCDataArrayBuffer> temp_array_;
  // Internal temporal data space.
  std::string temp_data_;
  // Temp variables for copy request state.
  TVMContext copy_ctx_;
359
  TVMType copy_dtype_;
360 361 362 363 364 365 366 367
  uint64_t copy_handle_, copy_offset_, copy_size_;
  // State switcher
  void SwitchToState(State state) {
    // invariant
    CHECK_EQ(pending_request_bytes_, 0U)
        << "state=" << state;
    state_ = state;
    switch (state) {
368 369 370 371
      case kInitHeader: {
        LOG(FATAL) << "cannot switch to init header";
        break;
      }
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
      case kRecvCode: {
        this->RequestBytes(sizeof(RPCCode));
        break;
      }
      case kRecvCallHandle: {
        this->RequestBytes(sizeof(call_handle_));
        break;
      }
      case kRecvPackedSeqNumArgs: {
        this->RequestBytes(sizeof(num_packed_args_));
        break;
      }
      case kRecvPackedSeqTypeCode: {
        this->RequestBytes(sizeof(int) * num_packed_args_);
        break;
      }
      case kRecvPackedSeqArg: {
        CHECK_LE(arg_index_, num_packed_args_);
        if (arg_index_ == num_packed_args_) {
          // The function can change state_ again.
          HandlePackedCall();
        } else {
          RequestRecvPackedSeqArg();
        }
        break;
      }
      case kDoCopyFromRemote: {
        this->RequestBytes(sizeof(uint64_t) * 3);
        this->RequestBytes(sizeof(TVMContext));
401
        this->RequestBytes(sizeof(TVMType));
402 403 404 405 406
        break;
      }
      case kDoCopyToRemote: {
        this->RequestBytes(sizeof(uint64_t) * 3);
        this->RequestBytes(sizeof(TVMContext));
407
        this->RequestBytes(sizeof(TVMType));
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
        break;
      }
      case kCopyAckReceived:
      case kReturnReceived:
      case kShutdownReceived: {
        break;
      }
    }
  }
  // Requets bytes needed for next computation.
  void RequestRecvPackedSeqArg() {
    CHECK_EQ(arg_recv_stage_, 0);
    int tcode = arg_buf_->tcode[arg_index_];
    static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant");
    switch (tcode) {
423 424 425
      case kDLInt:
      case kDLUInt:
      case kDLFloat:
426 427 428 429
      case kTVMType:
      case kHandle:
      case kStr:
      case kBytes:
430 431 432 433 434 435 436 437 438
      case kTVMContext: {
        this->RequestBytes(sizeof(TVMValue)); break;
      }
      case kFuncHandle:
      case kModuleHandle: {
        CHECK(client_mode_)
            << "Only client can receive remote functions";
        this->RequestBytes(sizeof(TVMValue)); break;
      }
439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459
      case kNull: break;
      case kArrayHandle: {
        this->RequestBytes(sizeof(uint64_t));
        this->RequestBytes(sizeof(TVMContext));
        this->RequestBytes(sizeof(int));
        this->RequestBytes(sizeof(DLDataType));
        break;
      }
      default: {
        LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
        break;
      }
    }
  }
  // Handler for packed sequence argument receive.
  void HandleRecvPackedSeqArg() {
    CHECK_LT(arg_index_, num_packed_args_);
    int tcode = arg_buf_->tcode[arg_index_];
    TVMValue& value = arg_buf_->value[arg_index_];
    if (arg_recv_stage_ == 0) {
      switch (tcode) {
460 461
        case kDLInt:
        case kDLUInt:
tqchen committed
462 463 464 465 466 467 468 469 470 471 472 473 474 475
        case kDLFloat: {
          this->Read<int64_t>(&(value.v_int64));
          ++arg_index_;
          this->SwitchToState(kRecvPackedSeqArg);
          break;
        }
        case kTVMType: {
          this->Read(&(value.v_type));
          int32_t padding = 0;
          this->Read<int32_t>(&padding);
          ++arg_index_;
          this->SwitchToState(kRecvPackedSeqArg);
          break;
        }
476
        case kTVMContext: {
tqchen committed
477
          this->Read(&(value.v_ctx));
478 479 480 481
          ++arg_index_;
          this->SwitchToState(kRecvPackedSeqArg);
          break;
        }
482 483
        case kFuncHandle:
        case kModuleHandle:
484 485 486
        case kHandle: {
          // always send handle in 64 bit.
          uint64_t handle;
tqchen committed
487
          this->Read(&handle);
488 489 490 491 492 493 494 495 496 497 498 499 500 501
          value.v_handle = reinterpret_cast<void*>(handle);
          ++arg_index_;
          this->SwitchToState(kRecvPackedSeqArg);
          break;
        }
        case kNull: {
          value.v_handle = nullptr;
          ++arg_index_;
          this->SwitchToState(kRecvPackedSeqArg);
          break;
        }
        case kStr:
        case kBytes: {
          uint64_t len;
tqchen committed
502
          this->Read(&len);
503 504 505 506 507
          temp_bytes_.reset( new RPCByteArrayBuffer());
          temp_bytes_->data.resize(len);
          arg_recv_stage_ = 1;
          this->RequestBytes(len);
          break;
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
        }
        case kArrayHandle: {
          temp_array_.reset(new RPCDataArrayBuffer());
          uint64_t handle;
          this->Read(&handle);
          DLTensor& tensor = temp_array_->tensor;
          tensor.data = reinterpret_cast<void*>(handle);
          this->Read(&(tensor.ctx));
          this->Read(&(tensor.ndim));
          this->Read(&(tensor.dtype));
          temp_array_->shape.resize(tensor.ndim);
          tensor.shape = temp_array_->shape.data();
          arg_recv_stage_ = 1;
          tensor.strides = nullptr;
          tensor.byte_offset = 0;
          this->RequestBytes(sizeof(int64_t) * tensor.ndim);
          break;
        }
        default: {
          LOG(FATAL) << "RPC cannot handle type " << TypeCode2Str(tcode);
          break;
        }
530 531 532 533 534
      }
    } else {
      CHECK_EQ(arg_recv_stage_, 1);
      if (tcode == kStr || tcode == kBytes) {
        if (temp_bytes_->data.size() != 0) {
tqchen committed
535
          this->ReadArray(&(temp_bytes_->data[0]), temp_bytes_->data.size());
536 537 538 539 540 541 542 543 544 545 546 547
        }
        if (tcode == kStr) {
          value.v_str = temp_bytes_->data.c_str();
        } else {
          temp_bytes_->arr.size = static_cast<size_t>(temp_bytes_->data.size());
          temp_bytes_->arr.data = dmlc::BeginPtr(temp_bytes_->data);
          value.v_handle = &(temp_bytes_->arr);
        }
        arg_buf_->temp_bytes.emplace_back(std::move(temp_bytes_));
      } else {
        CHECK_EQ(tcode, kArrayHandle);
        DLTensor& tensor = temp_array_->tensor;
tqchen committed
548
        this->ReadArray(tensor.shape, tensor.ndim);
549 550 551 552 553 554 555 556
        value.v_handle = &tensor;
        arg_buf_->temp_array.emplace_back(std::move(temp_array_));
      }
      ++arg_index_;
      arg_recv_stage_ = 0;
      this->SwitchToState(kRecvPackedSeqArg);
    }
  }
557 558 559 560
  // handler for initial header read
  void HandleInitHeader() {
    if (init_header_step_ == 0) {
      int32_t len;
tqchen committed
561
      this->Read(&len);
562 563 564 565 566 567
      remote_key_->resize(len);
      init_header_step_ = 1;
      this->RequestBytes(len);
      return;
    } else {
      CHECK_EQ(init_header_step_, 1);
tqchen committed
568
      this->ReadArray(dmlc::BeginPtr(*remote_key_), remote_key_->length());
569 570 571
      this->SwitchToState(kRecvCode);
    }
  }
572 573
  // Handler for read code.
  void HandleRecvCode() {
tqchen committed
574
    this->Read(&code_);
575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
    if (code_ > RPCCode::kSystemFuncStart) {
      SwitchToState(kRecvPackedSeqNumArgs);
      return;
    }
    // invariant.
    CHECK_EQ(arg_recv_stage_, 0);
    switch (code_) {
      case RPCCode::kCallFunc: {
        SwitchToState(kRecvCallHandle);
        break;
      }
      case RPCCode::kException:
      case RPCCode::kReturn: {
        SwitchToState(kRecvPackedSeqNumArgs);
        break;
      }
      case RPCCode::kCopyFromRemote: {
        SwitchToState(kDoCopyFromRemote);
        break;
      }
      case RPCCode::kCopyToRemote: {
        SwitchToState(kDoCopyToRemote);
        break;
      }
      case RPCCode::kShutdown: {
        SwitchToState(kShutdownReceived);
        break;
      }
      case RPCCode::kCopyAck: {
        SwitchToState(kCopyAckReceived);
        break;
      }
      default: LOG(FATAL) << "Unknown event "  << static_cast<int>(code_);
    }
  }

  void HandleCopyFromRemote() {
612
    uint64_t handle, offset, num_bytes;
613
    TVMContext ctx;
614
    TVMType type_hint;
tqchen committed
615 616
    this->Read(&handle);
    this->Read(&offset);
617
    this->Read(&num_bytes);
tqchen committed
618
    this->Read(&ctx);
619 620 621
    this->Read(&type_hint);
    size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8;

622
    if (ctx.device_type == kDLCPU) {
623
      RPCCode code = RPCCode::kCopyAck;
tqchen committed
624
      this->Write(code);
625 626 627 628 629 630 631 632 633
      char* dptr = reinterpret_cast<char*>(handle) + offset;
      if (!DMLC_IO_NO_ENDIAN_SWAP) {
        temp_data_.resize(0);
        temp_data_.insert(temp_data_.end(), dptr, dptr + num_bytes);
        dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes);
        this->WriteArray(temp_data_.data(), num_bytes);
      } else {
        this->WriteArray(dptr, num_bytes);
      }
634
    } else {
635
      temp_data_.resize(num_bytes + 1);
636 637
      try {
        TVMContext cpu_ctx;
638
        cpu_ctx.device_type = kDLCPU;
639 640 641 642
        cpu_ctx.device_id = 0;
        DeviceAPI::Get(ctx)->CopyDataFromTo(
            reinterpret_cast<void*>(handle), offset,
            dmlc::BeginPtr(temp_data_), 0,
643
            num_bytes, ctx, cpu_ctx, type_hint, nullptr);
644
        RPCCode code = RPCCode::kCopyAck;
tqchen committed
645
        this->Write(code);
646 647 648 649
        if (!DMLC_IO_NO_ENDIAN_SWAP) {
          dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, num_bytes / elem_bytes);
        }
        this->WriteArray(&temp_data_[0], num_bytes);
650 651
      } catch (const std::runtime_error &e) {
        RPCCode code = RPCCode::kException;
tqchen committed
652
        this->Write(code);
653 654 655 656 657 658 659 660 661 662 663 664 665
        TVMValue ret_value;
        ret_value.v_str = e.what();
        int ret_tcode = kStr;
        SendPackedSeq(&ret_value, &ret_tcode, 1);
      }
    }
    this->SwitchToState(kRecvCode);
  }

  void HandleCopyToRemote() {
    // use static variable to persist state.
    // This only works if next stage is immediately after this.
    if (arg_recv_stage_ == 0) {
tqchen committed
666 667 668 669
      CHECK(this->Read(&copy_handle_));
      CHECK(this->Read(&copy_offset_));
      CHECK(this->Read(&copy_size_));
      CHECK(this->Read(&copy_ctx_));
670
      CHECK(this->Read(&copy_dtype_));
671 672 673 674 675 676 677 678 679 680
      arg_recv_stage_ = 1;
      CHECK_EQ(pending_request_bytes_, 0U);
      this->RequestBytes(copy_size_);
    } else {
      CHECK_EQ(arg_recv_stage_, 1);
      TVMValue ret_value;
      ret_value.v_handle = nullptr;
      int ret_tcode = kNull;
      RPCCode code = RPCCode::kReturn;
      std::string errmsg;
681 682

      size_t elem_bytes = (copy_dtype_.bits * copy_dtype_.lanes + 7) / 8;
683
      if (copy_ctx_.device_type == kDLCPU) {
684 685 686 687 688
        char* dptr = reinterpret_cast<char*>(copy_handle_) + copy_offset_;
        this->ReadArray(dptr, copy_size_);
        if (!DMLC_IO_NO_ENDIAN_SWAP) {
          dmlc::ByteSwap(dptr, elem_bytes, copy_size_ / elem_bytes);
        }
689 690
      } else {
        temp_data_.resize(copy_size_ + 1);
tqchen committed
691
        this->ReadArray(&temp_data_[0], copy_size_);
692 693 694
        if (!DMLC_IO_NO_ENDIAN_SWAP) {
          dmlc::ByteSwap(dmlc::BeginPtr(temp_data_), elem_bytes, copy_size_ / elem_bytes);
        }
695 696
        try {
          TVMContext cpu_ctx;
697
          cpu_ctx.device_type = kDLCPU;
698 699 700 701
          cpu_ctx.device_id = 0;
          DeviceAPI::Get(copy_ctx_)->CopyDataFromTo(
              temp_data_.data(), 0,
              reinterpret_cast<void*>(copy_handle_), copy_offset_,
702
              copy_size_, cpu_ctx, copy_ctx_, copy_dtype_, nullptr);
703 704 705 706 707 708 709
        } catch (const std::runtime_error &e) {
          code = RPCCode::kException;
          errmsg = e.what();
          ret_value.v_str = errmsg.c_str();
          ret_tcode = kStr;
        }
      }
tqchen committed
710
      this->Write(code);
711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729
      SendPackedSeq(&ret_value, &ret_tcode, 1);
      arg_recv_stage_ = 0;
      this->SwitchToState(kRecvCode);
    }
  }
  // Handle for packed call.
  void HandlePackedCall();

  template<typename F>
  void CallHandler(F f) {
    TVMRetValue rv;
    TVMValue ret_value;
    int ret_tcode;
    try {
      // Need to move out, in case f itself need to call RecvPackedSeq
      // Which will override argbuf again.
      std::unique_ptr<RPCArgBuffer> args = std::move(arg_buf_);
      f(args->AsTVMArgs(), &rv);
      RPCCode code = RPCCode::kReturn;
tqchen committed
730
      this->Write(code);
731 732 733 734 735 736 737 738 739 740 741 742
      if (rv.type_code() == kStr) {
        ret_value.v_str = rv.ptr<std::string>()->c_str();
        ret_tcode = kStr;
        SendPackedSeq(&ret_value, &ret_tcode, 1);
      } else if (rv.type_code() == kBytes) {
        std::string* bytes = rv.ptr<std::string>();
        TVMByteArray arr;
        arr.data = bytes->c_str();
        arr.size = bytes->length();
        ret_value.v_handle = &arr;
        ret_tcode = kBytes;
        SendPackedSeq(&ret_value, &ret_tcode, 1);
743 744 745 746 747 748 749
      } else if (rv.type_code() == kFuncHandle ||
                 rv.type_code() == kModuleHandle) {
        // always send handle in 64 bit.
        CHECK(!client_mode_)
              << "Only server can send function and module handle back.";
        rv.MoveToCHost(&ret_value, &ret_tcode);
        SendPackedSeq(&ret_value, &ret_tcode, 1);
750 751 752 753 754 755 756 757 758 759 760 761 762 763 764
      } else if (rv.type_code() == kNDArrayContainer) {
        // always send handle in 64 bit.
        CHECK(!client_mode_)
            << "Only server can send NDArray back";
        // We follow a special protocol to return NDArray to client side
        // The first pack value is the NDArray handle as DLTensor
        // The second pack value is a customized deleter that deletes the NDArray.
        TVMValue ret_value_pack[2];
        int ret_tcode_pack[2];
        rv.MoveToCHost(&ret_value_pack[0], &ret_tcode_pack[0]);

        NDArray::Container* nd = static_cast<NDArray::Container*>(ret_value_pack[0].v_handle);
        ret_value_pack[1].v_handle = nd;
        ret_tcode_pack[1] = kHandle;
        SendPackedSeq(ret_value_pack, ret_tcode_pack, 2, true);
765 766 767 768 769 770 771
      } else {
        ret_value = rv.value();
        ret_tcode = rv.type_code();
        SendPackedSeq(&ret_value, &ret_tcode, 1);
      }
    } catch (const std::runtime_error& e) {
      RPCCode code = RPCCode::kException;
tqchen committed
772
      this->Write(code);
773 774 775 776 777 778 779 780 781
      ret_value.v_str = e.what();
      ret_tcode = kStr;
      SendPackedSeq(&ret_value, &ret_tcode, 1);
    }
  }

 private:
  // Utility functions
  // Internal read function, update pending_request_bytes_
tqchen committed
782
  size_t Read(void* data, size_t size) final {
783 784 785
    CHECK_LE(size, pending_request_bytes_);
    reader_->Read(data, size);
    pending_request_bytes_ -= size;
tqchen committed
786
    return size;
787
  }
tqchen committed
788 789
  void Write(const void* data, size_t size) final {
    writer_->Write(data, size);
790 791 792 793 794 795 796 797 798 799 800
  }
  // Number of pending bytes requests
  size_t pending_request_bytes_;
  // The ring buffer to read data from.
  common::RingBuffer* reader_;
  // The ringr buffer to write reply to.
  common::RingBuffer* writer_;
  // Session table index.
  int rpc_sess_table_index_;
  // Name of session.
  std::string name_;
801 802
  // remote key
  std::string* remote_key_;
803 804
};

805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837
struct RPCSessTable {
 public:
  static constexpr int kMaxRPCSession = 32;
  // Get global singleton
  static RPCSessTable* Global() {
    static RPCSessTable inst;
    return &inst;
  }
  // Get session from table
  std::shared_ptr<RPCSession> Get(int index) {
    CHECK(index >= 0 && index < kMaxRPCSession);
    return tbl_[index].lock();
  }
  // Insert session into table.
  int Insert(std::shared_ptr<RPCSession> ptr) {
    std::lock_guard<std::mutex> lock(mutex_);
    for (int i = 0; i < kMaxRPCSession; ++i) {
      if (tbl_[i].lock() == nullptr) {
        tbl_[i] = ptr; return i;
      }
    }
    LOG(FATAL) << "maximum number of RPC session reached";
    return 0;
  }

 private:
  // The mutex
  std::mutex mutex_;
  // Use weak_ptr intentionally
  // If the RPCSession get released, the pointer session will be released
  std::array<std::weak_ptr<RPCSession>, kMaxRPCSession> tbl_;
};

838 839
RPCCode RPCSession::HandleUntilReturnEvent(
    TVMRetValue* rv,  bool client_mode, const PackedFunc* fwrap) {
840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861
  RPCCode code = RPCCode::kCallFunc;
  while (code != RPCCode::kReturn &&
         code != RPCCode::kShutdown &&
         code != RPCCode::kCopyAck) {
    while (writer_.bytes_available() != 0) {
      writer_.ReadWithCallback([this](const void *data, size_t size) {
          return channel_->Send(data, size);
        }, writer_.bytes_available());
    }
    size_t bytes_needed = handler_->BytesNeeded();
    if (bytes_needed != 0) {
      size_t n = reader_.WriteWithCallback([this](void* data, size_t size) {
          return channel_->Recv(data, size);
        }, bytes_needed);
      if (n == 0) {
        if (handler_->CanCleanShutdown()) {
          return RPCCode::kShutdown;
        } else {
          LOG(FATAL) << "Channel closes before we get neded bytes";
        }
      }
    }
862
    code = handler_->HandleNextEvent(rv, client_mode, fwrap);
863 864 865 866
  }
  return code;
}

867
void RPCSession::Init() {
868
  // Event handler
869 870
  handler_ = std::make_shared<EventHandler>(
      &reader_, &writer_, table_index_, name_, &remote_key_);
871 872
  // Quick function to call remote.
  call_remote_ = PackedFunc([this](TVMArgs args, TVMRetValue* rv) {
873
      handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
874
      RPCCode code = HandleUntilReturnEvent(rv, true, nullptr);
875
      CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
876 877 878
    });
}

879
std::shared_ptr<RPCSession> RPCSession::Create(
880 881 882
    std::unique_ptr<RPCChannel> channel,
    std::string name,
    std::string remote_key) {
883
  std::shared_ptr<RPCSession> sess = std::make_shared<RPCSession>();
884 885
  sess->channel_ = std::move(channel);
  sess->name_ = std::move(name);
886
  sess->remote_key_ = std::move(remote_key);
887
  sess->table_index_ = RPCSessTable::Global()->Insert(sess);
888 889 890 891 892 893 894 895 896 897 898 899 900
  sess->Init();
  return sess;
}

std::shared_ptr<RPCSession> RPCSession::Get(int table_index) {
  return RPCSessTable::Global()->Get(table_index);
}

RPCSession::~RPCSession() {
  this->Shutdown();
}

void RPCSession::Shutdown() {
901
  if (channel_ != nullptr) {
902
    RPCCode code = RPCCode::kShutdown;
tqchen committed
903
    handler_->Write(code);
904 905 906 907 908 909 910 911 912 913 914
    // flush all writing buffer to output channel.
    try {
      while (writer_.bytes_available() != 0) {
        size_t n = writer_.ReadWithCallback([this](const void *data, size_t size) {
            return channel_->Send(data, size);
          }, writer_.bytes_available());
        if (n == 0) break;
      }
    } catch (const dmlc::Error& e) {
    }
    channel_.reset(nullptr);
915 916 917 918 919
  }
}

void RPCSession::ServerLoop() {
  std::lock_guard<std::recursive_mutex> lock(mutex_);
920
  if (const auto* f = Registry::Get("tvm.rpc.server.start")) {
921 922
    (*f)();
  }
923
  TVMRetValue rv;
924
  CHECK(HandleUntilReturnEvent(&rv, false, nullptr) == RPCCode::kShutdown);
925
  if (const auto* f = Registry::Get("tvm.rpc.server.shutdown")) {
926 927
    (*f)();
  }
928 929 930
  channel_.reset(nullptr);
}

931
int RPCSession::ServerEventHandler(const std::string& bytes, int event_flag) {
932
  std::lock_guard<std::recursive_mutex> lock(mutex_);
933 934 935 936 937 938 939
  RPCCode code = RPCCode::kNone;
  if (bytes.length() != 0) {
    reader_.Write(bytes.c_str(), bytes.length());
    TVMRetValue rv;
    code = handler_->HandleNextEvent(&rv, false, nullptr);
  }
  if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) {
940 941 942
    writer_.ReadWithCallback([this](const void *data, size_t size) {
        return channel_->Send(data, size);
      }, writer_.bytes_available());
943
  }
944
  CHECK(code != RPCCode::kReturn && code != RPCCode::kCopyAck);
945 946 947
  if (code == RPCCode::kShutdown) return 0;
  if (writer_.bytes_available() != 0) return 2;
  return 1;
948 949 950
}

// Get remote function with name
951 952 953 954
void RPCSession::CallFunc(void* h,
                          TVMArgs args,
                          TVMRetValue* rv,
                          const PackedFunc* fwrap) {
955 956
  std::lock_guard<std::recursive_mutex> lock(mutex_);
  RPCCode code = RPCCode::kCallFunc;
tqchen committed
957
  handler_->Write(code);
958
  uint64_t handle = reinterpret_cast<uint64_t>(h);
tqchen committed
959
  handler_->Write(handle);
960 961 962
  handler_->SendPackedSeq(args.values, args.type_codes, args.num_args);
  code = HandleUntilReturnEvent(rv, true, fwrap);
  CHECK(code == RPCCode::kReturn) << "code=" << static_cast<int>(code);
963 964 965 966 967 968 969
}

void RPCSession::CopyToRemote(void* from,
                              size_t from_offset,
                              void* to,
                              size_t to_offset,
                              size_t data_size,
970 971
                              TVMContext ctx_to,
                              TVMType type_hint) {
972
  std::lock_guard<std::recursive_mutex> lock(mutex_);
973
  ctx_to = handler_->StripSessMask(ctx_to);
974
  RPCCode code = RPCCode::kCopyToRemote;
tqchen committed
975
  handler_->Write(code);
976
  uint64_t handle = reinterpret_cast<uint64_t>(to);
tqchen committed
977
  handler_->Write(handle);
978
  uint64_t offset = static_cast<uint64_t>(to_offset);
tqchen committed
979
  handler_->Write(offset);
980
  uint64_t size = static_cast<uint64_t>(data_size);
tqchen committed
981 982
  handler_->Write(size);
  handler_->Write(ctx_to);
983
  handler_->Write(type_hint);
tqchen committed
984
  handler_->WriteArray(reinterpret_cast<char*>(from) + from_offset, data_size);
985
  TVMRetValue rv;
986
  CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kReturn);
987 988 989 990 991 992 993
}

void RPCSession::CopyFromRemote(void* from,
                                size_t from_offset,
                                void* to,
                                size_t to_offset,
                                size_t data_size,
994 995
                                TVMContext ctx_from,
                                TVMType type_hint) {
996
  std::lock_guard<std::recursive_mutex> lock(mutex_);
997
  ctx_from = handler_->StripSessMask(ctx_from);
998
  RPCCode code = RPCCode::kCopyFromRemote;
tqchen committed
999
  handler_->Write(code);
1000
  uint64_t handle = reinterpret_cast<uint64_t>(from);
tqchen committed
1001
  handler_->Write(handle);
1002
  uint64_t offset = static_cast<uint64_t>(from_offset);
tqchen committed
1003
  handler_->Write(offset);
1004
  uint64_t size = static_cast<uint64_t>(data_size);
tqchen committed
1005 1006
  handler_->Write(size);
  handler_->Write(ctx_from);
1007
  handler_->Write(type_hint);
1008
  TVMRetValue rv;
1009
  CHECK(HandleUntilReturnEvent(&rv, true, nullptr) == RPCCode::kCopyAck);
1010
  reader_.Reserve(data_size);
tqchen committed
1011 1012 1013
  handler_->RequestBytes(data_size);
  while (!handler_->Ready()) {
    size_t bytes_needed = handler_->BytesNeeded();
1014 1015 1016 1017 1018
    reader_.WriteWithCallback([this](void* data, size_t size) {
        size_t n = channel_->Recv(data, size);
        CHECK_NE(n, 0U) << "Channel closes before we get neded bytes";
        return n;
      }, bytes_needed);
1019
  }
tqchen committed
1020
  handler_->ReadArray(reinterpret_cast<char*>(to) + to_offset, data_size);
1021
  handler_->FinishCopyAck();
1022 1023
}

1024
RPCFuncHandle RPCSession::GetTimeEvaluator(
1025
    RPCFuncHandle fhandle, TVMContext ctx, int number, int repeat, int min_repeat_ms) {
1026
  return this->CallRemote(
1027
      RPCCode::kGetTimeEvaluator, fhandle, ctx, number, repeat, min_repeat_ms);
1028 1029
}

1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068
// Event handler functions
void RPCGetGlobalFunc(TVMArgs args, TVMRetValue* rv) {
  std::string name = args[0];
  auto *fp = tvm::runtime::Registry::Get(name);
  if (fp != nullptr) {
    *rv = static_cast<void*>(new tvm::runtime::PackedFunc(*fp));
  } else {
    *rv = nullptr;
  }
}

void RPCFreeFunc(TVMArgs args, TVMRetValue *rv) {
  void* handle = args[0];
  delete static_cast<PackedFunc*>(handle);
}

void RPCDevSetDevice(TVMArgs args, TVMRetValue *rv) {
  TVMContext ctx = args[0];
  DeviceAPI::Get(ctx)->SetDevice(ctx);
}

void RPCDevGetAttr(TVMArgs args, TVMRetValue *rv) {
  TVMContext ctx = args[0];
  DeviceAttrKind kind = static_cast<DeviceAttrKind>(args[1].operator int());
  if (kind == kExist) {
    DeviceAPI* api = DeviceAPI::Get(ctx, true);
    if (api != nullptr) {
      api->GetAttr(ctx, kind, rv);
    } else {
      *rv = 0;
    }
  } else {
    DeviceAPI::Get(ctx)->GetAttr(
        ctx, static_cast<DeviceAttrKind>(kind), rv);
  }
}

void RPCDevAllocData(TVMArgs args, TVMRetValue *rv) {
  TVMContext ctx = args[0];
1069
  uint64_t nbytes = args[1];
1070
  uint64_t alignment = args[2];
1071 1072 1073
  TVMType type_hint = args[3];
  void* data = DeviceAPI::Get(ctx)->AllocDataSpace(
      ctx, nbytes, alignment, type_hint);
1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096
  *rv = data;
}

void RPCDevFreeData(TVMArgs args, TVMRetValue *rv) {
  TVMContext ctx = args[0];
  void* ptr = args[1];
  DeviceAPI::Get(ctx)->FreeDataSpace(ctx, ptr);
}

void RPCDevStreamSync(TVMArgs args, TVMRetValue *rv) {
  TVMContext ctx = args[0];
  TVMStreamHandle handle = args[1];
  DeviceAPI::Get(ctx)->StreamSync(ctx, handle);
}

void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
  void* from = args[0];
  uint64_t from_offset = args[1];
  void* to = args[2];
  uint64_t to_offset = args[3];
  uint64_t size = args[4];
  TVMContext ctx_from = args[5];
  TVMContext ctx_to = args[6];
1097 1098
  TVMType type_hint = args[7];
  TVMStreamHandle stream = args[8];
1099
  TVMContext ctx = ctx_from;
1100
  if (ctx.device_type == kDLCPU) {
1101 1102
    ctx = ctx_to;
  } else {
1103
    CHECK(ctx_to.device_type == kDLCPU ||
1104 1105 1106 1107 1108 1109
          ctx_to.device_type == ctx_from.device_type)
        << "Can not copy across different ctx types directly";
  }
  DeviceAPI::Get(ctx)->CopyDataFromTo(
      from, from_offset,
      to, to_offset,
1110
      size, ctx_from, ctx_to, type_hint, stream);
1111 1112 1113 1114 1115
}

void RPCModuleLoad(TVMArgs args, TVMRetValue *rv) {
  static const PackedFunc* fsys_load_ = nullptr;
  if (fsys_load_ == nullptr) {
1116
    fsys_load_ = runtime::Registry::Get("tvm.rpc.server.load_module");
1117 1118 1119 1120 1121 1122 1123 1124
    CHECK(fsys_load_ != nullptr);
  }
  std::string file_name = args[0];
  TVMRetValue ret = (*fsys_load_)(file_name);
  Module m = ret;
  *rv = static_cast<void*>(new Module(m));
}

1125 1126 1127 1128 1129 1130 1131
void RPCModuleImport(TVMArgs args, TVMRetValue *rv) {
  void* pmod = args[0];
  void* cmod = args[1];
  static_cast<Module*>(pmod)->Import(
      *static_cast<Module*>(cmod));
}

1132 1133 1134 1135 1136 1137 1138 1139 1140
void RPCModuleFree(TVMArgs args, TVMRetValue *rv) {
  void* mhandle = args[0];
  delete static_cast<Module*>(mhandle);
}

void RPCModuleGetFunc(TVMArgs args, TVMRetValue *rv) {
  void* mhandle = args[0];
  PackedFunc pf = static_cast<Module*>(mhandle)->GetFunction(
      args[1], false);
Tianqi Chen committed
1141 1142 1143 1144 1145
  if (pf != nullptr) {
    *rv = static_cast<void*>(new PackedFunc(pf));
  } else {
    *rv = nullptr;
  }
1146 1147 1148 1149 1150 1151 1152 1153
}

void RPCModuleGetSource(TVMArgs args, TVMRetValue *rv) {
  void* mhandle = args[0];
  std::string fmt = args[1];
  *rv = (*static_cast<Module*>(mhandle))->GetSource(fmt);
}

1154 1155 1156 1157 1158
void RPCNDArrayFree(TVMArgs args, TVMRetValue *rv) {
  void* handle = args[0];
  static_cast<NDArray::Container*>(handle)->DecRef();
}

1159 1160
void RPCGetTimeEvaluator(TVMArgs args, TVMRetValue *rv) {
  PackedFunc *pf = static_cast<PackedFunc*>(args[0].operator void*());
1161
  void *fhandle = new PackedFunc(WrapTimeEvaluator(*pf, args[1], args[2], args[3], args[4]));
1162 1163 1164 1165
  delete pf;
  *rv = fhandle;
}

1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191
void RPCSession::EventHandler::HandlePackedCall() {
  CHECK_EQ(pending_request_bytes_, 0U);
  if (code_ == RPCCode::kReturn) {
    state_ = kReturnReceived; return;
  }
  // reset state to clean init state
  state_ = kRecvCode;
  this->RequestBytes(sizeof(RPCCode));
  // Event handler sit at clean state at this point.
  switch (code_) {
    case RPCCode::kCallFunc: {
      PackedFunc* pf = reinterpret_cast<PackedFunc*>(call_handle_);
      CallHandler([pf](TVMArgs args, TVMRetValue* rv) {
          pf->CallPacked(args, rv);
        });
      break;
    }
    case RPCCode::kException: {
      CHECK_EQ(arg_buf_->value.size(), 1U);
      CHECK_EQ(arg_buf_->tcode[0], kStr);
      std::ostringstream os;
      os << "Except caught from RPC call: " << arg_buf_->value[0].v_str;
      arg_buf_.reset();
      throw dmlc::Error(os.str());
      break;
    }
1192
    // system functions
1193
    case RPCCode::kGetTimeEvaluator: CallHandler(RPCGetTimeEvaluator); break;
1194 1195 1196 1197 1198 1199 1200 1201 1202
    case RPCCode::kFreeFunc: CallHandler(RPCFreeFunc); break;
    case RPCCode::kGetGlobalFunc: CallHandler(RPCGetGlobalFunc); break;
    case RPCCode::kDevSetDevice: CallHandler(RPCDevSetDevice); break;
    case RPCCode::kDevGetAttr: CallHandler(RPCDevGetAttr); break;
    case RPCCode::kDevAllocData: CallHandler(RPCDevAllocData); break;
    case RPCCode::kDevFreeData: CallHandler(RPCDevFreeData); break;
    case RPCCode::kDevStreamSync: CallHandler(RPCDevStreamSync); break;
    case RPCCode::kCopyAmongRemote: CallHandler(RPCCopyAmongRemote); break;
    case RPCCode::kModuleLoad: CallHandler(RPCModuleLoad); break;
1203
    case RPCCode::kModuleImport: CallHandler(RPCModuleImport); break;
1204 1205 1206
    case RPCCode::kModuleFree: CallHandler(RPCModuleFree); break;
    case RPCCode::kModuleGetFunc: CallHandler(RPCModuleGetFunc); break;
    case RPCCode::kModuleGetSource: CallHandler(RPCModuleGetSource); break;
1207
    case RPCCode::kNDArrayFree: CallHandler(RPCNDArrayFree); break;
1208
    default: LOG(FATAL) << "Unknown event " << static_cast<int>(code_);
1209
  }
1210
  CHECK_EQ(state_, kRecvCode);
1211
}
1212

1213 1214 1215 1216 1217 1218
PackedFunc WrapTimeEvaluator(PackedFunc pf,
                             TVMContext ctx,
                             int number,
                             int repeat,
                             int min_repeat_ms) {
  auto ftimer = [pf, ctx, number, repeat, min_repeat_ms](TVMArgs args, TVMRetValue *rv) mutable {
1219
    TVMRetValue temp;
1220
    std::ostringstream os;
1221 1222 1223
    // skip first time call, to activate lazy compilation components.
    pf.CallPacked(args, &temp);
    DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
1224

1225
    for (int i = 0; i < repeat; ++i) {
1226 1227
      std::chrono::time_point<
        std::chrono::high_resolution_clock, std::chrono::nanoseconds> tbegin, tend;
1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248
      double duration_ms = 0.0;

      do {
        if (duration_ms > 0.0) {
          number = static_cast<int>(
              std::max((min_repeat_ms / (duration_ms / number) + 1),
                       number * 1.618));   // 1.618 is chosen by random
        }

        tbegin = std::chrono::high_resolution_clock::now();
        // start timing
        for (int i = 0; i < number; ++i) {
          pf.CallPacked(args, &temp);
        }
        DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr);
        tend = std::chrono::high_resolution_clock::now();

        duration_ms = std::chrono::duration_cast<std::chrono::duration<double> >
            (tend - tbegin).count() * 1000;
      } while (duration_ms < min_repeat_ms);

1249 1250 1251
      double speed = std::chrono::duration_cast<std::chrono::duration<double> >(
          tend - tbegin).count() / number;
      os.write(reinterpret_cast<char*>(&speed), sizeof(speed));
1252
    }
1253 1254 1255 1256
    std::string blob = os.str();
    TVMByteArray arr;
    arr.size = blob.length();
    arr.data = blob.data();
1257
    // return the time.
1258
    *rv = arr;
1259 1260 1261
  };
  return PackedFunc(ftimer);
}
1262

1263 1264
}  // namespace runtime
}  // namespace tvm