Commit ebbec6de by Tianqi Chen

updates (#25)

* [FIX] Remove extra move

* [MEMORY] Add inplace index
parent be199635
......@@ -160,7 +160,7 @@ int NNSymbolListAttrs(SymbolHandle symbol,
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
std::unordered_map<std::string, std::string> attr =
std::move(s->ListAttrs(static_cast<Symbol::ListAttrOption>(option))); // NOLINT(*)
s->ListAttrs(static_cast<Symbol::ListAttrOption>(option)); // NOLINT(*)
std::vector<std::string>& attr_list = ret->ret_vec_str;
attr_list.clear();
......@@ -184,8 +184,8 @@ int NNSymbolListInputNames(SymbolHandle symbol,
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str = std::move(
s->ListInputNames(Symbol::ListInputOption(option)));
ret->ret_vec_str =
s->ListInputNames(Symbol::ListInputOption(option));
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
......@@ -201,7 +201,7 @@ int NNSymbolListOutputNames(SymbolHandle symbol,
Symbol *s = static_cast<Symbol*>(symbol);
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
API_BEGIN();
ret->ret_vec_str = std::move(s->ListOutputNames());
ret->ret_vec_str = s->ListOutputNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
......
......@@ -22,10 +22,7 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
}
inline bool IsMutate(const std::vector<uint32_t>& mutate_inputs, uint32_t i) {
if (mutate_inputs.size() == 0) return false;
auto it = std::lower_bound(
mutate_inputs.begin(), mutate_inputs.end(), i);
return (it != mutate_inputs.end()) && (*it == i);
return std::binary_search(mutate_inputs.begin(), mutate_inputs.end(), i);
}
Graph OrderMutation(const Graph& src) {
......
......@@ -150,6 +150,7 @@ Graph PlanMemory(Graph ret) {
}
// step 2: allocate memory.
StorageVector storage(idx.num_node_entries(), -1);
std::vector<int> storage_inplace_index(idx.num_node_entries(), -1);
const ShapeVector& shape_vec = ret.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = ret.GetAttr<DTypeVector>("dtype");
const DeviceVector* device_vec = nullptr;
......@@ -173,8 +174,10 @@ Graph PlanMemory(Graph ret) {
uint32_t eid_out = idx.entry_id(nid, kv.second);
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);
if (ref_count[eid_in] == 1 && storage[eid_in] != GraphAllocator::kBadStorageID) {
// inplace optimization
storage[eid_out] = storage[eid_in];
ref_count[eid_in] = 0;
storage_inplace_index[eid_out] = kv.first;
}
}
}
......@@ -209,8 +212,8 @@ Graph PlanMemory(Graph ret) {
}
}
}
ret.attrs["storage_id"] = std::make_shared<any>(std::move(storage));
ret.attrs["storage_inplace_index"] = std::make_shared<any>(std::move(storage_inplace_index));
ret.attrs["storage_allocated_bytes"] = std::make_shared<any>(allocator.TotalAllocBytes());
ret.attrs["storage_num_not_allocated"] = std::make_shared<any>(num_not_allocated);
return ret;
......@@ -222,7 +225,8 @@ NNVM_REGISTER_PASS(PlanMemory)
.set_change_graph(false)
.depend_graph_attr("dtype")
.depend_graph_attr("shape")
.provide_graph_attr("storage_id");
.provide_graph_attr("storage_id")
.provide_graph_attr("storage_inplace_index");
} // namespace
} // namespace pass
......
......@@ -89,7 +89,7 @@ struct JSONNode {
}
void Load(dmlc::JSONReader *reader) {
node = std::move(Node::Create());
node = Node::Create();
control_deps.clear();
dmlc::JSONObjectReadHelper helper;
std::string op_type_str;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment