Commit 06c06f58 by Eric Junyuan Xie Committed by Tianqi Chen

fix plan memory add inplace_identity option (#124)

* fix plan memory add inplace_identity option

* comment
parent c00100f8
...@@ -107,8 +107,6 @@ using TIsBackward = bool; ...@@ -107,8 +107,6 @@ using TIsBackward = bool;
* \brief Get possible inplace options. * \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output. * This function enables optimization to reuse memory of inputs in output.
* \param attrs The attributes of the node * \param attrs The attributes of the node
* \param in_data The input data.
* \param out_data The output data.
* \return list of pair of that maps input->output, * \return list of pair of that maps input->output,
* indicating possible in place operations. * indicating possible in place operations.
* *
...@@ -118,6 +116,18 @@ using FInplaceOption = std::function< ...@@ -118,6 +116,18 @@ using FInplaceOption = std::function<
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>; std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
/*! /*!
* \brief Get if the inplace option is an identity
* This function enables inplace optimization even when input reference count
* is greater than one.
* \param attrs The attributes of the node
* \return list of bool indicating whether corresponding pair from FInplaceOption
* is an identity
*
* \note Register under "FInplaceIdentity", by default no identities.
*/
using FInplaceIdentity = std::function<std::vector<bool> (const NodeAttrs& attrs)>;
/*!
* \brief Get list of inputs in the op whose content are actually not used by the operator * \brief Get list of inputs in the op whose content are actually not used by the operator
* These are dummy input that can be used for example in zeros_like, ones_like. * These are dummy input that can be used for example in zeros_like, ones_like.
* *
......
...@@ -142,8 +142,12 @@ class GraphAllocator { ...@@ -142,8 +142,12 @@ class GraphAllocator {
* Internal method to perform the memory allocation for a graph * Internal method to perform the memory allocation for a graph
* */ * */
size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* storage_ptr, size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* storage_ptr,
std::vector<int>* storage_inplace_index_ptr, std::vector<uint32_t> ref_count, std::vector<int>* storage_inplace_index_ptr,
const std::vector<uint32_t>& entry_ref_count,
GraphAllocator* allocator) { GraphAllocator* allocator) {
static auto& finplace_option = Op::GetAttr<FInplaceOption>("FInplaceOption");
static auto& finplace_identity = Op::GetAttr<FInplaceIdentity>("FInplaceIdentity");
// Get reference // Get reference
auto &storage = *storage_ptr; auto &storage = *storage_ptr;
auto &storage_inplace_index = *storage_inplace_index_ptr; auto &storage_inplace_index = *storage_inplace_index_ptr;
...@@ -152,12 +156,12 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto ...@@ -152,12 +156,12 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto
const ShapeVector& shape_vec = ret.GetAttr<ShapeVector>("shape"); const ShapeVector& shape_vec = ret.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = ret.GetAttr<DTypeVector>("dtype"); const DTypeVector& dtype_vec = ret.GetAttr<DTypeVector>("dtype");
const DeviceVector* device_vec = nullptr; const DeviceVector* device_vec = nullptr;
static auto& finplace_option = Op::GetAttr<FInplaceOption>("FInplaceOption");
if (ret.attrs.count("device") != 0) { if (ret.attrs.count("device") != 0) {
device_vec = &(ret.GetAttr<DeviceVector>("device")); device_vec = &(ret.GetAttr<DeviceVector>("device"));
} }
size_t num_not_allocated = 0; size_t num_not_allocated = 0;
std::vector<GraphAllocator::StorageID> storage_ref_count(idx.num_node_entries(), 0);
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
...@@ -165,18 +169,36 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto ...@@ -165,18 +169,36 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto
// check inplace option // check inplace option
if (finplace_option.count(inode.source->op()) != 0) { if (finplace_option.count(inode.source->op()) != 0) {
auto inplace_pairs = finplace_option[inode.source->op()](inode.source->attrs); auto inplace_pairs = finplace_option[inode.source->op()](inode.source->attrs);
for (auto& kv : inplace_pairs) { std::vector<bool> identity;
if (finplace_identity.count(inode.source->op()) != 0) {
identity = finplace_identity[inode.source->op()](inode.source->attrs);
CHECK_EQ(identity.size(), inplace_pairs.size())
<< "FInplaceOption and FInplaceIdentity returned vectors of different "
<< "size for operator " << inode.source->op()->name;
} else {
identity = std::vector<bool>(inplace_pairs.size(), false);
}
std::vector<bool> taken(inode.inputs.size(), false);
for (size_t ipair = 0; ipair < inplace_pairs.size(); ++ipair) {
const auto& kv = inplace_pairs[ipair];
uint32_t eid_out = idx.entry_id(nid, kv.second); uint32_t eid_out = idx.entry_id(nid, kv.second);
uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]); uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);
if (ref_count[eid_in] == 1 && auto sid_out = storage[eid_out];
ref_count[eid_out] != 0 && auto sid_in = storage[eid_in];
storage[eid_out] == GraphAllocator::kBadStorageID && if (taken[kv.first] == false &&
storage[eid_in] >= 0 && sid_out == GraphAllocator::kBadStorageID &&
sid_in >= 0 &&
(storage_ref_count[sid_in] == 1 || identity[ipair]) &&
entry_ref_count[eid_out] > 0 &&
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
dtype_vec[eid_out] == dtype_vec[eid_in]) { dtype_vec[eid_out] == dtype_vec[eid_in]) {
// inplace optimization // inplace optimization
storage[eid_out] = storage[eid_in]; taken[kv.first] = true;
ref_count[eid_in] = 0; storage[eid_out] = sid_in;
// Reuse storage for output and add ref count of output
// to storage. This will get substracted later in free
// input section.
storage_ref_count[sid_in] += entry_ref_count[eid_out];
storage_inplace_index[eid_out] = kv.first; storage_inplace_index[eid_out] = kv.first;
} }
} }
...@@ -196,7 +218,9 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto ...@@ -196,7 +218,9 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto
} }
for (auto rit = eids.rbegin(); rit != eids.rend(); ++rit) { for (auto rit = eids.rbegin(); rit != eids.rend(); ++rit) {
uint32_t eid = rit->second; uint32_t eid = rit->second;
storage[eid] = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid); auto sid = allocator->Request(dev_id, dtype_vec[eid], shape_vec[eid], nid);
storage_ref_count[sid] = entry_ref_count[eid];
storage[eid] = sid;
} }
// check if certain inputs is ignored. // check if certain inputs is ignored.
...@@ -212,20 +236,22 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto ...@@ -212,20 +236,22 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, StorageVector* sto
if (std::binary_search(ignore_inputs.begin(), ignore_inputs.end(), i)) continue; if (std::binary_search(ignore_inputs.begin(), ignore_inputs.end(), i)) continue;
const auto& e = inode.inputs[i]; const auto& e = inode.inputs[i];
uint32_t eid = idx.entry_id(e); uint32_t eid = idx.entry_id(e);
// temp_ref_count == 0 means it is taken by inplace op auto sid = storage[eid];
if (ref_count[eid] == 0) continue; // storage_ref_count == 0 means it is taken by inplace op
if (sid < 0) continue;
// if we decrease it to zero, means we are ready to relase // if we decrease it to zero, means we are ready to relase
--ref_count[eid]; --storage_ref_count[sid];
if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) { if (storage_ref_count[sid] == 0) {
allocator->Release(storage[eid], nid); allocator->Release(sid, nid);
} }
} }
// check if there are outputs that can be freeded immediately // check if there are outputs that can be freeded immediately
// these output are not referenced by any operator. // these output are not referenced by any operator.
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index); uint32_t eid = idx.entry_id(nid, index);
if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) { auto sid = storage[eid];
allocator->Release(storage[eid], nid); if (sid >= 0 && storage_ref_count[sid] == 0) {
allocator->Release(sid, nid);
// use -2 to indicate that the node was never touched. // use -2 to indicate that the node was never touched.
storage_inplace_index[eid] = -2; storage_inplace_index[eid] = -2;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment