rpc_module.cc 6.05 KB
Newer Older
1 2 3 4 5 6 7
/*!
 *  Copyright (c) 2017 by Contributors
 * \file rpc_device_api.cc
 * \brief RPC module.
 */
#include <tvm/runtime/registry.h>
#include <memory>
8
#include <cstring>
9 10 11 12 13 14 15 16
#include "./rpc_session.h"

namespace tvm {
namespace runtime {

// Wrapped remote function to packed func.
struct RPCWrappedFunc {
 public:
17 18 19 20 21 22 23 24
  RPCWrappedFunc(void* handle,
                 std::shared_ptr<RPCSession> sess)
      : handle_(handle), sess_(sess) {
    fwrap_ = PackedFunc([sess](TVMArgs args, TVMRetValue* rv) {
        WrapRemote(sess, args.values[0].v_handle, args.type_codes[0], rv);
      });
  }

25
  void operator()(TVMArgs args, TVMRetValue *rv) const {
26
    sess_->CallFunc(handle_, args, rv, &fwrap_);
27 28 29 30 31
  }
  ~RPCWrappedFunc() {
    sess_->CallRemote(RPCCode::kFreeFunc, handle_);
  }

32 33 34 35 36
  static void WrapRemote(std::shared_ptr<RPCSession> sess,
                         void* handle,
                         int tcode,
                         TVMRetValue* rv);

37
 private:
38
  PackedFunc fwrap_;
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
  void* handle_{nullptr};
  std::shared_ptr<RPCSession> sess_;
};

// RPC that represents a remote module session.
class RPCModuleNode final : public ModuleNode {
 public:
  RPCModuleNode(void* module_handle, std::shared_ptr<RPCSession> sess)
      : module_handle_(module_handle), sess_(sess) {
  }
  ~RPCModuleNode() {
    if (module_handle_ != nullptr) {
      sess_->CallRemote(RPCCode::kModuleFree, module_handle_);
    }
  }

  const char* type_key() const final {
    return "rpc";
  }

  PackedFunc GetFunction(
      const std::string& name,
      const std::shared_ptr<ModuleNode>& sptr_to_self) final {
62 63
    RPCFuncHandle handle = GetFuncHandle(name);
    return WrapRemote(handle);
64 65 66 67 68 69 70 71 72 73 74 75 76 77
  }

  std::string GetSource(const std::string& format) final {
    if (module_handle_ != nullptr) {
      std::string ret =  sess_->CallRemote(
          RPCCode::kModuleGetSource, module_handle_, format);
    }
    return "";
  }

  std::shared_ptr<RPCSession>& sess() {
    return sess_;
  }

78 79 80 81 82 83 84 85 86
  PackedFunc GetTimeEvaluator(const std::string& name,
                              TVMContext ctx,
                              int nstep) {
    RPCFuncHandle handle = GetFuncHandle(name);
    if (handle == nullptr) return PackedFunc();
    handle = sess_->GetTimeEvaluator(handle, ctx, nstep);
    return WrapRemote(handle);
  }

87 88 89 90
  void* module_handle() const {
    return module_handle_;
  }

91
 private:
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
  PackedFunc WrapRemote(RPCFuncHandle handle) {
    if (handle == nullptr) return PackedFunc();
    auto wf = std::make_shared<RPCWrappedFunc>(handle, sess_);
    return PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
        return wf->operator()(args, rv);
      });
  }

  RPCFuncHandle GetFuncHandle(const std::string& name) {
    RPCFuncHandle handle = nullptr;
    if (module_handle_ == nullptr) {
      handle = sess_->CallRemote(RPCCode::kGetGlobalFunc, name);
    } else {
      handle = sess_->CallRemote(
          RPCCode::kModuleGetFunc, module_handle_, name);
    }
    return handle;
  }
110 111 112 113
  // The module handle
  void* module_handle_{nullptr};
  // The local channel
  std::shared_ptr<RPCSession> sess_;
114 115
  // Wrap function to wrap remote module/function.
  PackedFunc fwrap_;
116 117
};

118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
                                void* handle,
                                int tcode,
                                TVMRetValue *rv) {
  if (handle == nullptr) return;
  if (tcode == kFuncHandle) {
    auto wf = std::make_shared<RPCWrappedFunc>(handle, sess);
    *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
        return wf->operator()(args, rv);
      });
  } else {
    CHECK_EQ(tcode, kModuleHandle);
    std::shared_ptr<RPCModuleNode> n =
        std::make_shared<RPCModuleNode>(handle, sess);
    *rv = Module(n);
  }
}

136
Module CreateRPCModule(std::shared_ptr<RPCSession> sess) {
137
  std::shared_ptr<RPCModuleNode> n =
138
      std::make_shared<RPCModuleNode>(nullptr, sess);
139 140 141
  return Module(n);
}

142 143 144 145 146 147 148 149 150 151 152 153
TVM_REGISTER_GLOBAL("module._RPCTimeEvaluator")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    TVMContext ctx;
    ctx.device_type = static_cast<DLDeviceType>(args[2].operator int());
    ctx.device_id = args[3];
    if (tkey == "rpc") {
      *rv = static_cast<RPCModuleNode*>(m.operator->())
          ->GetTimeEvaluator(args[1], ctx, args[4]);
    } else {
      *rv = WrapTimeEvaluator(
154
          m.GetFunction(args[1], false), ctx, args[4]);
155 156 157
    }
  });

158 159 160 161 162 163 164 165 166 167 168 169
TVM_REGISTER_GLOBAL("contrib.rpc._LoadRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    CHECK_EQ(tkey, "rpc");
    auto& sess = static_cast<RPCModuleNode*>(m.operator->())->sess();
    void* mhandle = sess->CallRemote(RPCCode::kModuleLoad, args[1]);
    std::shared_ptr<RPCModuleNode> n =
        std::make_shared<RPCModuleNode>(mhandle, sess);
    *rv = Module(n);
  });

170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
TVM_REGISTER_GLOBAL("contrib.rpc._ImportRemoteModule")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module parent = args[0];
    Module child = args[1];
    CHECK(!std::strcmp(parent->type_key(), "rpc") &&
          !std::strcmp(child->type_key(), "rpc"));
    auto* pmod = static_cast<RPCModuleNode*>(parent.operator->());
    auto* cmod = static_cast<RPCModuleNode*>(child.operator->());
    CHECK(pmod->sess().get() == cmod->sess().get())
        << "Import of remote module need to belong to same session.";
    pmod->sess()->CallRemote(RPCCode::kModuleImport,
                             pmod->module_handle(),
                             cmod->module_handle());
  });

185 186 187 188 189 190 191 192
TVM_REGISTER_GLOBAL("contrib.rpc._ModuleHandle")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    CHECK_EQ(tkey, "rpc");
    *rv = static_cast<RPCModuleNode*>(m.operator->())->module_handle();
  });

193 194 195 196 197 198 199 200 201
TVM_REGISTER_GLOBAL("contrib.rpc._SessTableIndex")
.set_body([](TVMArgs args, TVMRetValue* rv) {
    Module m = args[0];
    std::string tkey = m->type_key();
    CHECK_EQ(tkey, "rpc");
    *rv = static_cast<RPCModuleNode*>(m.operator->())->sess()->table_index();
  });
}  // namespace runtime
}  // namespace tvm