Commit 79ceb9f7 by Tianqi Chen

[PASS] PrintGraphIR Join attributes when print ir (#20)

parent c829bd86
...@@ -177,8 +177,23 @@ class Graph(object): ...@@ -177,8 +177,23 @@ class Graph(object):
self._index = GraphIndex(self) self._index = GraphIndex(self)
return self._index return self._index
def graphir(self): def ir(self, join_entry_attrs=None, join_node_attrs=None):
"""Get text form of graph ir.""" """Get text form of graph ir.
Parameters
----------
join_entry_attrs : list of str
List of graph NodeEntry attribute to be
printed along each operator.
join_node_attrs : list of str
List of graph node attribute to be
printed along each operator.
"""
if join_entry_attrs:
self._set_json_attr("join_entry_attrs", join_entry_attrs, "list_str")
if join_node_attrs:
self._set_json_attr("join_node_attrs", join_node_attrs, "list_str")
return self.apply("PrintGraphIR").json_attr("graphir") return self.apply("PrintGraphIR").json_attr("graphir")
def apply(self, passes): def apply(self, passes):
......
...@@ -67,6 +67,8 @@ Graph InferAttr(Graph &&ret, ...@@ -67,6 +67,8 @@ Graph InferAttr(Graph &&ret,
shape_attr_key = ret.GetAttr<std::string>(attr_key_name); shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
// erase the provided arguments // erase the provided arguments
ret.attrs.erase(attr_key_name); ret.attrs.erase(attr_key_name);
} else {
shape_attr_key = attr_name;
} }
// Temp space for shape inference. // Temp space for shape inference.
std::vector<AttrType> ishape, oshape; std::vector<AttrType> ishape, oshape;
......
...@@ -5,14 +5,80 @@ ...@@ -5,14 +5,80 @@
*/ */
#include <nnvm/graph.h> #include <nnvm/graph.h>
#include <nnvm/pass.h> #include <nnvm/pass.h>
#include <nnvm/tuple.h>
#include <iostream> #include <iostream>
namespace nnvm { namespace nnvm {
namespace pass { namespace pass {
using AttrPrinter = std::function<void(uint32_t index, std::ostream& os)>; // NOLINT(*)
template<typename T>
AttrPrinter GetVectorPrinter_(const T& vec) {
return [&vec](uint32_t index, std::ostream& os) { // NOLINT(*)
os << vec[index];
};
}
AttrPrinter GetVectorPrinter(const Graph& graph,
const std::string& key) {
auto it = graph.attrs.find(key);
CHECK(it != graph.attrs.end())
<< "Cannot find " << key << " in graph attr";
const any& value = *(it->second);
if (value.type() == typeid(std::vector<TShape>)) {
return GetVectorPrinter_(
nnvm::get<std::vector<TShape> >(value));
} else if (value.type() == typeid(std::vector<int>)) {
return GetVectorPrinter_(
nnvm::get<std::vector<int> >(value));
} else if (value.type() == typeid(std::vector<std::string>)) {
return GetVectorPrinter_(
nnvm::get<std::vector<std::string> >(value));
} else {
LOG(FATAL) << "Cannot handle type " << value.type().name();
return nullptr;
}
}
// print the graph ir in readable format // print the graph ir in readable format
void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*) void PrintGraphIR_(Graph src,
const std::vector<std::string>& join_entry_attrs,
const std::vector<std::string>& join_node_attrs,
std::ostream& os) { // NOLINT(*)
const IndexedGraph& idx = src.indexed_graph(); const IndexedGraph& idx = src.indexed_graph();
std::vector<std::function<void(uint32_t, std::ostream&)> > trigger; // NOLINT(*)
for (const std::string& key : join_entry_attrs) {
AttrPrinter fp = GetVectorPrinter(src, key);
auto fprint = [&idx, key, fp](
uint32_t nid, std::ostream& os) { // NOLINT(*)
const IndexedGraph::Node& inode = idx[nid];
os << ", " << key << "=";
if (inode.source->num_outputs() != 1) {
os << '[';
for (uint32_t i = 0; i < inode.source->num_outputs(); ++i) {
if (i != 0) os << ", ";
fp(idx.entry_id(nid, i), os);
}
os << ']';
} else {
fp(idx.entry_id(nid, 0), os);
}
};
trigger.push_back(fprint);
}
for (const std::string& key : join_node_attrs) {
AttrPrinter fp = GetVectorPrinter(src, key);
auto fprint = [&idx, key, fp](
uint32_t nid, std::ostream& os) { // NOLINT(*)
os << key << "=";
fp(idx.entry_id(nid, 0), os);
};
trigger.push_back(fprint);
}
os << "Graph("; os << "Graph(";
if (idx.input_nodes().size() < 4) { if (idx.input_nodes().size() < 4) {
for (size_t i = 0; i < idx.input_nodes().size(); ++i) { for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
...@@ -79,6 +145,10 @@ void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*) ...@@ -79,6 +145,10 @@ void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*)
} }
os << "]"; os << "]";
} }
// additional attribute trigger
for (const auto& fp : trigger) {
fp(nid, os);
}
os << "\n"; os << "\n";
} }
os << " ret "; os << " ret ";
...@@ -112,7 +182,16 @@ void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*) ...@@ -112,7 +182,16 @@ void PrintGraphIR_(Graph src, std::ostream& os) { // NOLINT(*)
// save a graph to json // save a graph to json
Graph PrintGraphIR(Graph src) { Graph PrintGraphIR(Graph src) {
std::ostringstream os; std::ostringstream os;
PrintGraphIR_(src, os); std::vector<std::string> join_entry_attrs, join_node_attrs;
if (src.attrs.count("join_entry_attrs") != 0) {
join_entry_attrs = src.MoveCopyAttr<std::vector<std::string> >(
"join_entry_attrs");
}
if (src.attrs.count("join_node_attrs") != 0) {
join_node_attrs = src.MoveCopyAttr<std::vector<std::string> >(
"join_node_attrs");
}
PrintGraphIR_(src, join_entry_attrs, join_node_attrs, os);
Graph ret; Graph ret;
ret.attrs["graphir"] = std::make_shared<any>(os.str()); ret.attrs["graphir"] = std::make_shared<any>(os.str());
return ret; return ret;
......
...@@ -38,7 +38,7 @@ def test_simplify_batchnorm(): ...@@ -38,7 +38,7 @@ def test_simplify_batchnorm():
graph_attr.set_shape_inputs(g, ishape) graph_attr.set_shape_inputs(g, ishape)
g1 = g.apply("InferShape").apply("SimplifyBatchNormInference") g1 = g.apply("InferShape").apply("SimplifyBatchNormInference")
# Some prints for debug # Some prints for debug
# print(g1.graphir()) # print(g1.ir())
# assert graph equals as expected # assert graph equals as expected
graph_pass.check_graph_equal(g1, g2) graph_pass.check_graph_equal(g1, g2)
......
...@@ -99,8 +99,19 @@ def test_plan_memory(): ...@@ -99,8 +99,19 @@ def test_plan_memory():
assert (storage_id[jnode_row_ptr[nindex["add2"]]] == assert (storage_id[jnode_row_ptr[nindex["add2"]]] ==
storage_id[jnode_row_ptr[nindex["reshapek"]]]) storage_id[jnode_row_ptr[nindex["reshapek"]]])
def test_print_graph_ir():
x = sym.Variable("x", shape=(1, 1, 10, 20))
y = sym.conv2d(x + 1, name="y", channels=10, kernel_size=(3,3))
g = graph.create(y)
g = g.apply("InferShape")
ir1 = g.ir()
ir2 = g.ir(join_entry_attrs=["shape"])
assert("y_bias" in ir1)
assert("shape=" in ir2)
if __name__ == "__main__": if __name__ == "__main__":
test_print_graph_ir()
test_json_pass_with_attr() test_json_pass_with_attr()
test_graph_json_attr() test_graph_json_attr()
test_json_pass() test_json_pass()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment