Commit ebbec6de by Tianqi Chen

updates (#25)

* [FIX] Remove extra move

* [MEMORY] Add inplace index
parent be199635
...@@ -160,7 +160,7 @@ int NNSymbolListAttrs(SymbolHandle symbol, ...@@ -160,7 +160,7 @@ int NNSymbolListAttrs(SymbolHandle symbol,
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
std::unordered_map<std::string, std::string> attr = std::unordered_map<std::string, std::string> attr =
std::move(s->ListAttrs(static_cast<Symbol::ListAttrOption>(option))); // NOLINT(*) s->ListAttrs(static_cast<Symbol::ListAttrOption>(option)); // NOLINT(*)
std::vector<std::string>& attr_list = ret->ret_vec_str; std::vector<std::string>& attr_list = ret->ret_vec_str;
attr_list.clear(); attr_list.clear();
...@@ -184,8 +184,8 @@ int NNSymbolListInputNames(SymbolHandle symbol, ...@@ -184,8 +184,8 @@ int NNSymbolListInputNames(SymbolHandle symbol,
Symbol *s = static_cast<Symbol*>(symbol); Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
ret->ret_vec_str = std::move( ret->ret_vec_str =
s->ListInputNames(Symbol::ListInputOption(option))); s->ListInputNames(Symbol::ListInputOption(option));
ret->ret_vec_charp.clear(); ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
...@@ -201,7 +201,7 @@ int NNSymbolListOutputNames(SymbolHandle symbol, ...@@ -201,7 +201,7 @@ int NNSymbolListOutputNames(SymbolHandle symbol,
Symbol *s = static_cast<Symbol*>(symbol); Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get(); NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN(); API_BEGIN();
ret->ret_vec_str = std::move(s->ListOutputNames()); ret->ret_vec_str = s->ListOutputNames();
ret->ret_vec_charp.clear(); ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) { for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str()); ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
......
...@@ -22,10 +22,7 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map, ...@@ -22,10 +22,7 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
} }
inline bool IsMutate(const std::vector<uint32_t>& mutate_inputs, uint32_t i) { inline bool IsMutate(const std::vector<uint32_t>& mutate_inputs, uint32_t i) {
if (mutate_inputs.size() == 0) return false; return std::binary_search(mutate_inputs.begin(), mutate_inputs.end(), i);
auto it = std::lower_bound(
mutate_inputs.begin(), mutate_inputs.end(), i);
return (it != mutate_inputs.end()) && (*it == i);
} }
Graph OrderMutation(const Graph& src) { Graph OrderMutation(const Graph& src) {
......
...@@ -150,6 +150,7 @@ Graph PlanMemory(Graph ret) { ...@@ -150,6 +150,7 @@ Graph PlanMemory(Graph ret) {
} }
// step 2: allocate memory. // step 2: allocate memory.
StorageVector storage(idx.num_node_entries(), -1); StorageVector storage(idx.num_node_entries(), -1);
std::vector<int> storage_inplace_index(idx.num_node_entries(), -1);
const ShapeVector& shape_vec = ret.GetAttr<ShapeVector>("shape"); const ShapeVector& shape_vec = ret.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = ret.GetAttr<DTypeVector>("dtype"); const DTypeVector& dtype_vec = ret.GetAttr<DTypeVector>("dtype");
const DeviceVector* device_vec = nullptr; const DeviceVector* device_vec = nullptr;
...@@ -173,8 +174,10 @@ Graph PlanMemory(Graph ret) { ...@@ -173,8 +174,10 @@ Graph PlanMemory(Graph ret) {
uint32_t eid_out = idx.entry_id(nid, kv.second); uint32_t eid_out = idx.entry_id(nid, kv.second);
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]); uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);
if (ref_count[eid_in] == 1 && storage[eid_in] != GraphAllocator::kBadStorageID) { if (ref_count[eid_in] == 1 && storage[eid_in] != GraphAllocator::kBadStorageID) {
// inplace optimization
storage[eid_out] = storage[eid_in]; storage[eid_out] = storage[eid_in];
ref_count[eid_in] = 0; ref_count[eid_in] = 0;
storage_inplace_index[eid_out] = kv.first;
} }
} }
} }
...@@ -209,8 +212,8 @@ Graph PlanMemory(Graph ret) { ...@@ -209,8 +212,8 @@ Graph PlanMemory(Graph ret) {
} }
} }
} }
ret.attrs["storage_id"] = std::make_shared<any>(std::move(storage)); ret.attrs["storage_id"] = std::make_shared<any>(std::move(storage));
ret.attrs["storage_inplace_index"] = std::make_shared<any>(std::move(storage_inplace_index));
ret.attrs["storage_allocated_bytes"] = std::make_shared<any>(allocator.TotalAllocBytes()); ret.attrs["storage_allocated_bytes"] = std::make_shared<any>(allocator.TotalAllocBytes());
ret.attrs["storage_num_not_allocated"] = std::make_shared<any>(num_not_allocated); ret.attrs["storage_num_not_allocated"] = std::make_shared<any>(num_not_allocated);
return ret; return ret;
...@@ -222,7 +225,8 @@ NNVM_REGISTER_PASS(PlanMemory) ...@@ -222,7 +225,8 @@ NNVM_REGISTER_PASS(PlanMemory)
.set_change_graph(false) .set_change_graph(false)
.depend_graph_attr("dtype") .depend_graph_attr("dtype")
.depend_graph_attr("shape") .depend_graph_attr("shape")
.provide_graph_attr("storage_id"); .provide_graph_attr("storage_id")
.provide_graph_attr("storage_inplace_index");
} // namespace } // namespace
} // namespace pass } // namespace pass
......
...@@ -89,7 +89,7 @@ struct JSONNode { ...@@ -89,7 +89,7 @@ struct JSONNode {
} }
void Load(dmlc::JSONReader *reader) { void Load(dmlc::JSONReader *reader) {
node = std::move(Node::Create()); node = Node::Create();
control_deps.clear(); control_deps.clear();
dmlc::JSONObjectReadHelper helper; dmlc::JSONObjectReadHelper helper;
std::string op_type_str; std::string op_type_str;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment