rpc_module.cc 6.34 KB
Newer Older
1 2 3 4 5 6 7
/*!
 *  Copyright (c) 2017 by Contributors
 * \file rpc_device_api.cc
 * \brief RPC module.
 */
#include <tvm/runtime/registry.h>
#include <memory>
8
#include <cstring>
9 10 11 12 13 14 15 16
#include "./rpc_session.h"

namespace tvm {
namespace runtime {

// Wrapped remote function to packed func.
struct RPCWrappedFunc {
 public:
17 18 19 20 21 22 23 24
  RPCWrappedFunc(void* handle,
                 std::shared_ptr<RPCSession> sess)
      : handle_(handle), sess_(sess) {
    fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
        WrapRemote(sess, args.values[0].v_handle, args.type_codes[0], rv);
      });
  }

25
  void operator()(TVMArgs args, TVMRetValue *rv) const {
26
    sess_->CallFunc(handle_, args, rv, &fwrap_);
27 28
  }
  ~RPCWrappedFunc() {
29 30 31 32 33
    try {
      sess_->CallRemote(RPCCode::kFreeFunc, handle_);
    } catch (const dmlc::Error& e) {
      // fault tolerance to remote close
    }
34 35
  }

36 37 38 39 40
  static void WrapRemote(std::shared_ptr<RPCSession> sess,
                         void* handle,
                         int tcode,
                         TVMRetValue* rv);

41
 private:
42
  PackedFunc fwrap_;
43 44 45 46 47 48 49 50 51 52 53 54
  void* handle_{nullptr};
  std::shared_ptr<RPCSession> sess_;
};

// RPC that represents a remote module session.
class RPCModuleNode final : public ModuleNode {
 public:
  RPCModuleNode(void* module_handle, std::shared_ptr<RPCSession> sess)
      : module_handle_(module_handle), sess_(sess) {
  }
  ~RPCModuleNode() {
    if (module_handle_ != nullptr) {
55 56 57 58 59
      try {
        sess_->CallRemote(RPCCode::kModuleFree, module_handle_);
      } catch (const dmlc::Error& e) {
        // fault tolerance to remote close
      }
60
      module_handle_ = nullptr;
61 62 63 64 65 66 67 68 69 70
    }
  }

  const char* type_key() const final {
    return "rpc";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
71 72
    RPCFuncHandle handle = GetFuncHandle(name);
    return WrapRemote(handle);
73 74 75 76 77 78 79 80 81 82 83 84 85 86
  }

  std::string GetSource(const std::string& format) final {
    if (module_handle_ != nullptr) {
      std::string ret =  sess_->CallRemote(
          RPCCode::kModuleGetSource, module_handle_, format);
    }
    return "";
  }

  std::shared_ptr<RPCSession>& sess() {
    return sess_;
  }

87 88
  PackedFunc GetTimeEvaluator(const std::string& name,
                              TVMContext ctx,
89 90
                              int number,
                              int repeat) {
91 92
    RPCFuncHandle handle = GetFuncHandle(name);
    if (handle == nullptr) return PackedFunc();
93
    handle = sess_->GetTimeEvaluator(handle, ctx, number, repeat);
94 95 96
    return WrapRemote(handle);
  }

97 98 99 100
  void* module_handle() const {
    return module_handle_;
  }

101
 private:
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
  PackedFunc WrapRemote(RPCFuncHandle handle) {
    if (handle == nullptr) return PackedFunc();
    auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
    return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
        return wf->operator()(args, rv);
      });
  }

  RPCFuncHandle GetFuncHandle(const std::string& name) {
    RPCFuncHandle handle = nullptr;
    if (module_handle_ == nullptr) {
      handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
    } else {
      handle = sess_->CallRemote(
          RPCCode::kModuleGetFunc, module_handle_, name);
    }
    return handle;
  }
120 121 122 123
  // The module handle
  void* module_handle_{nullptr};
  // The local channel
  std::shared_ptr<RPCSession> sess_;
124 125
  // Wrap function to wrap remote module/function.
  PackedFunc fwrap_;
126 127
};

128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
                                void* handle,
                                int tcode,
                                TVMRetValue *rv) {
  if (handle == nullptr) return;
  if (tcode == kFuncHandle) {
    auto wf = std::make_shared<RPCWrappedFunc>(handle, sess);
    *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
        return wf->operator()(args, rv);
      });
  } else {
    CHECK_EQ(tcode, kModuleHandle);
    std::shared_ptr<RPCModuleNode> n =
        std::make_shared<RPCModuleNode>(handle, sess);
    *rv = Module(n);
  }
}

146
Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
147
  std::shared_ptr<RPCModuleNode> n =
148
      std::make_shared<RPCModuleNode>(nullptr, sess);
149 150 151
  return Module(n);
}

152 153 154 155 156 157 158 159 160
TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    TVMContext ctx;
    ctx.device_type = static_cast<DLDeviceType>(args[2].operator int());
    ctx.device_id = args[3];
    if (tkey == "rpc") {
      *rv = static_cast<RPCModuleNode*>(m.operator->())
161
          ->GetTimeEvaluator(args[1], ctx, args[4], args[5]);
162 163
    } else {
      *rv = WrapTimeEvaluator(
164
          m.GetFunction(args[1], false), ctx, args[4], args[5]);
165 166 167
    }
  });

168 169 170 171 172 173 174 175 176 177 178 179
TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    CHECK_EQ(tkey, "rpc");
    auto& sess = static_cast<RPCModuleNode*>(m.operator->())->sess();
    void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]);
    std::shared_ptr<RPCModuleNode> n =
        std::make_shared<RPCModuleNode>(mhandle, sess);
    *rv = Module(n);
  });

180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
TVM_REGISTER_GLOBAL("contrib.rpc._ImportRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module parent = args[0];
    Module child = args[1];
    CHECK(!std::strcmp(parent->type_key(), "rpc") &&
          !std::strcmp(child->type_key(), "rpc"));
    auto* pmod = static_cast<RPCModuleNode*>(parent.operator->());
    auto* cmod = static_cast<RPCModuleNode*>(child.operator->());
    CHECK(pmod->sess().get() == cmod->sess().get())
        << "Import of remote module need to belong to same session.";
    pmod->sess()->CallRemote(RPCCode::kModuleImport,
                             pmod->module_handle(),
                             cmod->module_handle());
  });

195 196 197 198 199 200 201 202
TVM_REGISTER_GLOBAL("contrib.rpc._ModuleHandle")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    CHECK_EQ(tkey, "rpc");
    *rv = static_cast<RPCModuleNode*>(m.operator->())->module_handle();
  });

203 204 205 206 207 208 209
TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    CHECK_EQ(tkey, "rpc");
    *rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
  });
210

211 212
}  // namespace runtime
}  // namespace tvm