Commit b931f8e2 by Eric Junyuan Xie Committed by Tianqi Chen

add range to plan memory (#147)

parent 00f8165c
...@@ -143,7 +143,9 @@ class GraphAllocator { ...@@ -143,7 +143,9 @@ class GraphAllocator {
/* /*
* Internal method to perform the memory allocation for a graph * Internal method to perform the memory allocation for a graph
* */ * */
size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* storage_ptr, size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
const std::pair<uint32_t, uint32_t>& node_range,
StorageVector* storage_ptr,
std::vector<int>* storage_inplace_index_ptr, std::vector<int>* storage_inplace_index_ptr,
const std::vector<uint32_t>& entry_ref_count, const std::vector<uint32_t>& entry_ref_count,
GraphAllocator* allocator) { GraphAllocator* allocator) {
...@@ -165,7 +167,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto ...@@ -165,7 +167,7 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto
size_t num_not_allocated = 0; size_t num_not_allocated = 0;
std::vector<GraphAllocator::StorageID> storage_ref_count(idx.num_node_entries(), 0); std::vector<GraphAllocator::StorageID> storage_ref_count(idx.num_node_entries(), 0);
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = node_range.first; nid < node_range.second; ++nid) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
if (inode.source->is_variable()) continue; if (inode.source->is_variable()) continue;
// check inplace option // check inplace option
...@@ -272,26 +274,35 @@ Graph PlanMemory(Graph ret) { ...@@ -272,26 +274,35 @@ Graph PlanMemory(Graph ret) {
// setup ref counter // setup ref counter
const IndexedGraph& idx = ret.indexed_graph(); const IndexedGraph& idx = ret.indexed_graph();
static auto& fignore_inputs = Op::GetAttr<FIgnoreInputs>("FIgnoreInputs"); static auto& fignore_inputs = Op::GetAttr<FIgnoreInputs>("FIgnoreInputs");
std::pair<uint32_t, uint32_t> node_range = {0, idx.num_nodes()};
if (ret.attrs.count("node_range")) {
node_range = ret.MoveCopyAttr<std::pair<uint32_t, uint32_t> >("node_range");
}
// reference counter of each node // reference counter of each node
std::vector<uint32_t> ref_count(idx.num_node_entries(), 0); std::vector<uint32_t> ref_count;
// step 1: initialize reference count // step 1: initialize reference count
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { if (ret.attrs.count("ref_count") != 0) {
const auto& inode = idx[nid]; ref_count = ret.MoveCopyAttr<std::vector<uint32_t> >("ref_count");
if (inode.source->is_variable()) continue; } else {
for (const auto& e : inode.inputs) { ref_count.resize(idx.num_node_entries(), 0);
++ref_count[idx.entry_id(e)]; for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
} const auto& inode = idx[nid];
// no dataflow dependency is needed for those are ignored. if (inode.source->is_variable()) continue;
// revoke the dependency counter. for (const auto& e : inode.inputs) {
if (fignore_inputs.count(inode.source->op()) != 0) { ++ref_count[idx.entry_id(e)];
auto ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs); }
for (uint32_t i : ignore_inputs) { // no dataflow dependency is needed for those are ignored.
--ref_count[idx.entry_id(inode.inputs[i])]; // revoke the dependency counter.
if (fignore_inputs.count(inode.source->op()) != 0) {
auto ignore_inputs = fignore_inputs[inode.source->op()](inode.source->attrs);
for (uint32_t i : ignore_inputs) {
--ref_count[idx.entry_id(inode.inputs[i])];
}
} }
} }
} for (const auto& e : idx.outputs()) {
for (const auto& e : idx.outputs()) { ++ref_count[idx.entry_id(e)];
++ref_count[idx.entry_id(e)]; }
} }
// step 2: allocate memory. // step 2: allocate memory.
StorageVector storage; StorageVector storage;
...@@ -316,7 +327,8 @@ Graph PlanMemory(Graph ret) { ...@@ -316,7 +327,8 @@ Graph PlanMemory(Graph ret) {
// number of entries that are not statically allocated. // number of entries that are not statically allocated.
size_t storage_num_not_allocated = size_t storage_num_not_allocated =
AllocMemory(ret, idx, &storage_vec, &storage_inplace_index, ref_count, &allocator); AllocMemory(ret, idx, node_range, &storage_vec, &storage_inplace_index,
ref_count, &allocator);
size_t storage_allocated_bytes = allocator.TotalAllocBytes(); size_t storage_allocated_bytes = allocator.TotalAllocBytes();
// Choose the plan which leads to minimal memory usage // Choose the plan which leads to minimal memory usage
if (min_allocated_bytes > storage_allocated_bytes) { if (min_allocated_bytes > storage_allocated_bytes) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment