rpc_module.cc 8.73 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25
/*!
 * \file rpc_device_api.cc
 * \brief RPC module.
 */
#include <tvm/runtime/registry.h>
#include <memory>
26
#include <cstring>
27
#include "rpc_session.h"
28 29 30 31 32

namespace tvm {
namespace runtime {

// Wrapped remote function to packed func.
33
class RPCWrappedFunc {
34
 public:
35 36 37 38
  RPCWrappedFunc(void* handle,
                 std::shared_ptr<RPCSession> sess)
      : handle_(handle), sess_(sess) {
    fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
39
        WrapRemote(sess, args, rv);
40 41 42
      });
  }

43
  void operator()(TVMArgs args, TVMRetValue *rv) const {
44
    sess_->CallFunc(handle_, args, rv, &fwrap_);
45 46
  }
  ~RPCWrappedFunc() {
47 48 49 50 51
    try {
      sess_->CallRemote(RPCCode::kFreeFunc, handle_);
    } catch (const dmlc::Error& e) {
      // fault tolerance to remote close
    }
52 53
  }

54
  static void WrapRemote(std::shared_ptr<RPCSession> sess,
55
                         TVMArgs args,
56 57
                         TVMRetValue* rv);

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
  // deleter of RPC remote array
  static void RemoteNDArrayDeleter(NDArray::Container* ptr) {
    RemoteSpace* space = static_cast<RemoteSpace*>(ptr->dl_tensor.data);
    space->sess->CallRemote(RPCCode::kNDArrayFree, ptr->manager_ctx);
    delete space;
    delete ptr;
  }
  // wrap return value as remote NDArray.
  static NDArray WrapRemoteNDArray(std::shared_ptr<RPCSession> sess,
                                   DLTensor* tensor,
                                   void* nd_handle) {
    NDArray::Container* data = new NDArray::Container();
    data->manager_ctx = nd_handle;
    data->deleter = RemoteNDArrayDeleter;
    RemoteSpace* space = new RemoteSpace();
    space->sess = sess;
    space->data = tensor->data;
    data->dl_tensor.data = space;
    NDArray ret(data);
    // RAII now in effect
    data->shape_ = std::vector<int64_t>(
        tensor->shape, tensor->shape + tensor->ndim);
    data->dl_tensor.shape = dmlc::BeginPtr(data->shape_);
    data->dl_tensor.ndim = static_cast<int>(data->shape_.size());
    // setup dtype
    data->dl_tensor.dtype = tensor->dtype;
    // setup ctx, encode as remote session
    data->dl_tensor.ctx.device_id = tensor->ctx.device_id;
    data->dl_tensor.ctx.device_type = static_cast<DLDeviceType>(
        static_cast<int>(tensor->ctx.device_type) +
        kRPCSessMask * (sess->table_index() + 1));
    // check strides.
    CHECK(tensor->strides == nullptr);
    // setup byteoffset
    data->dl_tensor.byte_offset = tensor->byte_offset;
    return ret;
  }

96
 private:
97
  PackedFunc fwrap_;
98 99 100 101 102 103 104 105 106 107 108 109
  void* handle_{nullptr};
  std::shared_ptr<RPCSession> sess_;
};

// RPC that represents a remote module session.
class RPCModuleNode final : public ModuleNode {
 public:
  RPCModuleNode(void* module_handle, std::shared_ptr<RPCSession> sess)
      : module_handle_(module_handle), sess_(sess) {
  }
  ~RPCModuleNode() {
    if (module_handle_ != nullptr) {
110 111 112 113 114
      try {
        sess_->CallRemote(RPCCode::kModuleFree, module_handle_);
      } catch (const dmlc::Error& e) {
        // fault tolerance to remote close
      }
115
      module_handle_ = nullptr;
116 117 118 119 120 121 122 123 124
    }
  }

  const char* type_key() const final {
    return "rpc";
  }

  PackedFunc GetFunction(
      const std::string& name,
125
      const ObjectPtr<Object>& sptr_to_self) final {
126 127
    RPCFuncHandle handle = GetFuncHandle(name);
    return WrapRemote(handle);
128 129 130 131 132 133 134 135 136 137 138 139 140 141
  }

  std::string GetSource(const std::string& format) final {
    if (module_handle_ != nullptr) {
      std::string ret =  sess_->CallRemote(
          RPCCode::kModuleGetSource, module_handle_, format);
    }
    return "";
  }

  std::shared_ptr<RPCSession>& sess() {
    return sess_;
  }

142 143
  PackedFunc GetTimeEvaluator(const std::string& name,
                              TVMContext ctx,
144
                              int number,
145 146
                              int repeat,
                              int min_repeat_ms) {
147 148
    RPCFuncHandle handle = GetFuncHandle(name);
    if (handle == nullptr) return PackedFunc();
149
    handle = sess_->GetTimeEvaluator(handle, ctx, number, repeat, min_repeat_ms);
150 151 152
    return WrapRemote(handle);
  }

153 154 155 156
  void* module_handle() const {
    return module_handle_;
  }

157
 private:
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
  PackedFunc WrapRemote(RPCFuncHandle handle) {
    if (handle == nullptr) return PackedFunc();
    auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
    return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
        return wf->operator()(args, rv);
      });
  }

  RPCFuncHandle GetFuncHandle(const std::string& name) {
    RPCFuncHandle handle = nullptr;
    if (module_handle_ == nullptr) {
      handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
    } else {
      handle = sess_->CallRemote(
          RPCCode::kModuleGetFunc, module_handle_, name);
    }
    return handle;
  }
176 177 178 179
  // The module handle
  void* module_handle_{nullptr};
  // The local channel
  std::shared_ptr<RPCSession> sess_;
180 181
  // Wrap function to wrap remote module/function.
  PackedFunc fwrap_;
182 183
};

184
void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
185
                                TVMArgs args,
186
                                TVMRetValue *rv) {
187 188 189
  void* handle = args.values[0].v_handle;
  int tcode = args.type_codes[0];

190 191 192 193 194 195
  if (handle == nullptr) return;
  if (tcode == kFuncHandle) {
    auto wf = std::make_shared<RPCWrappedFunc>(handle, sess);
    *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
        return wf->operator()(args, rv);
      });
196
  } else if (tcode == kModuleHandle) {
197
    auto n = make_object<RPCModuleNode>(handle, sess);
198
    *rv = Module(n);
199 200 201 202 203 204 205
  } else if (tcode == kArrayHandle || tcode == kNDArrayContainer) {
    CHECK_EQ(args.size(), 2);
    DLTensor* tensor = args[0];
    void* nd_handle = args[1];
    *rv = WrapRemoteNDArray(sess, tensor, nd_handle);
  } else {
    LOG(FATAL) << "Cannot wrap tcode=" << tcode;
206 207 208
  }
}

209
Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
210
  auto n = make_object<RPCModuleNode>(nullptr, sess);
211 212 213
  return Module(n);
}

214 215 216 217 218 219 220 221 222
TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    TVMContext ctx;
    ctx.device_type = static_cast<DLDeviceType>(args[2].operator int());
    ctx.device_id = args[3];
    if (tkey == "rpc") {
      *rv = static_cast<RPCModuleNode*>(m.operator->())
223
          ->GetTimeEvaluator(args[1], ctx, args[4], args[5], args[6]);
224 225
    } else {
      *rv = WrapTimeEvaluator(
226
          m.GetFunction(args[1], false), ctx, args[4], args[5], args[6]);
227 228 229
    }
  });

230
TVM_REGISTER_GLOBAL("rpc._LoadRemoteModule")
231 232 233 234 235 236
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    CHECK_EQ(tkey, "rpc");
    auto& sess = static_cast<RPCModuleNode*>(m.operator->())->sess();
    void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]);
237
    auto n = make_object<RPCModuleNode>(mhandle, sess);
238 239 240
    *rv = Module(n);
  });

241
TVM_REGISTER_GLOBAL("rpc._ImportRemoteModule")
242 243 244 245 246 247 248 249 250 251 252 253 254 255
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module parent = args[0];
    Module child = args[1];
    CHECK(!std::strcmp(parent->type_key(), "rpc") &&
          !std::strcmp(child->type_key(), "rpc"));
    auto* pmod = static_cast<RPCModuleNode*>(parent.operator->());
    auto* cmod = static_cast<RPCModuleNode*>(child.operator->());
    CHECK(pmod->sess().get() == cmod->sess().get())
        << "Import of remote module need to belong to same session.";
    pmod->sess()->CallRemote(RPCCode::kModuleImport,
                             pmod->module_handle(),
                             cmod->module_handle());
  });

256
TVM_REGISTER_GLOBAL("rpc._ModuleHandle")
257 258 259 260 261 262 263
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    CHECK_EQ(tkey, "rpc");
    *rv = static_cast<RPCModuleNode*>(m.operator->())->module_handle();
  });

264
TVM_REGISTER_GLOBAL("rpc._SessTableIndex")
265 266 267 268 269 270
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    CHECK_EQ(tkey, "rpc");
    *rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
  });
271

272 273
}  // namespace runtime
}  // namespace tvm