Commit 9a4e1339 by Tianqi Chen

Bugfix plan memory, fully support mxnet executor (#32)

* [PASS] include knullop info in plan memory

* Bugfix plan memory, fully support mxnet
parent 7c3d18c6
...@@ -31,6 +31,18 @@ namespace nnvm { ...@@ -31,6 +31,18 @@ namespace nnvm {
using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>; using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
/*! /*!
* \brief Return number of visible outputs by the user.
*
* \param attrs The attributes of the node.
*
* \note Register under "FNumVisibleOutputs", default not registered.
* This can be used to hide certain output from the user,
* but the additional outputs can be used to pass information from
* forward to gradient pass.
*/
using FNumVisibleOutputs = std::function<uint32_t (const NodeAttrs& attrs)>;
/*!
* \brief Return list of output arguments names of each operator. * \brief Return list of output arguments names of each operator.
* *
* \param attrs The attributes of the node. * \param attrs The attributes of the node.
......
...@@ -87,7 +87,7 @@ inline std::vector<std::string> GetKeys( ...@@ -87,7 +87,7 @@ inline std::vector<std::string> GetKeys(
// whether the symbol is atomic functor // whether the symbol is atomic functor
inline bool IsAtomic(const std::vector<NodeEntry>& outputs) { inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
return outputs.size() == 1 && outputs[0].node->inputs.size() == 0; return outputs[0].node->inputs.size() == 0;
} }
// public functions // public functions
...@@ -222,6 +222,7 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const { ...@@ -222,6 +222,7 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
std::vector<std::string> Symbol::ListOutputNames() const { std::vector<std::string> Symbol::ListOutputNames() const {
static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames"); static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
std::vector<std::string> ret; std::vector<std::string> ret;
for (auto &head : outputs) { for (auto &head : outputs) {
if (head.node->is_variable()) { if (head.node->is_variable()) {
...@@ -256,8 +257,6 @@ void Symbol::Compose(const array_view<const Symbol*>& args, ...@@ -256,8 +257,6 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
const std::string& name) { const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames"); static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
CHECK_EQ(outputs.size(), 1)
<< "Only composition of value function is supported currently";
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed"; CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
// parameter check. // parameter check.
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
...@@ -400,6 +399,7 @@ void Symbol::AddControlDeps(const Symbol& src) { ...@@ -400,6 +399,7 @@ void Symbol::AddControlDeps(const Symbol& src) {
} }
Symbol Symbol::GetInternals() const { Symbol Symbol::GetInternals() const {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol ret; Symbol ret;
DFSVisit(this->outputs, [&ret](const NodePtr& node) { DFSVisit(this->outputs, [&ret](const NodePtr& node) {
Node* n = node.get(); Node* n = node.get();
...@@ -409,6 +409,9 @@ Symbol Symbol::GetInternals() const { ...@@ -409,6 +409,9 @@ Symbol Symbol::GetInternals() const {
ret.outputs.emplace_back(NodeEntry{node, 0, param.version}); ret.outputs.emplace_back(NodeEntry{node, 0, param.version});
} else { } else {
uint32_t nout = n->num_outputs(); uint32_t nout = n->num_outputs();
if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) { for (uint32_t i = 0; i < nout; ++i) {
ret.outputs.emplace_back(NodeEntry{node, i, 0}); ret.outputs.emplace_back(NodeEntry{node, i, 0});
} }
...@@ -467,6 +470,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op ...@@ -467,6 +470,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op
Symbol Symbol::CreateFunctor(const Op* op, Symbol Symbol::CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string> attrs) { std::unordered_map<std::string, std::string> attrs) {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol s; Symbol s;
NodePtr n = Node::Create(); NodePtr n = Node::Create();
n->attrs.op = op; n->attrs.op = op;
...@@ -474,7 +478,14 @@ Symbol Symbol::CreateFunctor(const Op* op, ...@@ -474,7 +478,14 @@ Symbol Symbol::CreateFunctor(const Op* op,
if (n->op()->attr_parser != nullptr) { if (n->op()->attr_parser != nullptr) {
n->op()->attr_parser(&(n->attrs)); n->op()->attr_parser(&(n->attrs));
} }
s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0});
uint32_t nout = n->num_outputs();
if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
s.outputs.emplace_back(NodeEntry{n, i, 0});
}
return s; return s;
} }
......
...@@ -12,7 +12,6 @@ namespace nnvm { ...@@ -12,7 +12,6 @@ namespace nnvm {
namespace pass { namespace pass {
namespace { namespace {
// simply logic to place device according to device_group hint // simply logic to place device according to device_group hint
// insert copy node when there is // insert copy node when there is
Graph PlaceDevice(Graph src) { Graph PlaceDevice(Graph src) {
......
...@@ -142,11 +142,11 @@ Graph PlanMemory(Graph ret) { ...@@ -142,11 +142,11 @@ Graph PlanMemory(Graph ret) {
// step 1: initialize reference count // step 1: initialize reference count
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
for (const auto& e : idx[nid].inputs) { for (const auto& e : idx[nid].inputs) {
++ref_count[e.node_id]; ++ref_count[idx.entry_id(e)];
} }
} }
for (const auto& e : idx.outputs()) { for (const auto& e : idx.outputs()) {
++ref_count[e.node_id]; ++ref_count[idx.entry_id(e)];
} }
// step 2: allocate memory. // step 2: allocate memory.
StorageVector storage(idx.num_node_entries(), -1); StorageVector storage(idx.num_node_entries(), -1);
...@@ -202,10 +202,13 @@ Graph PlanMemory(Graph ret) { ...@@ -202,10 +202,13 @@ Graph PlanMemory(Graph ret) {
} }
} }
// check if there are outputs that can be freeded immediately // check if there are outputs that can be freeded immediately
// these output are not referenced by any operator.
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index); uint32_t eid = idx.entry_id(nid, index);
if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) { if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) {
allocator.Release(storage[eid], nid); allocator.Release(storage[eid], nid);
// use -2 to indicate that the node was never touched.
storage_inplace_index[eid] = -2;
} }
if (storage[eid] == GraphAllocator::kBadStorageID) { if (storage[eid] == GraphAllocator::kBadStorageID) {
++num_not_allocated; ++num_not_allocated;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment