rpc_module.cc 8.1 KB
Newer Older
1 2 3 4 5 6 7
/*!
 *  Copyright (c) 2017 by Contributors
 * \file rpc_device_api.cc
 * \brief RPC module.
 */
#include <tvm/runtime/registry.h>
#include <memory>
8
#include <cstring>
9
#include "rpc_session.h"
10 11 12 13 14

namespace tvm {
namespace runtime {

// Wrapped remote function to packed func.
15
class RPCWrappedFunc {
16
 public:
17 18 19 20
  RPCWrappedFunc(void* handle,
                 std::shared_ptr<RPCSession> sess)
      : handle_(handle), sess_(sess) {
    fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
21
        WrapRemote(sess, args, rv);
22 23 24
      });
  }

25
  void operator()(TVMArgs args, TVMRetValue *rv) const {
26
    sess_->CallFunc(handle_, args, rv, &fwrap_);
27 28
  }
  ~RPCWrappedFunc() {
29 30 31 32 33
    try {
      sess_->CallRemote(RPCCode::kFreeFunc, handle_);
    } catch (const dmlc::Error& e) {
      // fault tolerance to remote close
    }
34 35
  }

36
  static void WrapRemote(std::shared_ptr<RPCSession> sess,
37
                         TVMArgs args,
38 39
                         TVMRetValue* rv);

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
  // deleter of RPC remote array
  static void RemoteNDArrayDeleter(NDArray::Container* ptr) {
    RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
    space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx);
    delete space;
    delete ptr;
  }
  // wrap return value as remote NDArray.
  static NDArray WrapRemoteNDArray(std::shared_ptr<RPCSession> sess,
                                   DLTensor* tensor,
                                   void* nd_handle) {
    NDArray::Container* data = new NDArray::Container();
    data->manager_ctx = nd_handle;
    data->deleter = RemoteNDArrayDeleter;
    RemoteSpace* space = new RemoteSpace();
    space->sess = sess;
    space->data = tensor->data;
    data->dl_tensor.data = space;
    NDArray ret(data);
    // RAII now in effect
    data->shape_ = std::vector<int64_t>(
        tensor->shape, tensor->shape + tensor->ndim);
    data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
    data->dl_tensor.ndim = static_cast<int>(data->shape_.size());
    // setup dtype
    data->dl_tensor.dtype = tensor->dtype;
    // setup ctx, encode as remote session
    data->dl_tensor.ctx.device_id = tensor->ctx.device_id;
    data->dl_tensor.ctx.device_type = static_cast<DLDeviceType>(
        static_cast<int>(tensor->ctx.device_type) +
        kRPCSessMask * (sess->table_index() + 1));
    // check strides.
    CHECK(tensor->strides == nullptr);
    // setup byteoffset
    data->dl_tensor.byte_offset = tensor->byte_offset;
    return ret;
  }

78
 private:
79
  PackedFunc fwrap_;
80 81 82 83 84 85 86 87 88 89 90 91
  void* handle_{nullptr};
  std::shared_ptr<RPCSession> sess_;
};

// RPC that represents a remote module session.
class RPCModuleNode final : public ModuleNode {
 public:
  RPCModuleNode(void* module_handle, std::shared_ptr<RPCSession> sess)
      : module_handle_(module_handle), sess_(sess) {
  }
  ~RPCModuleNode() {
    if (module_handle_ != nullptr) {
92 93 94 95 96
      try {
        sess_->CallRemote(RPCCode::kModuleFree, module_handle_);
      } catch (const dmlc::Error& e) {
        // fault tolerance to remote close
      }
97
      module_handle_ = nullptr;
98 99 100 101 102 103 104 105 106 107
    }
  }

  const char* type_key() const final {
    return "rpc";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
108 109
    RPCFuncHandle handle = GetFuncHandle(name);
    return WrapRemote(handle);
110 111 112 113 114 115 116 117 118 119 120 121 122 123
  }

  std::string GetSource(const std::string& format) final {
    if (module_handle_ != nullptr) {
      std::string ret =  sess_->CallRemote(
          RPCCode::kModuleGetSource, module_handle_, format);
    }
    return "";
  }

  std::shared_ptr<RPCSession>& sess() {
    return sess_;
  }

124 125
  PackedFunc GetTimeEvaluator(const std::string& name,
                              TVMContext ctx,
126
                              int number,
127 128
                              int repeat,
                              int min_repeat_ms) {
129 130
    RPCFuncHandle handle = GetFuncHandle(name);
    if (handle == nullptr) return PackedFunc();
131
    handle = sess_->GetTimeEvaluator(handle, ctx, number, repeat, min_repeat_ms);
132 133 134
    return WrapRemote(handle);
  }

135 136 137 138
  void* module_handle() const {
    return module_handle_;
  }

139
 private:
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
  PackedFunc WrapRemote(RPCFuncHandle handle) {
    if (handle == nullptr) return PackedFunc();
    auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
    return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
        return wf->operator()(args, rv);
      });
  }

  RPCFuncHandle GetFuncHandle(const std::string& name) {
    RPCFuncHandle handle = nullptr;
    if (module_handle_ == nullptr) {
      handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
    } else {
      handle = sess_->CallRemote(
          RPCCode::kModuleGetFunc, module_handle_, name);
    }
    return handle;
  }
158 159 160 161
  // The module handle
  void* module_handle_{nullptr};
  // The local channel
  std::shared_ptr<RPCSession> sess_;
162 163
  // Wrap function to wrap remote module/function.
  PackedFunc fwrap_;
164 165
};

166
void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
167
                                TVMArgs args,
168
                                TVMRetValue *rv) {
169 170 171
  void* handle = args.values[0].v_handle;
  int tcode = args.type_codes[0];

172 173 174 175 176 177
  if (handle == nullptr) return;
  if (tcode == kFuncHandle) {
    auto wf = std::make_shared<RPCWrappedFunc>(handle, sess);
    *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
        return wf->operator()(args, rv);
      });
178
  } else if (tcode == kModuleHandle) {
179 180 181
    std::shared_ptr<RPCModuleNode> n =
        std::make_shared<RPCModuleNode>(handle, sess);
    *rv = Module(n);
182 183 184 185 186 187 188
  } else if (tcode == kArrayHandle || tcode == kNDArrayContainer) {
    CHECK_EQ(args.size(), 2);
    DLTensor* tensor = args[0];
    void* nd_handle = args[1];
    *rv = WrapRemoteNDArray(sess, tensor, nd_handle);
  } else {
    LOG(FATAL) << "Cannot wrap tcode=" << tcode;
189 190 191
  }
}

192
Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
193
  std::shared_ptr<RPCModuleNode> n =
194
      std::make_shared<RPCModuleNode>(nullptr, sess);
195 196 197
  return Module(n);
}

198 199 200 201 202 203 204 205 206
TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    TVMContext ctx;
    ctx.device_type = static_cast<DLDeviceType>(args[2].operator int());
    ctx.device_id = args[3];
    if (tkey == "rpc") {
      *rv = static_cast<RPCModuleNode*>(m.operator->())
207
          ->GetTimeEvaluator(args[1], ctx, args[4], args[5], args[6]);
208 209
    } else {
      *rv = WrapTimeEvaluator(
210
          m.GetFunction(args[1], false), ctx, args[4], args[5], args[6]);
211 212 213
    }
  });

214
TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule")
215 216 217 218 219 220 221 222 223 224 225
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    CHECK_EQ(tkey, "rpc");
    auto& sess = static_cast<RPCModuleNode*>(m.operator->())->sess();
    void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]);
    std::shared_ptr<RPCModuleNode> n =
        std::make_shared<RPCModuleNode>(mhandle, sess);
    *rv = Module(n);
  });

226
TVM_REGISTER_GLOBAL("rpc._ImportRemoteModule")
227 228 229 230 231 232 233 234 235 236 237 238 239 240
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module parent = args[0];
    Module child = args[1];
    CHECK(!std::strcmp(parent->type_key(), "rpc") &&
          !std::strcmp(child->type_key(), "rpc"));
    auto* pmod = static_cast<RPCModuleNode*>(parent.operator->());
    auto* cmod = static_cast<RPCModuleNode*>(child.operator->());
    CHECK(pmod->sess().get() == cmod->sess().get())
        << "Import of remote module need to belong to same session.";
    pmod->sess()->CallRemote(RPCCode::kModuleImport,
                             pmod->module_handle(),
                             cmod->module_handle());
  });

241
TVM_REGISTER_GLOBAL("rpc._ModuleHandle")
242 243 244 245 246 247 248
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    CHECK_EQ(tkey, "rpc");
    *rv = static_cast<RPCModuleNode*>(m.operator->())->module_handle();
  });

249
TVM_REGISTER_GLOBAL("rpc._SessTableIndex")
250 251 252 253 254 255
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    CHECK_EQ(tkey, "rpc");
    *rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
  });
256

257 258
}  // namespace runtime
}  // namespace tvm