Commit 32121865 by ziheng Committed by GitHub

[PASS] Layout transform pass (#233)

* [PASS] Layout transform pass

* Fix according to comment

* Fix
parent adc06e6f
......@@ -387,5 +387,115 @@ NNVM_REGISTER_OP(tvm_op)
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_outputs;
});
inline bool IsPair(LayoutInfo in, LayoutInfo out) {
if (in.src == out.dst && in.dst == out.src) return true;
return false;
}
inline LayoutInfo GetLayout(const nnvm::OpMap<FTVMLayoutInfo>& layouts,
const nnvm::NodePtr& n, int idx) {
return layouts[n->op()](n->attrs)[idx];
}
nnvm::NodePtr CreateLayoutTransformNode(std::string src, std::string dst) {
static const nnvm::Op* trans_op = nnvm::Op::Get("layout_transform");
static int count = 0;
nnvm::NodePtr n = nnvm::Node::Create();
n->attrs.op = trans_op;
n->attrs.name = src + "_to_" + dst + std::to_string(count++);
n->attrs.dict["src"] = src;
n->attrs.dict["dst"] = dst;
return n;
}
/*!
* \brief A simple layout transform pass that will
* insert layout transform nodes automatically.
*/
nnvm::Graph LayoutTransform(nnvm::Graph src) {
static auto& ilayouts =
nnvm::Op::GetAttr<FTVMInputsLayoutInfo>("FTVMInputsLayoutInfo");
static auto& olayouts =
nnvm::Op::GetAttr<FTVMOutputsLayoutInfo>("FTVMOutputsLayoutInfo");
std::unordered_map<nnvm::Node*, nnvm::NodePtr> mirror_map;
std::unordered_map<nnvm::Node*, std::vector<nnvm::NodePtr> > transformed;
DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) {
nnvm::NodePtr new_node = nnvm::Node::Create();
*new_node = *n;
if (new_node->is_variable()) {
mirror_map[n.get()] = new_node;
return;
}
if (olayouts.count(n->op())) {
std::vector<nnvm::NodePtr> tnodes(n->num_outputs(), nullptr);
for (uint32_t i = 0; i < n->num_outputs(); ++i) {
LayoutInfo layout = GetLayout(olayouts, n, i);
tnodes[i] = CreateLayoutTransformNode(layout.src, layout.dst);
tnodes[i]->inputs.emplace_back(nnvm::NodeEntry{new_node, i, 0});
}
transformed.emplace(n.get(), std::move(tnodes));
}
for (size_t idx = 0; idx < n->inputs.size(); ++idx) {
const nnvm::NodeEntry& e = n->inputs[idx];
const nnvm::NodePtr& in = e.node;
new_node->inputs[idx] =
nnvm::NodeEntry{mirror_map.at(in.get()), e.index, e.version};
bool otrans = olayouts.count(in->op());
bool itrans = ilayouts.count(n->op());
if (otrans && itrans) {
LayoutInfo olayout = GetLayout(olayouts, in, e.index);
LayoutInfo ilayout = GetLayout(ilayouts, n, idx);
if (IsPair(olayout, ilayout)) {
continue;
}
}
if (otrans) {
const auto& tnodes = transformed.at(in.get());
new_node->inputs[idx] =
nnvm::NodeEntry{tnodes[e.index], 0, 0};
}
if (itrans) {
LayoutInfo layout = GetLayout(ilayouts, n, idx);
nnvm::NodePtr tnode =
CreateLayoutTransformNode(layout.src, layout.dst);
tnode->inputs.emplace_back(new_node->inputs[idx]);
new_node->inputs[idx] = nnvm::NodeEntry{tnode, 0, 0};
}
}
mirror_map[n.get()] = std::move(new_node);
});
std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : src.outputs) {
if (olayouts.count(e.node->op())) {
const auto& tnodes = transformed.at(e.node.get());
outputs.emplace_back(nnvm::NodeEntry{tnodes[e.index], 0, 0});
} else {
outputs.emplace_back(
nnvm::NodeEntry{mirror_map.at(e.node.get()), e.index, e.version});
}
}
nnvm::Graph ret;
ret.outputs = std::move(outputs);
return ret;
}
NNVM_REGISTER_PASS(LayoutTransform)
.set_body(LayoutTransform);
NNVM_REGISTER_OP(layout_transform)
.set_num_inputs(1)
.set_num_outputs(1);
} // namespace contrib
} // namespace tvm
......@@ -52,6 +52,36 @@ using FTVMSchedule = std::function<
const Array<Tensor>& outs,
const std::string& target)>;
/*!
* \brief Layout transform information,
* from source layout to destination layout.
*/
struct LayoutInfo {
using Layout = std::string;
Layout src;
Layout dst;
};
/*!
* \brief Layout info of the node.
* \param attrs The attribute of the node.
* \return layouts A vector of inputs/outputs layout info.
*/
using FTVMLayoutInfo = std::function<
std::vector<LayoutInfo>(const NodeAttrs& attrs)>;
/*!
* \brief Inputs layout info of the node.
* \param attrs The attribute of the node.
* \return layouts A vector of inputs layout info.
*/
using FTVMInputsLayoutInfo = FTVMLayoutInfo;
/*!
* \brief Outputs layout info of the node.
* \param attrs The attribute of the node.
* \return layouts A vector of outputs layout info.
*/
using FTVMOutputsLayoutInfo = FTVMLayoutInfo;
// The storage result of op
enum OpPatternKind : int {
// Elementwise operation
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment