Commit 619e529a by ziheng Committed by Tianqi Chen

Enable identity layout (#238)

parent eb3fc6c6
...@@ -388,8 +388,13 @@ NNVM_REGISTER_OP(tvm_op) ...@@ -388,8 +388,13 @@ NNVM_REGISTER_OP(tvm_op)
return param.num_outputs; return param.num_outputs;
}); });
inline bool IsIdentityLayout(const LayoutInfo& layout) {
if (layout.src == "" && layout.dst == "") return true;
return false;
}
inline bool IsPair(LayoutInfo in, LayoutInfo out) { inline bool IsPairedLayouts(const LayoutInfo& in,
const LayoutInfo& out) {
if (in.src == out.dst && in.dst == out.src) return true; if (in.src == out.dst && in.dst == out.src) return true;
return false; return false;
} }
...@@ -399,7 +404,8 @@ inline LayoutInfo GetLayout(const nnvm::OpMap<FTVMLayoutInfo>& layouts, ...@@ -399,7 +404,8 @@ inline LayoutInfo GetLayout(const nnvm::OpMap<FTVMLayoutInfo>& layouts,
return layouts[n->op()](n->attrs)[idx]; return layouts[n->op()](n->attrs)[idx];
} }
nnvm::NodePtr CreateLayoutTransformNode(std::string src, std::string dst) { nnvm::NodePtr CreateLayoutTransformNode(const std::string& src,
const std::string& dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform"); static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform");
static int count = 0; static int count = 0;
nnvm::NodePtr n = nnvm::Node::Create(); nnvm::NodePtr n = nnvm::Node::Create();
...@@ -434,10 +440,13 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { ...@@ -434,10 +440,13 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
if (olayouts.count(n->op())) { if (olayouts.count(n->op())) {
std::vector<nnvm::NodePtr> tnodes(n->num_outputs(), nullptr); std::vector<nnvm::NodePtr> tnodes(n->num_outputs(), nullptr);
std::vector<LayoutInfo> layouts = olayouts[n->op()](n->attrs);
for (uint32_t i = 0; i < n->num_outputs(); ++i) { for (uint32_t i = 0; i < n->num_outputs(); ++i) {
LayoutInfo layout = GetLayout(olayouts, n, i); const LayoutInfo& layout = layouts[i];
tnodes[i] = CreateLayoutTransformNode(layout.src, layout.dst); if (!IsIdentityLayout(layout)) {
tnodes[i]->inputs.emplace_back(nnvm::NodeEntry{new_node, i, 0}); tnodes[i] = CreateLayoutTransformNode(layout.src, layout.dst);
tnodes[i]->inputs.emplace_back(nnvm::NodeEntry{new_node, i, 0});
}
} }
transformed.emplace(n.get(), std::move(tnodes)); transformed.emplace(n.get(), std::move(tnodes));
} }
...@@ -451,25 +460,29 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { ...@@ -451,25 +460,29 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
bool otrans = olayouts.count(in->op()); bool otrans = olayouts.count(in->op());
bool itrans = ilayouts.count(n->op()); bool itrans = ilayouts.count(n->op());
if (otrans && itrans) { if (otrans && itrans) {
LayoutInfo olayout = GetLayout(olayouts, in, e.index);
LayoutInfo ilayout = GetLayout(ilayouts, n, idx); LayoutInfo ilayout = GetLayout(ilayouts, n, idx);
if (IsPair(olayout, ilayout)) { LayoutInfo olayout = GetLayout(olayouts, in, e.index);
if (IsPairedLayouts(olayout, ilayout)) {
continue; continue;
} }
} }
if (otrans) { if (otrans) {
const auto& tnodes = transformed.at(in.get()); nnvm::NodePtr tnode = transformed.at(in.get())[e.index];
new_node->inputs[idx] = if (tnode.get()) {
nnvm::NodeEntry{tnodes[e.index], 0, 0}; new_node->inputs[idx] =
nnvm::NodeEntry{tnode, 0, 0};
}
} }
if (itrans) { if (itrans) {
LayoutInfo layout = GetLayout(ilayouts, n, idx); LayoutInfo layout = GetLayout(ilayouts, n, idx);
nnvm::NodePtr tnode = if (!IsIdentityLayout(layout)) {
CreateLayoutTransformNode(layout.src, layout.dst); nnvm::NodePtr tnode =
tnode->inputs.emplace_back(new_node->inputs[idx]); CreateLayoutTransformNode(layout.src, layout.dst);
new_node->inputs[idx] = nnvm::NodeEntry{tnode, 0, 0}; tnode->inputs.emplace_back(new_node->inputs[idx]);
new_node->inputs[idx] = nnvm::NodeEntry{tnode, 0, 0};
}
} }
} }
mirror_map[n.get()] = std::move(new_node); mirror_map[n.get()] = std::move(new_node);
...@@ -478,12 +491,14 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { ...@@ -478,12 +491,14 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
std::vector<nnvm::NodeEntry> outputs; std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : src.outputs) { for (const auto& e : src.outputs) {
if (olayouts.count(e.node->op())) { if (olayouts.count(e.node->op())) {
const auto& tnodes = transformed.at(e.node.get()); nnvm::NodePtr tnode = transformed.at(e.node.get())[e.index];
outputs.emplace_back(nnvm::NodeEntry{tnodes[e.index], 0, 0}); if (tnode.get()) {
} else { outputs.emplace_back(nnvm::NodeEntry{tnode, 0, 0});
outputs.emplace_back( continue;
nnvm::NodeEntry{mirror_map.at(e.node.get()), e.index, e.version}); }
} }
outputs.emplace_back(
nnvm::NodeEntry{mirror_map.at(e.node.get()), e.index, e.version});
} }
nnvm::Graph ret; nnvm::Graph ret;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment