Commit 204e9cb4 by ziheng Committed by GitHub

[EXECUTOR] Fix bug and improve (#252)

* [EXECUTOR] Fix bug and improve

* [EXECUTOR] Enhance test case
parent f433373d
...@@ -77,6 +77,8 @@ class GraphExecutor : public runtime::ModuleNode { ...@@ -77,6 +77,8 @@ class GraphExecutor : public runtime::ModuleNode {
TVMContext ctx_; TVMContext ctx_;
// Common storage pool // Common storage pool
std::vector<DLTensor*> storage_pool_; std::vector<DLTensor*> storage_pool_;
// The data shape
std::vector<TShape> data_shape_;
// The data entry // The data entry
std::vector<DLTensor> data_entry_; std::vector<DLTensor> data_entry_;
// The operation lambda on each node // The operation lambda on each node
...@@ -251,13 +253,10 @@ void GraphExecutor::LoadParams(dmlc::Stream *strm) { ...@@ -251,13 +253,10 @@ void GraphExecutor::LoadParams(dmlc::Stream *strm) {
CHECK(strm->Read(&names)) CHECK(strm->Read(&names))
<< "Invalid parameters file format"; << "Invalid parameters file format";
nnvm::Symbol s; std::unordered_map<std::string, size_t> name_eid;
s.outputs = graph_.outputs; const auto& idx = graph_.indexed_graph();
std::vector<std::string> input_names = for (int nid : idx.input_nodes()) {
s.ListInputNames(nnvm::Symbol::ListInputOption::kAll); name_eid.emplace(idx[nid].source->attrs.name, idx.entry_id(nid, 0));
std::unordered_map<std::string, size_t> name_index;
for (size_t i = 0; i < input_names.size(); ++i) {
name_index.emplace(input_names[i], i);
} }
uint64_t sz; uint64_t sz;
...@@ -266,8 +265,9 @@ void GraphExecutor::LoadParams(dmlc::Stream *strm) { ...@@ -266,8 +265,9 @@ void GraphExecutor::LoadParams(dmlc::Stream *strm) {
CHECK(size == names.size()) CHECK(size == names.size())
<< "Invalid parameters file format"; << "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) { for (size_t i = 0; i < size; ++i) {
size_t idx = name_index.at(names[i]); auto iter = name_eid.find(names[i]);
CHECK(LoadDLTensor(strm, &data_entry_[idx])) CHECK(iter != name_eid.end());
CHECK(LoadDLTensor(strm, &data_entry_[iter->second]))
<< "Invalid parameters file format"; << "Invalid parameters file format";
} }
} }
...@@ -281,13 +281,13 @@ void GraphExecutor::SetupStorage() { ...@@ -281,13 +281,13 @@ void GraphExecutor::SetupStorage() {
const auto& idx = graph_.indexed_graph(); const auto& idx = graph_.indexed_graph();
// Grab saved optimization plan from graph. // Grab saved optimization plan from graph.
auto vstorage = graph_.MoveCopyAttr<StorageVector>("storage_id"); auto vstorage = graph_.MoveCopyAttr<StorageVector>("storage_id");
const auto& vshape = graph_.GetAttr<ShapeVector>("shape");
const auto& vtype = graph_.GetAttr<DLTypeVector>("dltype"); const auto& vtype = graph_.GetAttr<DLTypeVector>("dltype");
data_shape_ = graph_.GetAttr<ShapeVector>("shape");
data_entry_.resize(idx.num_node_entries()); data_entry_.resize(idx.num_node_entries());
// Find the maximum space size. // Find the maximum space size.
int max_id = 0; int max_id = 0;
for (size_t i = 0; i < vshape.size(); ++i) { for (size_t i = 0; i < data_shape_.size(); ++i) {
max_id = std::max(vstorage[i] + 1, max_id); max_id = std::max(vstorage[i] + 1, max_id);
} }
for (const auto& e : idx.input_nodes()) { for (const auto& e : idx.input_nodes()) {
...@@ -296,9 +296,9 @@ void GraphExecutor::SetupStorage() { ...@@ -296,9 +296,9 @@ void GraphExecutor::SetupStorage() {
// size of each storage pool entry // size of each storage pool entry
std::vector<size_t> pool_entry_bytes; std::vector<size_t> pool_entry_bytes;
// Find the maximum space size. // Find the maximum space size.
for (size_t i = 0; i < vshape.size(); ++i) { for (size_t i = 0; i < data_shape_.size(); ++i) {
int storage_id = vstorage[i]; int storage_id = vstorage[i];
size_t size = vshape[i].Size(); size_t size = data_shape_[i].Size();
CHECK_GE(storage_id, 0) << "Do not support runtime shape op"; CHECK_GE(storage_id, 0) << "Do not support runtime shape op";
DLDataType t = vtype[i]; DLDataType t = vtype[i];
...@@ -324,8 +324,8 @@ void GraphExecutor::SetupStorage() { ...@@ -324,8 +324,8 @@ void GraphExecutor::SetupStorage() {
for (size_t i = 0; i < data_entry_.size(); ++i) { for (size_t i = 0; i < data_entry_.size(); ++i) {
int storage_id = vstorage[i]; int storage_id = vstorage[i];
data_entry_[i] = *storage_pool_[storage_id]; data_entry_[i] = *storage_pool_[storage_id];
data_entry_[i].shape = const_cast<int64_t*>(vshape[i].data()); data_entry_[i].shape = const_cast<int64_t*>(data_shape_[i].data());
data_entry_[i].ndim = vshape[i].ndim(); data_entry_[i].ndim = data_shape_[i].ndim();
data_entry_[i].dtype = vtype[i]; data_entry_[i].dtype = vtype[i];
} }
} }
......
...@@ -16,7 +16,7 @@ def test_rpc_executor(): ...@@ -16,7 +16,7 @@ def test_rpc_executor():
x = tg.Variable('x') x = tg.Variable('x')
y = tg.Variable('y') y = tg.Variable('y')
sym = tg.exp(y + x) sym = tg.exp(y + x) + tg.exp(x + y)
shape = (10, 128) shape = (10, 128)
dtype = tvm.float32 dtype = tvm.float32
...@@ -46,8 +46,10 @@ def test_rpc_executor(): ...@@ -46,8 +46,10 @@ def test_rpc_executor():
run() run()
get_output(0, nc) get_output(0, nc)
np.testing.assert_allclose( npa = na.asnumpy()
nc.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy())) npb = nb.asnumpy()
np.testing.assert_allclose(nc.asnumpy(),
np.exp(npa + npb) + np.exp(npb + npa))
server.terminate() server.terminate()
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -22,14 +22,15 @@ def main(): ...@@ -22,14 +22,15 @@ def main():
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
server = rpc.Server(args.host, args.port, args.port_end)
if args.with_executor: if args.with_executor:
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
apps_path = os.path.join(curr_path, "../../../apps/graph_executor/lib/") apps_path = os.path.join(curr_path, "../../../apps/graph_executor/lib/")
lib_path = find_lib_path('libtvm_graph_exec.so', apps_path) lib_path = find_lib_path('libtvm_graph_exec.so', apps_path)
server.libs.append(ctypes.CDLL(lib_path[0], ctypes.RTLD_GLOBAL)) lib = ctypes.CDLL(lib_path[0])
server = rpc.Server(args.host, args.port, args.port_end)
server.libs.append(lib)
server.proc.join() server.proc.join()
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment