Commit 9a4e1339 by Tianqi Chen

Bugfix plan memory, fully support mxnet executor (#32)

* [PASS] include knullop info in plan memory

* Bugfix plan memory, fully support mxnet
parent 7c3d18c6
......@@ -31,6 +31,18 @@ namespace nnvm {
using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
/*!
* \brief Return number of visible outputs by the user.
*
* \param attrs The attributes of the node.
*
* \note Register under "FNumVisibleOutputs", default not registered.
* This can be used to hide certain output from the user,
* but the additional outputs can be used to pass information from
* forward to gradient pass.
*/
using FNumVisibleOutputs = std::function<uint32_t (const NodeAttrs& attrs)>;
/*!
* \brief Return list of output arguments names of each operator.
*
* \param attrs The attributes of the node.
......
......@@ -87,7 +87,7 @@ inline std::vector<std::string> GetKeys(
// whether the symbol is atomic functor
inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
return outputs.size() == 1 && outputs[0].node->inputs.size() == 0;
return outputs[0].node->inputs.size() == 0;
}
// public functions
......@@ -222,6 +222,7 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
std::vector<std::string> Symbol::ListOutputNames() const {
static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
std::vector<std::string> ret;
for (auto &head : outputs) {
if (head.node->is_variable()) {
......@@ -256,8 +257,6 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
CHECK_EQ(outputs.size(), 1)
<< "Only composition of value function is supported currently";
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
......@@ -400,6 +399,7 @@ void Symbol::AddControlDeps(const Symbol& src) {
}
Symbol Symbol::GetInternals() const {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol ret;
DFSVisit(this->outputs, [&ret](const NodePtr& node) {
Node* n = node.get();
......@@ -409,6 +409,9 @@ Symbol Symbol::GetInternals() const {
ret.outputs.emplace_back(NodeEntry{node, 0, param.version});
} else {
uint32_t nout = n->num_outputs();
if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
ret.outputs.emplace_back(NodeEntry{node, i, 0});
}
......@@ -467,6 +470,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op
Symbol Symbol::CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string> attrs) {
static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
Symbol s;
NodePtr n = Node::Create();
n->attrs.op = op;
......@@ -474,7 +478,14 @@ Symbol Symbol::CreateFunctor(const Op* op,
if (n->op()->attr_parser != nullptr) {
n->op()->attr_parser(&(n->attrs));
}
s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0});
uint32_t nout = n->num_outputs();
if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
s.outputs.emplace_back(NodeEntry{n, i, 0});
}
return s;
}
......
......@@ -12,7 +12,6 @@ namespace nnvm {
namespace pass {
namespace {
// simply logic to place device according to device_group hint
// insert copy node when there is
Graph PlaceDevice(Graph src) {
......
......@@ -142,11 +142,11 @@ Graph PlanMemory(Graph ret) {
// step 1: initialize reference count
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
for (const auto& e : idx[nid].inputs) {
++ref_count[e.node_id];
++ref_count[idx.entry_id(e)];
}
}
for (const auto& e : idx.outputs()) {
++ref_count[e.node_id];
++ref_count[idx.entry_id(e)];
}
// step 2: allocate memory.
StorageVector storage(idx.num_node_entries(), -1);
......@@ -202,10 +202,13 @@ Graph PlanMemory(Graph ret) {
}
}
// check if there are outputs that can be freeded immediately
// these output are not referenced by any operator.
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index);
if (ref_count[eid] == 0 && storage[eid] != GraphAllocator::kBadStorageID) {
allocator.Release(storage[eid], nid);
// use -2 to indicate that the node was never touched.
storage_inplace_index[eid] = -2;
}
if (storage[eid] == GraphAllocator::kBadStorageID) {
++num_not_allocated;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment