Commit 8518c7dd by Przemyslaw Tredak Committed by Tianqi Chen

Fix the FInplaceIdentity (#2572)

parent fef72827
...@@ -218,10 +218,14 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx, ...@@ -218,10 +218,14 @@ size_t AllocMemory(const Graph& ret, const IndexedGraph& idx,
bool ignore_all_inputs = (fignore_inputs.count(inode.source->op()) != 0 && bool ignore_all_inputs = (fignore_inputs.count(inode.source->op()) != 0 &&
fignore_inputs[inode.source->op()]( fignore_inputs[inode.source->op()](
inode.source->attrs).size() == inode.source->num_inputs()); inode.source->attrs).size() == inode.source->num_inputs());
// Identity should only be true if shape.Size() and types match
bool real_identity = identity[ipair] &&
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
dtype_vec[eid_out] == dtype_vec[eid_in];
if (taken[kv.first] == false && if (taken[kv.first] == false &&
sid_out == GraphAllocator::kBadStorageID && sid_out == GraphAllocator::kBadStorageID &&
sid_in >= 0 && sid_in >= 0 &&
((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || identity[ipair]) && ((storage_ref_count[sid_in] == 1 && !ignore_all_inputs) || real_identity) &&
entry_ref_count[eid_out] > 0 && entry_ref_count[eid_out] > 0 &&
shape_vec[eid_out].Size() == shape_vec[eid_in].Size() && shape_vec[eid_out].Size() == shape_vec[eid_in].Size() &&
(dtype_vec[eid_out] == dtype_vec[eid_in] || (dtype_vec[eid_out] == dtype_vec[eid_in] ||
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment