Commit 50ddb76b by ziheng Committed by GitHub

[PASS] Improve graph fusion (#286)

* [PASS] Improve graph fusion

* Change fusion center to segment head

* Use 'master' to identity the schedule node

* Make things compact

* Fix
parent 7e82eb61
...@@ -21,7 +21,7 @@ using nnvm::IndexedGraph; ...@@ -21,7 +21,7 @@ using nnvm::IndexedGraph;
// The single fuse rule. // The single fuse rule.
enum class FuseRule { enum class FuseRule {
kUknown, kUknown,
kFuseToParent, kFuseToMaster,
kRealize kRealize
}; };
...@@ -57,10 +57,16 @@ nnvm::Graph GraphPartition(nnvm::Graph g) { ...@@ -57,10 +57,16 @@ nnvm::Graph GraphPartition(nnvm::Graph g) {
++ref_count[e.node_id]; ++ref_count[e.node_id];
} }
} }
for (const auto& e : idx.outputs()) {
// this line will realize all the outputs
ref_count[e.node_id] += 2;
}
// Pattern fo the subgraph // Pattern fo the subgraph
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kExtern); std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kExtern);
// Whether node can be fused to parent. // Whether node can be fused to parent.
std::vector<FuseRule> fuse_vec(idx.num_nodes(), FuseRule::kUknown); std::vector<FuseRule> fuse_vec(idx.num_nodes(), FuseRule::kUknown);
// Master node id of fusion segment.
std::vector<int> master_vec(idx.num_nodes(), -1);
// Operator pattern // Operator pattern
static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern"); static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");
...@@ -70,38 +76,58 @@ nnvm::Graph GraphPartition(nnvm::Graph g) { ...@@ -70,38 +76,58 @@ nnvm::Graph GraphPartition(nnvm::Graph g) {
fuse_vec[nid] = FuseRule::kRealize; continue; fuse_vec[nid] = FuseRule::kRealize; continue;
} }
TOpPattern pt = op_pattern.get(inode.source->op(), kExtern); TOpPattern pt = op_pattern.get(inode.source->op(), kExtern);
if (pt <= kBroadcast) { if (pt <= kBroadcast) {
// Looking for fusable bcast pattern int chosen_master = -1;
bool ewise = inode.source->num_outputs() == 1; bool ewise = inode.source->num_outputs() == 1;
for (const auto& e : inode.inputs) { for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) { if (fuse_vec[e.node_id] == FuseRule::kUknown) {
if (pattern_vec[e.node_id] == kBroadcast) { TOpPattern ipt = pattern_vec[e.node_id];
ewise = false; if (ipt != kElemWise) ewise = false;
fuse_vec[e.node_id] = FuseRule::kFuseToParent; if (ipt <= kBroadcast) {
} else if (pattern_vec[e.node_id] == kElemWise) { fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
fuse_vec[e.node_id] = FuseRule::kFuseToParent; } else if (ipt == kComplex && chosen_master == -1 &&
shape_vec[idx.entry_id(nid, 0)] == shape_vec[idx.entry_id(e)]) {
chosen_master = master_vec[e.node_id];
fuse_vec[e.node_id] = FuseRule::kFuseToMaster;
} else {
fuse_vec[e.node_id] = FuseRule::kRealize;
} }
} }
if (ewise) { if (ewise) {
TShape oshape = shape_vec[idx.entry_id(nid, 0)]; if (shape_vec[idx.entry_id(nid, 0)] != shape_vec[idx.entry_id(e)]) {
if (oshape != shape_vec[idx.entry_id(e)]) ewise = false; ewise = false;
}
} }
} }
pt = ewise ? kElemWise : kBroadcast; master_vec[nid] = chosen_master;
} else if (pt == kComplex) { if (chosen_master != -1) {
pt = kComplex;
} else {
pt = ewise ? kElemWise : kBroadcast;
}
} else {
master_vec[nid] = nid;
for (const auto& e : inode.inputs) { for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kUknown) { if (fuse_vec[e.node_id] == FuseRule::kUknown) {
if (pattern_vec[e.node_id] <= kBroadcast) { fuse_vec[e.node_id] = FuseRule::kRealize;
fuse_vec[e.node_id] = FuseRule::kFuseToParent; if (master_vec[e.node_id] == -1) {
master_vec[e.node_id] = e.node_id;
} }
} }
} }
} }
pattern_vec[nid] = pt; pattern_vec[nid] = pt;
if (ref_count[nid] > 1) { if (ref_count[nid] > 1) {
fuse_vec[nid] = FuseRule::kRealize; fuse_vec[nid] = FuseRule::kRealize;
if (master_vec[nid] == -1) {
master_vec[nid] = nid;
}
} }
} }
// point to the group root id of each node // point to the group root id of each node
std::vector<int> group_vec(idx.num_nodes(), -1); std::vector<int> group_vec(idx.num_nodes(), -1);
for (uint32_t i = idx.num_nodes(); i != 0; --i) { for (uint32_t i = idx.num_nodes(); i != 0; --i) {
...@@ -112,7 +138,7 @@ nnvm::Graph GraphPartition(nnvm::Graph g) { ...@@ -112,7 +138,7 @@ nnvm::Graph GraphPartition(nnvm::Graph g) {
} }
// propagate the group id. // propagate the group id.
for (const auto& e : inode.inputs) { for (const auto& e : inode.inputs) {
if (fuse_vec[e.node_id] == FuseRule::kFuseToParent) { if (fuse_vec[e.node_id] == FuseRule::kFuseToMaster) {
CHECK(group_vec[e.node_id] == -1|| CHECK(group_vec[e.node_id] == -1||
group_vec[e.node_id] == group_vec[nid]); group_vec[e.node_id] == group_vec[nid]);
group_vec[e.node_id] = group_vec[nid]; group_vec[e.node_id] = group_vec[nid];
...@@ -120,6 +146,7 @@ nnvm::Graph GraphPartition(nnvm::Graph g) { ...@@ -120,6 +146,7 @@ nnvm::Graph GraphPartition(nnvm::Graph g) {
} }
} }
g.attrs["group_root"] = std::make_shared<any>(std::move(group_vec)); g.attrs["group_root"] = std::make_shared<any>(std::move(group_vec));
g.attrs["group_master"] = std::make_shared<any>(std::move(master_vec));
g.attrs["pattern"] = std::make_shared<any>(std::move(pattern_vec)); g.attrs["pattern"] = std::make_shared<any>(std::move(pattern_vec));
g.attrs["dltype"] = std::make_shared<any>(std::move(dltype_vec)); g.attrs["dltype"] = std::make_shared<any>(std::move(dltype_vec));
return g; return g;
...@@ -172,6 +199,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { ...@@ -172,6 +199,7 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
const DLTypeVector& dltype_vec = g.GetAttr<DLTypeVector>("dltype"); const DLTypeVector& dltype_vec = g.GetAttr<DLTypeVector>("dltype");
const DTypeVector& dtype_vec = g.GetAttr<DTypeVector>("dtype"); const DTypeVector& dtype_vec = g.GetAttr<DTypeVector>("dtype");
const std::vector<int>& group_vec = g.GetAttr<std::vector<int> >("group_root"); const std::vector<int>& group_vec = g.GetAttr<std::vector<int> >("group_root");
const std::vector<int>& master_vec = g.GetAttr<std::vector<int> >("group_master");
const std::vector<TOpPattern>& pattern_vec = const std::vector<TOpPattern>& pattern_vec =
g.GetAttr<std::vector<TOpPattern> >("pattern"); g.GetAttr<std::vector<TOpPattern> >("pattern");
std::string target = g.GetAttr<std::string>("target"); std::string target = g.GetAttr<std::string>("target");
...@@ -239,15 +267,18 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { ...@@ -239,15 +267,18 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
Array<Tensor> out = fcompute[inode.source->op()]( Array<Tensor> out = fcompute[inode.source->op()](
inode.source->attrs, inputs); inode.source->attrs, inputs);
CHECK_EQ(out.size(), inode.source->num_outputs()); CHECK_EQ(out.size(), inode.source->num_outputs());
// schedule on root node, and use master's schedule
if (nid != root_id) { if (nid != root_id) {
for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) { for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
uint32_t eid = idx.entry_id(nid, index); uint32_t eid = idx.entry_id(nid, index);
tensor_vec[eid] = out[index]; tensor_vec[eid] = out[index];
} }
} else { } else {
// Work on schedule
fe.outputs = out; fe.outputs = out;
fe.schedule = fschedule[inode.source->op()]( int master = master_vec[root_id];
CHECK_GE(master, 0);
fe.schedule = fschedule[idx[master].source->op()](
inode.source->attrs, fe.outputs, target); inode.source->attrs, fe.outputs, target);
std::ostringstream os; std::ostringstream os;
os << inode.source->attrs.name + "_id" << nid; os << inode.source->attrs.name + "_id" << nid;
...@@ -307,10 +338,12 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { ...@@ -307,10 +338,12 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
old_new[nid] = np; old_new[nid] = np;
} }
} }
nnvm::Graph ret; nnvm::Graph ret;
for (const auto& e : idx.outputs()) { for (const auto& e : idx.outputs()) {
auto it = old_new.find(e.node_id); auto it = old_new.find(group_vec[e.node_id]);
CHECK(it != old_new.end()); CHECK(it != old_new.end())
<< "cannot find node_id=" << e.node_id;
ret.outputs.emplace_back( ret.outputs.emplace_back(
nnvm::NodeEntry{it->second, e.index, e.version}); nnvm::NodeEntry{it->second, e.index, e.version});
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment