Commit c6d4f5af by ziheng Committed by GitHub

[EXECUTOR] Enable load executor remotely (#245)

* [EXECUTOR] Enable load executor remotely

* [EXECUTOR] Pipeline

* Pass bytearray directly

* Enable load dynamic library in rpc_server.py

* Fix

* lint

* Return Module from remote side directly

* Remove unused header file

* Fix

* fix
parent 3f3bf29d
......@@ -4,6 +4,6 @@ import tvm
from . import _base
from nnvm.symbol import *
from . import op_tvm_def
from .build import build, bind, save_params
from .build import build, bind, save_params, compile_graph, remote_load_exec
......@@ -32,6 +32,14 @@ def bind(g, ctx):
m = _create_exec(g.handle, ctx.device_type, ctx.device_id)
return m
_get_module = tvm.get_global_func("tvm_graph._get_module_from_graph")
def compile_graph(lib_fname, sym, target, shape, dtype="float32"):
g = build(sym, target, shape, dtype)
m = _get_module(g.handle)
m.save(lib_fname)
json_str = g.apply('SaveJSON').json_attr('json')
return json_str
@tvm.register_func("tvm_graph.lower")
def _lower(sch, inputs, func_name):
......@@ -55,3 +63,33 @@ def save_params(fname, params):
args.append(kv[0])
args.append(kv[1])
_save_param_dict(*args)
def remote_load_exec(sess, sym_json, remote_module_name, param_blob, ctx):
"""Load a remote graph executor, with the local files.
Parameters
----------
sym_json : str
The symbol json file.
remote_module_fname : str
The relative library location to remote temp folder. The
library need to be uploaded first.
param_blob : bytes or bytearray
The binary file to the local parameters.
Returns
-------
exec : GraphExecutor
The remote graph executor containing remote function.
"""
if "load_executor" not in sess._remote_funcs:
sess._remote_funcs["load_executor"] = sess.get_function("tvm_graph._load_executor")
assert ctx.device_type / tvm.contrib.rpc.RPC_SESS_MASK == sess._tbl_index + 1
device_type = ctx.device_type % tvm.contrib.rpc.RPC_SESS_MASK
return sess._remote_funcs["load_executor"](sym_json,
remote_module_name,
bytearray(param_blob),
device_type,
ctx.device_id)
......@@ -3,13 +3,16 @@
* \file NNVM Graph executor.
*/
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <nnvm/graph.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/tuple.h>
#include <nnvm/pass.h>
#include <numeric>
#include <string>
namespace tvm {
namespace contrib {
......@@ -53,8 +56,10 @@ class GraphExecutor : public runtime::ModuleNode {
void SetInput(int index, DLTensor* data_in);
// Copy index-th output to data_out
void GetOutput(int index, DLTensor* data_out);
// Load parameters from file
void LoadParams(std::string fname);
// Load parameters from stream
void LoadParams(dmlc::Stream* strm);
// Load parameters from binary file blob
void LoadParamsFromBlob(std::string param_blob);
// Execute the graph.
void Run();
......@@ -98,7 +103,7 @@ PackedFunc GraphExecutor::GetFunction(
});
} else if (name == "load_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->LoadParams(args[0]);
this->LoadParamsFromBlob(args[0]);
});
} else {
return PackedFunc();
......@@ -233,19 +238,17 @@ TVM_REGISTER_GLOBAL("tvm_graph._save_param_dict")
}
});
void GraphExecutor::LoadParams(std::string fname) {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
void GraphExecutor::LoadParams(dmlc::Stream *strm) {
uint64_t header, reserved;
CHECK(fi->Read(&header))
CHECK(strm->Read(&header))
<< "Invalid parameters file format";
CHECK(header == kTVMNDArrayListMagic)
<< "Invalid parameters file format";
CHECK(fi->Read(&reserved))
CHECK(strm->Read(&reserved))
<< "Invalid parameters file format";
std::vector<std::string> names;
CHECK(fi->Read(&names))
CHECK(strm->Read(&names))
<< "Invalid parameters file format";
nnvm::Symbol s;
......@@ -257,20 +260,22 @@ void GraphExecutor::LoadParams(std::string fname) {
name_index.emplace(input_names[i], i);
}
{
uint64_t sz;
fi->Read(&sz, sizeof(sz));
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size())
uint64_t sz;
strm->Read(&sz, sizeof(sz));
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size())
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
size_t idx = name_index.at(names[i]);
CHECK(LoadDLTensor(strm, &data_entry_[idx]))
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
size_t idx = name_index.at(names[i]);
CHECK(LoadDLTensor(fi.get(), &data_entry_[idx]))
<< "Invalid parameters file format";
}
}
}
void GraphExecutor::LoadParamsFromBlob(std::string param_blob) {
dmlc::MemoryStringStream strm(&param_blob);
this->LoadParams(&strm);
}
void GraphExecutor::SetupStorage() {
const auto& idx = graph_.indexed_graph();
......@@ -383,7 +388,8 @@ FOpExec GraphExecutor::CreateTVMOp(const nnvm::NodeAttrs& attrs,
}
// get compiled function from module.
runtime::PackedFunc pf = module_.GetFunction(it->second, false);
auto fexec = [arg_ptr, pf] () {
CHECK(pf != nullptr) << "no such function in module: " << it->second;
auto fexec = [arg_ptr, pf] () {
runtime::TVMRetValue rv;
runtime::TVMArgs targs(arg_ptr->arg_values.data(),
arg_ptr->arg_tcodes.data(),
......@@ -410,5 +416,47 @@ TVM_REGISTER_GLOBAL("tvm_graph._create_executor")
nnvm::Graph g = static_cast<nnvm::Graph*>(graph_handle)[0];
*rv = CreateExecutor(g, ctx);
});
TVM_REGISTER_GLOBAL("tvm_graph._get_module_from_graph")
.set_body([](TVMArgs args, TVMRetValue *rv) {
void* graph_handle = args[0];
nnvm::Graph* g = static_cast<nnvm::Graph*>(graph_handle);
*rv = g->MoveCopyAttr<tvm::runtime::Module>("module");
});
TVM_REGISTER_GLOBAL("tvm_graph._load_executor")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string sym_json = args[0];
std::string lib_fname = args[1];
std::string param_blob = args[2];
TVMContext ctx;
ctx.device_type = static_cast<DLDeviceType>(args[3].operator int());
ctx.device_id = args[4];
// load graph from json string
nnvm::Graph g;
g.attrs["json"] = std::make_shared<nnvm::any>(sym_json);
g = nnvm::ApplyPass(std::move(g), "LoadJSON");
// load module from file
static const PackedFunc* fsys_load_ = nullptr;
if (fsys_load_ == nullptr) {
fsys_load_ = runtime::Registry::Get("tvm.contrib.rpc.server.load_module");
CHECK(fsys_load_ != nullptr);
}
runtime::Module m = (*fsys_load_)(lib_fname);
g.attrs["module"] = std::make_shared<nnvm::any>(m);
std::shared_ptr<GraphExecutor> exec =
std::make_shared<GraphExecutor>();
exec->Init(g, ctx);
// load params form stream of string
exec->LoadParamsFromBlob(std::move(param_blob));
*rv = tvm::runtime::Module(exec);
});
} // namespace contrib
} // namespace tvm
......@@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file Additional optimization pass of NNVM.
*/
#include <dmlc/json.h>
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
......@@ -514,3 +515,27 @@ NNVM_REGISTER_OP(layout_transform)
.set_num_outputs(1);
} // namespace contrib
} // namespace tvm
namespace dmlc {
namespace json {
template<>
struct Handler<DLDataType> {
static void Write(JSONWriter *writer, const DLDataType& data) {
std::vector<int> tmp({data.code, data.bits, data.lanes});
writer->Write(tmp);
}
static void Read(JSONReader *reader, DLDataType* data) {
std::vector<int> tmp;
reader->Read(&tmp);
data->code = tmp[0];
data->bits = tmp[1];
data->lanes = tmp[2];
}
};
DMLC_JSON_ENABLE_ANY(std::vector<DLDataType>, list_dltype);
} // namespace dmlc
} // namespace json
import tvm
from tvm.contrib import util, rpc
import tvm_graph as tg
import numpy as np
import os
def test_rpc_executor():
host = 'localhost'
port = 9091
server = rpc.Server(host, port)
tmp = util.tempdir()
sym_fname = tmp.relpath('net.json')
lib_fname = tmp.relpath('net.o')
param_fname = tmp.relpath('net.param')
x = tg.Variable('x')
y = tg.Variable('y')
sym = tg.exp(y + x)
shape = (10, 128)
dtype = tvm.float32
na = tvm.nd.array(np.ones(shape).astype(dtype))
nb = tvm.nd.array(np.ones(shape).astype(dtype))
tg.save_params(param_fname, {'x': na, 'y': nb})
remote = rpc.connect(host, port)
ctx = remote.cpu(0)
target = "llvm"
shapes = {'x': shape, 'y': shape}
sym_json = tg.compile_graph(lib_fname, sym, target, shapes)
remote.upload(lib_fname)
param_blob = bytearray(open(param_fname, "rb").read())
rm = tg.remote_load_exec(remote,
sym_json,
os.path.basename(lib_fname),
param_blob,
ctx)
run, get_output = rm['run'], rm['get_output']
nc = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx)
run()
get_output(0, nc)
np.testing.assert_allclose(
nc.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy()))
server.terminate()
if __name__ == "__main__":
test_rpc_executor()
......@@ -26,7 +26,7 @@ def test_save_load():
# create another executor
m1 = tg.bind(g, tvm.cpu(0))
load_params1 = m1['load_params']
load_params1('test.params')
load_params1(bytearray(open('test.params', 'rb').read()))
run1, get_output1 = m1['run'], m1['get_output']
run1()
......
......@@ -5,7 +5,7 @@ import os
import platform
def find_lib_path(name=None):
def find_lib_path(name=None, search_path=None):
"""Find dynamic library files.
Parameters
......@@ -35,6 +35,11 @@ def find_lib_path(name=None):
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
dll_path = [os.path.abspath(x) for x in dll_path]
if search_path is not None:
if search_path is list:
dll_path = dll_path + search_path
else:
dll_path.append(search_path)
if name is not None:
lib_dll_path = [os.path.join(p, name) for p in dll_path]
runtime_dll_path = []
......
......@@ -139,6 +139,7 @@ class Server(object):
def __init__(self, host, port=9091, port_end=9199, is_proxy=False, key=""):
self.host = host
self.port = port
self.libs = []
if not is_proxy:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
......@@ -186,8 +187,7 @@ class RPCSession(object):
def __init__(self, sess):
self._sess = sess
self._tbl_index = _SessTableIndex(sess)
self._upload_func = None
self._download_func = None
self._remote_funcs = {}
def get_function(self, name):
"""Get function from the session.
......@@ -259,10 +259,10 @@ class RPCSession(object):
if not target:
target = os.path.basename(data)
if not self._upload_func:
self._upload_func = self.get_function(
if "upload" not in self._remote_funcs:
self._remote_funcs["upload"] = self.get_function(
"tvm.contrib.rpc.server.upload")
self._upload_func(target, blob)
self._remote_funcs["upload"](target, blob)
def download(self, path):
"""Download file from remote temp folder.
......@@ -277,10 +277,10 @@ class RPCSession(object):
blob : bytearray
The result blob from the file.
"""
if not self._download_func:
self._download_func = self.get_function(
if "download" not in self._remote_funcs:
self._remote_funcs["download"] = self.get_function(
"tvm.contrib.rpc.server.download")
return self._download_func(path)
return self._remote_funcs["download"](path)
def load_module(self, path):
"""Load a remote module, the file need to be uploaded first.
......
......@@ -3,7 +3,10 @@ from __future__ import absolute_import
import logging
import argparse
import os
import ctypes
from ..contrib import rpc
from .._ffi.libinfo import find_lib_path
def main():
"""Main funciton"""
......@@ -12,11 +15,21 @@ def main():
help='the hostname of the server')
parser.add_argument('--port', type=int, default=9090,
help='The port of the PRC')
parser.add_argument('--port_end', type=int, default=9199,
parser.add_argument('--port-end', type=int, default=9199,
help='The end search port of the PRC')
parser.add_argument('--with-executor', type=bool, default=False,
help="Whether to load executor runtime")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
server = rpc.Server(args.host, args.port, args.port_end)
if args.with_executor:
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
apps_path = os.path.join(curr_path, "../../../apps/graph_executor/lib/")
lib_path = find_lib_path('libtvm_graph_exec.so', apps_path)
server.libs.append(ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL))
server.proc.join()
if __name__ == "__main__":
......
......@@ -145,7 +145,7 @@ class RPCSession::EventHandler {
arg_recv_stage_ = 0;
arg_buf_.reset();
}
// strip sessionon mask
// strip session on mask
TVMContext StripSessMask(TVMContext ctx) {
int dev_type = ctx.device_type;
CHECK_EQ(dev_type / kRPCSessMask, rpc_sess_table_index_ + 1)
......
......@@ -46,7 +46,7 @@ enum class RPCCode : int {
kModuleLoad,
kModuleFree,
kModuleGetFunc,
kModuleGetSource
kModuleGetSource,
};
/*!
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment