Commit 619e529a by ziheng Committed by Tianqi Chen

Enable identity layout (#238)

parent eb3fc6c6
......@@ -388,8 +388,13 @@ NNVM_REGISTER_OP(tvm_op)
return param.num_outputs;
});
inline bool IsIdentityLayout(const LayoutInfo& layout) {
if (layout.src == "" && layout.dst == "") return true;
return false;
}
inline bool IsPair(LayoutInfo in, LayoutInfo out) {
inline bool IsPairedLayouts(const LayoutInfo& in,
const LayoutInfo& out) {
if (in.src == out.dst && in.dst == out.src) return true;
return false;
}
......@@ -399,7 +404,8 @@ inline LayoutInfo GetLayout(const nnvm::OpMap<FTVMLayoutInfo>& layouts,
return layouts[n->op()](n->attrs)[idx];
}
nnvm::NodePtr CreateLayoutTransformNode(std::string src, std::string dst) {
nnvm::NodePtr CreateLayoutTransformNode(const std::string& src,
const std::string& dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform");
static int count = 0;
nnvm::NodePtr n = nnvm::Node::Create();
......@@ -434,11 +440,14 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
if (olayouts.count(n->op())) {
std::vector<nnvm::NodePtr> tnodes(n->num_outputs(), nullptr);
std::vector<LayoutInfo> layouts = olayouts[n->op()](n->attrs);
for (uint32_t i = 0; i < n->num_outputs(); ++i) {
LayoutInfo layout = GetLayout(olayouts, n, i);
const LayoutInfo& layout = layouts[i];
if (!IsIdentityLayout(layout)) {
tnodes[i] = CreateLayoutTransformNode(layout.src, layout.dst);
tnodes[i]->inputs.emplace_back(nnvm::NodeEntry{new_node, i, 0});
}
}
transformed.emplace(n.get(), std::move(tnodes));
}
......@@ -451,40 +460,46 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
bool otrans = olayouts.count(in->op());
bool itrans = ilayouts.count(n->op());
if (otrans && itrans) {
LayoutInfo olayout = GetLayout(olayouts, in, e.index);
LayoutInfo ilayout = GetLayout(ilayouts, n, idx);
if (IsPair(olayout, ilayout)) {
LayoutInfo olayout = GetLayout(olayouts, in, e.index);
if (IsPairedLayouts(olayout, ilayout)) {
continue;
}
}
if (otrans) {
const auto& tnodes = transformed.at(in.get());
nnvm::NodePtr tnode = transformed.at(in.get())[e.index];
if (tnode.get()) {
new_node->inputs[idx] =
nnvm::NodeEntry{tnodes[e.index], 0, 0};
nnvm::NodeEntry{tnode, 0, 0};
}
}
if (itrans) {
LayoutInfo layout = GetLayout(ilayouts, n, idx);
if (!IsIdentityLayout(layout)) {
nnvm::NodePtr tnode =
CreateLayoutTransformNode(layout.src, layout.dst);
tnode->inputs.emplace_back(new_node->inputs[idx]);
new_node->inputs[idx] = nnvm::NodeEntry{tnode, 0, 0};
}
}
}
mirror_map[n.get()] = std::move(new_node);
});
std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : src.outputs) {
if (olayouts.count(e.node->op())) {
const auto& tnodes = transformed.at(e.node.get());
outputs.emplace_back(nnvm::NodeEntry{tnodes[e.index], 0, 0});
} else {
nnvm::NodePtr tnode = transformed.at(e.node.get())[e.index];
if (tnode.get()) {
outputs.emplace_back(nnvm::NodeEntry{tnode, 0, 0});
continue;
}
}
outputs.emplace_back(
nnvm::NodeEntry{mirror_map.at(e.node.get()), e.index, e.version});
}
}
nnvm::Graph ret;
ret.outputs = std::move(outputs);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment