Commit 06c06f58 by Eric Junyuan Xie Committed by Tianqi Chen

fix plan memory add inplace_identity option (#124)

* fix plan memory add inplace_identity option

* comment
parent c00100f8
......@@ -107,8 +107,6 @@ using TIsBackward = bool;
* \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output.
* \param attrs The attributes of the node
* \param in_data The input data.
* \param out_data The output data.
* \return list of pair of that maps input->output,
* indicating possible in place operations.
*
......@@ -118,6 +116,18 @@ using FInplaceOption = std::function<
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
/*!
* \brief Get if the inplace option is an identity
* This function enables inplace optimization even when input reference count
* is greater than one.
* \param attrs The attributes of the node
* \return list of bool indicating whether corresponding pair from FInplaceOption
* is an identity
*
* \note Register under "FInplaceIdentity", by default no identities.
*/
using FInplaceIdentity = std::function<std::vector<bool> (const NodeAttrs& attrs)>;
/*!
* \brief Get list of inputs in the op whose content are actually not used by the operator
* These are dummy input that can be used for example in zeros_like, ones_like.
*
......
......@@ -142,8 +142,12 @@ class GraphAllocator {
* Internal method to perform the memory allocation for a graph
* */
size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* storage_ptr,
std::vector<int>* storage_inplace_index_ptr, std::vector<uint32_t> ref_count,
std::vector<int>* storage_inplace_index_ptr,
const std::vector<uint32_t>& entry_ref_count,
GraphAllocator* allocator) {
static auto& finplace_option = Op::GetAttr<FInplaceOption>("FInplaceOption");
static auto& finplace_identity = Op::GetAttr<FInplaceIdentity>("FInplaceIdentity");
// Get reference
auto &storage = *storage_ptr;
auto &storage_inplace_index = *storage_inplace_index_ptr;
......@@ -152,12 +156,12 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto
const ShapeVector& shape_vec = ret.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = ret.GetAttr<DTypeVector>("dtype");
const DeviceVector* device_vec = nullptr;
static auto& finplace_option = Op::GetAttr<FInplaceOption>("FInplaceOption");
if (ret.attrs.count("device") != 0) {
device_vec = &(ret.GetAttr<DeviceVector>("device"));
}
size_t num_not_allocated = 0;
std::vector<GraphAllocator::StorageID> storage_ref_count(idx.num_node_entries(), 0);
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
......@@ -165,18 +169,36 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto
// check inplace option
if (finplace_option.count(inode.source->op()) != 0) {
auto inplace_pairs = finplace_option[inode.source->op()](inode.source->attrs);
for (auto& kv : inplace_pairs) {
std::vector<bool> identity;
if (finplace_identity.count(inode.source->op()) != 0) {
identity = finplace_identity[inode.source->op()](inode.source->attrs);
CHECK_EQ(identity.size(), inplace_pairs.size())
<< "FInplaceOption and FInplaceIdentity returned vectors of different "
<< "size for operator " << inode.source->op()->name;
} else {
identity = std::vector<bool>(inplace_pairs.size(), false);
}
std::vector<bool> taken(inode.inputs.size(), false);
for (size_t ipair = 0; ipair < inplace_pairs.size(); ++ipair) {
const auto& kv = inplace_pairs[ipair];
uint32_t eid_out = idx.entry_id(nid, kv.second);
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);
if (ref_count[eid_in] == 1 &&
ref_count[eid_out] != 0 &&
storage[eid_out] == GraphAllocator::kBadStorageID &&
storage[eid_in] >= 0 &&
auto sid_out = storage[eid_out];
auto sid_in = storage[eid_in];
if (taken[kv.first] == false &&
sid_out == GraphAllocator::kBadStorageID &&
sid_in >= 0 &&
(storage_ref_count[sid_in] == 1 || identity[ipair]) &&
entry_ref_count[eid_out] > 0 &&
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
dtype_vec[eid_out] == dtype_vec[eid_in]) {
// inplace optimization
storage[eid_out] = storage[eid_in];
ref_count[eid_in] = 0;
taken[kv.first] = true;
storage[eid_out] = sid_in;
// Reuse storage for output and add ref count of output
// to storage. This will get substracted later in free
// input section.
storage_ref_count[sid_in] += entry_ref_count[eid_out];
storage_inplace_index[eid_out] = kv.first;
}
}
......@@ -196,7 +218,9 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto
}
for (auto rit = eids.rbegin(); rit != eids.rend(); ++rit) {
uint32_t eid = rit->second;
storage[eid] = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid);
auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid);
storage_ref_count[sid] = entry_ref_count[eid];
storage[eid] = sid;
}
// check if certain inputs is ignored.
......@@ -212,20 +236,22 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto
if (std::binary_search(ignore_inputs.begin(), ignore_inputs.end(), i)) continue;
const auto& e = inode.inputs[i];
uint32_t eid = idx.entry_id(e);
// temp_ref_count == 0 means it is taken by inplace op
if (ref_count[eid] == 0) continue;
auto sid = storage[eid];
// storage_ref_count == 0 means it is taken by inplace op
if (sid < 0) continue;
// if we decrease it to zero, means we are ready to relase
--ref_count[eid];
if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) {
allocator->Release(storage[eid], nid);
--storage_ref_count[sid];
if (storage_ref_count[sid] == 0) {
allocator->Release(sid, nid);
}
}
// check if there are outputs that can be freeded immediately
// these output are not referenced by any operator.
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) {
allocator->Release(storage[eid], nid);
auto sid = storage[eid];
if (sid >= 0 && storage_ref_count[sid] == 0) {
allocator->Release(sid, nid);
// use -2 to indicate that the node was never touched.
storage_inplace_index[eid] = -2;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment