Commit 204e9cb4 by ziheng Committed by GitHub

[EXECUTOR] Fix bug and improve (#252)

* [EXECUTOR] Fix bug and improve

* [EXECUTOR] Enhance test case
parent f433373d
......@@ -77,6 +77,8 @@ class GraphExecutor : public runtime::ModuleNode {
TVMContext ctx_;
// Common storage pool
std::vector<DLTensor*> storage_pool_;
// The data shape
std::vector<TShape> data_shape_;
// The data entry
std::vector<DLTensor> data_entry_;
// The operation lambda on each node
......@@ -251,13 +253,10 @@ void GraphExecutor::LoadParams(dmlc::Stream *strm) {
CHECK(strm->Read(&names))
<< "Invalid parameters file format";
nnvm::Symbol s;
s.outputs = graph_.outputs;
std::vector<std::string> input_names =
s.ListInputNames(nnvm::Symbol::ListInputOption::kAll);
std::unordered_map<std::string, size_t> name_index;
for (size_t i = 0; i < input_names.size(); ++i) {
name_index.emplace(input_names[i], i);
std::unordered_map<std::string, size_t> name_eid;
const auto& idx = graph_.indexed_graph();
for (int nid : idx.input_nodes()) {
name_eid.emplace(idx[nid].source->attrs.name, idx.entry_id(nid, 0));
}
uint64_t sz;
......@@ -266,8 +265,9 @@ void GraphExecutor::LoadParams(dmlc::Stream *strm) {
CHECK(size == names.size())
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
size_t idx = name_index.at(names[i]);
CHECK(LoadDLTensor(strm, &data_entry_[idx]))
auto iter = name_eid.find(names[i]);
CHECK(iter != name_eid.end());
CHECK(LoadDLTensor(strm, &data_entry_[iter->second]))
<< "Invalid parameters file format";
}
}
......@@ -281,13 +281,13 @@ void GraphExecutor::SetupStorage() {
const auto& idx = graph_.indexed_graph();
// Grab saved optimization plan from graph.
auto vstorage = graph_.MoveCopyAttr<StorageVector>("storage_id");
const auto& vshape = graph_.GetAttr<ShapeVector>("shape");
const auto& vtype = graph_.GetAttr<DLTypeVector>("dltype");
data_shape_ = graph_.GetAttr<ShapeVector>("shape");
data_entry_.resize(idx.num_node_entries());
// Find the maximum space size.
int max_id = 0;
for (size_t i = 0; i < vshape.size(); ++i) {
for (size_t i = 0; i < data_shape_.size(); ++i) {
max_id = std::max(vstorage[i] + 1, max_id);
}
for (const auto& e : idx.input_nodes()) {
......@@ -296,9 +296,9 @@ void GraphExecutor::SetupStorage() {
// size of each storage pool entry
std::vector<size_t> pool_entry_bytes;
// Find the maximum space size.
for (size_t i = 0; i < vshape.size(); ++i) {
for (size_t i = 0; i < data_shape_.size(); ++i) {
int storage_id = vstorage[i];
size_t size = vshape[i].Size();
size_t size = data_shape_[i].Size();
CHECK_GE(storage_id, 0) << "Do not support runtime shape op";
DLDataType t = vtype[i];
......@@ -324,8 +324,8 @@ void GraphExecutor::SetupStorage() {
for (size_t i = 0; i < data_entry_.size(); ++i) {
int storage_id = vstorage[i];
data_entry_[i] = *storage_pool_[storage_id];
data_entry_[i].shape = const_cast<int64_t*>(vshape[i].data());
data_entry_[i].ndim = vshape[i].ndim();
data_entry_[i].shape = const_cast<int64_t*>(data_shape_[i].data());
data_entry_[i].ndim = data_shape_[i].ndim();
data_entry_[i].dtype = vtype[i];
}
}
......
......@@ -16,7 +16,7 @@ def test_rpc_executor():
x = tg.Variable('x')
y = tg.Variable('y')
sym = tg.exp(y + x)
sym = tg.exp(y + x) + tg.exp(x + y)
shape = (10, 128)
dtype = tvm.float32
......@@ -46,8 +46,10 @@ def test_rpc_executor():
run()
get_output(0, nc)
np.testing.assert_allclose(
nc.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy()))
npa = na.asnumpy()
npb = nb.asnumpy()
np.testing.assert_allclose(nc.asnumpy(),
np.exp(npa + npb) + np.exp(npb + npa))
server.terminate()
if __name__ == "__main__":
......
......@@ -22,14 +22,15 @@ def main():
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
server = rpc.Server(args.host, args.port, args.port_end)
if args.with_executor:
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
apps_path = os.path.join(curr_path, "../../../apps/graph_executor/lib/")
lib_path = find_lib_path('libtvm_graph_exec.so', apps_path)
server.libs.append(ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL))
lib = ctypes.CDLL(lib_path[0])
server = rpc.Server(args.host, args.port, args.port_end)
server.libs.append(lib)
server.proc.join()
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment