Commit 558cf098 by Da Zheng Committed by Eric Junyuan Xie

add support for subgraphs. (#1221)

* add support for subgraphs.

* fix.

* fix.

* Fix compilation error

* Fix compilation error

* add comments.

* update comments.

* Sanity check on subgraphs when creating IndexedGraph

* avoid the overhead of sanity check.

* Stop using non-recursive DFS

* Trigger CI

* trigger CI
parent f3f406ab
......@@ -18,6 +18,7 @@ namespace nnvm {
// Forward declare node.
class Node;
class Symbol;
/*!
* \brief we always used NodePtr for a reference pointer
......@@ -90,6 +91,21 @@ struct NodeAttrs {
* The object can be used to quickly access attributes.
*/
any parsed;
/*!
* \brief Some operators take graphs as input. These operators include
* control flow operators and high-order functions.
* These graphs don't change when the operators are invoked for different
* mini-batches. In this sense, the subgraphs are kind of similar to
* the parameters and show be kept as node attributes.
*
* Users need to make sure the subgraphs are disjoint with the main graph.
* If a graph shares nodes with subgraphs, loading the graph from LoadJSON
* may generate a graph that has a different structure from the original graph
* (some of the nodes are duplicated). If nodes are shared between two graphs,
* shared nodes might be executed multiple times, which can be a problem for
* stateful operators.
*/
std::vector<std::shared_ptr<Symbol> > subgraphs;
};
/*!
......
......@@ -202,6 +202,18 @@ using FCorrectLayout = std::function<bool(
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts)>;
/*!
* \brief Get a list of inputs that represent graphs instead of data.
* Normally, input symbols are considered as data to the operator. However,
* control flow operators and high-order functions need to interpret symbols
* as graphs.
* \param attrs The attributes of this node.
* \return a list of input index that are interpreted as symbols by the operator.
*
* \note Register under "FInputGraph".
*/
using FInputGraph = std::function<std::vector<uint32_t>(const NodeAttrs& attrs)>;
} // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_
......@@ -16,15 +16,51 @@ const IndexedGraph& Graph::indexed_graph() const {
return *indexed_graph_;
}
// a subgraph should not refer to any nodes with higher level
// where "level" refers to the nested depth of the subgraph
// e.g. the main graph is level 0
// subgraphs of the main graph is level 1
// subgraphs of the subgraphs of the main graph is level 2
static void SubgraphSanityCheck(const std::vector<std::shared_ptr<Symbol>> &subgraphs) {
std::vector<const std::vector<nnvm::NodeEntry>*> curr_level;
std::vector<const std::vector<nnvm::NodeEntry>*> next_level;
std::unordered_map<nnvm::Node*, uint32_t> node2level;
for (auto &subgraph : subgraphs)
next_level.push_back(&subgraph->outputs);
for (uint32_t level = 0; !next_level.empty(); ++level) {
curr_level.swap(next_level);
next_level.clear();
for (const std::vector<NodeEntry> *graph_ptr : curr_level) {
const std::vector<NodeEntry> &graph = *graph_ptr;
DFSVisit(graph, [&next_level, &node2level, level](const NodePtr& n) {
nnvm::Node *node = n.get();
// if the node is visited, but on a different level, then check failed
// if check failed here or before, we stop doing anything, but raise an error
CHECK(!node2level.count(node) || node2level[node] == level)
<< "A subgraph should not depend on the outputs of nodes on higher levels";
// otherwise, this node belongs to the current level
node2level[node] = level;
// subgraphs of current node belongs to next level
for (const auto& subgraph : n->attrs.subgraphs) {
next_level.push_back(&subgraph->outputs);
}
});
}
}
}
// implement constructor from graph
IndexedGraph::IndexedGraph(const Graph &g) {
entry_rptr_.push_back(0);
std::vector<size_t> inputs_rptr{0}, control_rptr{0};
std::vector<std::shared_ptr<Symbol>> subgraphs;
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr]
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr, &subgraphs]
(const NodePtr& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
for (const auto &subgraph : n->attrs.subgraphs)
subgraphs.push_back(subgraph);
// nodes_
IndexedGraph::Node new_node;
new_node.source = n.get();
......@@ -53,6 +89,8 @@ IndexedGraph::IndexedGraph(const Graph &g) {
}
control_rptr.push_back(control_deps_.size());
});
if (!subgraphs.empty())
SubgraphSanityCheck(subgraphs);
for (const auto& e : g.outputs) {
outputs_.emplace_back(NodeEntry{
......
......@@ -267,14 +267,36 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");
static auto& fgraph = Op::GetAttr<FInputGraph>("FInputGraph");
// The arguments that contain graphs.
Node* n = outputs[0].node.get();
FInputGraph fng = fgraph.get(n->op(), nullptr);
std::vector<uint32_t> garg_idx;
if (fng != nullptr)
garg_idx = fng(n->attrs);
// The names of the arguments that contain graphs.
FListInputNames name_fn = flist_inputs.get(n->op(), nullptr);
auto arg_names = (name_fn == nullptr) ? std::vector<std::string>{"data"} : name_fn(n->attrs);
std::vector<std::string> garg_names(garg_idx.size());
for (size_t i = 0; i < garg_idx.size(); i++) {
size_t idx = garg_idx[i];
if (idx < arg_names.size())
garg_names[i] = arg_names[idx];
}
// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i]->outputs.size(), 1U)
// If the argument isn't a graph, it should have only one output.
if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end())
CHECK_EQ(args[i]->outputs.size(), 1U)
<< "Argument " << i << " is a tuple, single value is required";
}
for (const auto& kv : kwargs) {
CHECK_EQ(kv.second->outputs.size(), 1U)
if (garg_names.empty()
|| std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end())
CHECK_EQ(kv.second->outputs.size(), 1U)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
}
// assign new name
......@@ -282,28 +304,49 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
// Atomic functor composition.
if (IsAtomic(outputs)) {
Node* n = outputs[0].node.get();
uint32_t n_req = n->num_inputs();
std::vector<const Symbol *> arg_vec(args.begin(), args.end());
std::unordered_map<std::string, const Symbol*> kwarg_map(kwargs.begin(), kwargs.end());
// If one of the input arguments is a graph, we need to remove it from the
// list.
if (fng != nullptr) {
std::vector<uint32_t> idxes = fng(n->attrs);
for (auto idx : idxes) {
const Symbol *sym;
if (idx < arg_vec.size()) {
sym = arg_vec[idx];
arg_vec.erase(arg_vec.begin() + idx);
} else {
auto it = kwarg_map.find(arg_names[idx]);
CHECK(it != kwarg_map.end());
sym = it->second;
kwarg_map.erase(it);
}
if (n_req != kVarg)
n_req--;
arg_names.erase(arg_names.begin() + idx);
n->attrs.subgraphs.push_back(std::make_shared<Symbol>(*sym));
}
}
if (n_req != kVarg) {
n->inputs.resize(n_req);
CHECK_LE(args.size(), n_req)
CHECK_LE(arg_vec.size(), n_req)
<< "Incorrect number of arguments, requires " << n_req
<< ", provided " << args.size();
for (size_t i = 0; i < args.size(); ++i) {
n->inputs[i] = args[i]->outputs[0];
<< ", provided " << arg_vec.size();
for (size_t i = 0; i < arg_vec.size(); ++i) {
n->inputs[i] = arg_vec[i]->outputs[0];
}
// switch to keyword argument matching
if (args.size() != n_req) {
FListInputNames fn = flist_inputs.get(n->op(), nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
if (arg_vec.size() != n_req) {
if (arg_names.size() != n_req) {
LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name;
}
size_t nmatched = 0;
for (size_t i = args.size(); i < n_req; ++i) {
auto it = kwargs.find(arg_names[i]);
if (it != kwargs.end() && it->first == arg_names[i]) {
for (size_t i = arg_vec.size(); i < n_req; ++i) {
auto it = kwarg_map.find(arg_names[i]);
if (it != kwarg_map.end() && it->first == arg_names[i]) {
n->inputs[i] = it->second->outputs[0];
++nmatched;
} else {
......@@ -314,18 +357,18 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
}
}
if (nmatched != kwargs.size()) {
if (nmatched != kwarg_map.size()) {
n->inputs.clear();
std::vector<std::string> keys = GetKeys(kwargs);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + args.size(),
std::vector<std::string> keys = GetKeys(kwarg_map);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_vec.size(),
dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, view);
}
}
} else {
CHECK_EQ(kwargs.size(), 0U) << "Variable length function do not accept kwargs";
n->inputs.reserve(args.size());
for (const Symbol* s : args) {
CHECK_EQ(kwarg_map.size(), 0U) << "Variable length function do not accept kwargs";
n->inputs.reserve(arg_vec.size());
for (const Symbol* s : arg_vec) {
n->inputs.push_back(s->outputs[0]);
}
}
......
......@@ -29,6 +29,11 @@ namespace nnvm {
namespace pass {
namespace {
// JSONNode represents an nnvm::Node in JSON
struct JSONNode;
// JSONGraph represents an nnvm::Graph or nnvm::Symbol in JSON
struct JSONGraph;
// auxiliary node structure for serialization.
struct JSONNode {
// the node entry structure in serialized format
......@@ -36,6 +41,10 @@ struct JSONNode {
uint32_t node_id;
uint32_t index;
uint32_t version;
Entry() = default;
Entry(uint32_t node_id, uint32_t index, uint32_t version):
node_id(node_id), index(index), version(version) {
}
void Save(dmlc::JSONWriter *writer) const {
writer->BeginArray(false);
writer->WriteArrayItem(node_id);
......@@ -64,6 +73,8 @@ struct JSONNode {
std::vector<Entry> inputs;
// control flow dependencies
std::vector<uint32_t> control_deps;
// subgraphs
std::vector<JSONGraph> subgraphs;
// function to save JSON node.
void Save(dmlc::JSONWriter *writer) const {
......@@ -85,6 +96,9 @@ struct JSONNode {
if (control_deps.size() != 0) {
writer->WriteObjectKeyValue("control_deps", control_deps);
}
if (subgraphs.size() != 0) {
writer->WriteObjectKeyValue("subgraphs", subgraphs);
}
writer->EndObject();
}
......@@ -99,6 +113,7 @@ struct JSONNode {
helper.DeclareOptionalField("attrs", &(node->attrs.dict));
helper.DeclareOptionalField("attr", &(node->attrs.dict));
helper.DeclareOptionalField("control_deps", &control_deps);
helper.DeclareOptionalField("subgraphs", &subgraphs);
// backward compatible code with mxnet graph.
int backward_source_id;
std::unordered_map<std::string, std::string> param;
......@@ -154,86 +169,107 @@ struct JSONGraph {
}
};
// Load a graph from JSON file.
Graph LoadJSON(Graph src) {
CHECK_NE(src.attrs.count("json"), 0U)
<< "Load JSON require json to be presented.";
const std::string &json_str =
nnvm::get<std::string>(*src.attrs.at("json"));
bool no_parse = false;
if (src.attrs.count("load_json_no_parse")) {
no_parse = nnvm::get<bool>(*src.attrs.at("load_json_no_parse"));
void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) {
std::unordered_map<Node*, uint32_t> node2index;
jgraph->node_row_ptr.push_back(0);
DFSVisit(src->outputs, [&node2index, jgraph](const NodePtr& n) {
uint32_t nid = static_cast<uint32_t>(jgraph->nodes.size());
node2index[n.get()] = nid;
if (n->is_variable()) {
jgraph->arg_nodes.push_back(nid);
}
JSONNode jnode;
jnode.node = n;
jnode.inputs.reserve(n->inputs.size());
for (const NodeEntry& e : n->inputs) {
jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version);
}
for (const NodePtr& c : n->control_deps) {
jnode.control_deps.push_back(node2index.at(c.get()));
}
jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs());
jgraph->nodes.emplace_back(std::move(jnode));
});
for (const NodeEntry& e : src->outputs) {
jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version);
}
std::istringstream is(json_str);
dmlc::JSONReader reader(&is);
JSONGraph jgraph;
// load in json graph.
jgraph.Load(&reader);
// connects the nodes
for (JSONNode &n : jgraph.nodes) {
// recursively construct subgraphs
for (JSONNode &jnode : jgraph->nodes) {
// construct jnode's subgraphs
const std::vector<std::shared_ptr<Symbol>> &subgraphs = jnode.node->attrs.subgraphs;
std::vector<JSONGraph> &jsubgraphs = jnode.subgraphs;
jsubgraphs.resize(subgraphs.size());
for (uint32_t i = 0; i < subgraphs.size(); ++i) {
Symbol2JSONGraph(subgraphs[i], &jsubgraphs[i]);
}
}
}
std::shared_ptr<Symbol> JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) {
for (const JSONNode &n : jgraph.nodes) {
n.node->inputs.reserve(n.inputs.size());
for (const JSONNode::Entry &e : n.inputs) {
n.node->inputs.emplace_back(
NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
n.node->inputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
}
n.node->control_deps.reserve(n.control_deps.size());
for (uint32_t nid : n.control_deps) {
n.node->control_deps.push_back(jgraph.nodes[nid].node);
}
// rebuild attribute parser
if (!no_parse && n.node->op() != nullptr &&
n.node->op()->attr_parser != nullptr) {
if (!no_parse && n.node->op() != nullptr && n.node->op()->attr_parser != nullptr) {
n.node->op()->attr_parser(&(n.node->attrs));
}
for (const JSONGraph &subgraph : n.subgraphs) {
// The "no_parse" option here, is to be compatible with
// commit cfd3075e85807dcd8f9534c37e053583dee87524
// (https://github.com/apache/incubator-mxnet/tree/cfd3075e85807dcd8f9534c37e053583dee87524),
// where the parsing of main graph is deferred until
// incubator-mxnet/src/nnvm/legacy_json_util.cc:UpgradeJSON_Parse
n.node->attrs.subgraphs.push_back(JSONGraph2Symbol(subgraph, false));
}
}
// consistent check
// consistency check
for (uint32_t nid : jgraph.arg_nodes) {
CHECK(jgraph.nodes[nid].node->is_variable());
}
std::shared_ptr<Symbol> symbol = std::make_shared<Symbol>();
symbol->outputs.reserve(jgraph.heads.size());
for (const JSONNode::Entry &e : jgraph.heads) {
symbol->outputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
}
return symbol;
}
// Load a graph from JSON file.
Graph LoadJSON(Graph src) {
CHECK_NE(src.attrs.count("json"), 0U)
<< "Load JSON require json to be presented.";
const std::string &json_str =
nnvm::get<std::string>(*src.attrs.at("json"));
bool no_parse = false;
if (src.attrs.count("load_json_no_parse")) {
no_parse = nnvm::get<bool>(*src.attrs.at("load_json_no_parse"));
}
std::istringstream is(json_str);
dmlc::JSONReader reader(&is);
JSONGraph jgraph;
// load in json graph.
jgraph.Load(&reader);
std::shared_ptr<Symbol> symbol = JSONGraph2Symbol(jgraph, no_parse);
// return the graph
Graph ret;
ret.attrs = std::move(jgraph.attrs);
ret.outputs.reserve(jgraph.heads.size());
for (const JSONNode::Entry &e : jgraph.heads) {
ret.outputs.emplace_back(
NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
}
ret.outputs = symbol->outputs;
return ret;
}
// save a graph to json
Graph SaveJSON(Graph src) {
std::shared_ptr<Symbol> src_symbol = std::make_shared<Symbol>();
src_symbol->outputs = src.outputs;
JSONGraph jgraph;
Symbol2JSONGraph(src_symbol, &jgraph);
jgraph.attrs = src.attrs;
std::unordered_map<Node*, uint32_t> node2index;
jgraph.node_row_ptr.push_back(0);
DFSVisit(src.outputs, [&node2index, &jgraph](const NodePtr& n) {
uint32_t nid = static_cast<uint32_t>(jgraph.nodes.size());
node2index[n.get()] = nid;
if (n->is_variable()) {
jgraph.arg_nodes.push_back(nid);
}
JSONNode jnode;
jnode.node = n;
jnode.inputs.reserve(n->inputs.size());
for (const NodeEntry& e : n->inputs) {
jnode.inputs.emplace_back(
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
}
for (const NodePtr& c : n->control_deps) {
jnode.control_deps.push_back(node2index.at(c.get()));
}
jgraph.node_row_ptr.push_back(
jgraph.node_row_ptr.back() + n->num_outputs());
jgraph.nodes.emplace_back(std::move(jnode));
});
for (const NodeEntry& e : src.outputs) {
jgraph.heads.push_back(
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
}
std::ostringstream os;
dmlc::JSONWriter writer(&os);
jgraph.Save(&writer);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment