Commit 581be165 by ziheng Committed by Tianqi Chen

[PASS] Enhance LayoutTransform pass (#293)

* [PASS] Enhance LayoutTransform pass

* Fix

* Fix Compilation

* Refactor

* Refactor

* doc

* fix

* add file
parent eefcfe19
...@@ -373,21 +373,9 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { ...@@ -373,21 +373,9 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
NNVM_REGISTER_PASS(GraphFuse) NNVM_REGISTER_PASS(GraphFuse)
.set_body(GraphFuse); .set_body(GraphFuse);
const TLayoutInfo& GetDefaultLayout() {
inline bool IsIdentityLayout(const LayoutInfo& layout) { static TLayoutInfo default_layout = "default";
if (layout.src == "" && layout.dst == "") return true; return default_layout;
return false;
}
inline bool IsPairedLayouts(const LayoutInfo& in,
const LayoutInfo& out) {
if (in.src == out.dst && in.dst == out.src) return true;
return false;
}
inline LayoutInfo GetLayout(const nnvm::OpMap<FTVMLayoutInfo>& layouts,
const nnvm::NodePtr& n, int idx) {
return layouts[n->op()](n->attrs)[idx];
} }
nnvm::NodePtr CreateLayoutTransformNode(const std::string& src, nnvm::NodePtr CreateLayoutTransformNode(const std::string& src,
...@@ -408,92 +396,117 @@ nnvm::NodePtr CreateLayoutTransformNode(const std::string& src, ...@@ -408,92 +396,117 @@ nnvm::NodePtr CreateLayoutTransformNode(const std::string& src,
* insert layout transform nodes automatically. * insert layout transform nodes automatically.
*/ */
nnvm::Graph LayoutTransform(nnvm::Graph src) { nnvm::Graph LayoutTransform(nnvm::Graph src) {
static auto& ilayouts = static auto& op_layout_request =
nnvm::Op::GetAttr<FTVMInputsLayoutInfo>("FTVMInputsLayoutInfo"); nnvm::Op::GetAttr<FTVMLayoutRequest>("FTVMLayoutRequest");
static auto& olayouts = static auto& op_vecop =
nnvm::Op::GetAttr<FTVMOutputsLayoutInfo>("FTVMOutputsLayoutInfo");
static auto& vec_op =
nnvm::Op::GetAttr<FTVMVectorizedOp>("FTVMVectorizedOp"); nnvm::Op::GetAttr<FTVMVectorizedOp>("FTVMVectorizedOp");
static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");
std::unordered_map<nnvm::Node*, nnvm::NodePtr> mirror_map; const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
std::unordered_map<nnvm::Node*, std::vector<nnvm::NodePtr> > transformed; const std::vector<TLayoutInfo>& input_layouts =
src.GetAttr<std::vector<TLayoutInfo> >("layout");
DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) { const IndexedGraph& idx = src.indexed_graph();
std::vector<TLayoutInfo> produce_vec(idx.num_node_entries(), GetDefaultLayout());
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);
// use op pattern to decide whether an op is map
auto is_map_op = [&](size_t nid) {
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kExtern);
bool is_map = (pt <= kBroadcast);
if (pt == kBroadcast) {
for (const auto& e : idx[nid].inputs) {
if (shape_vec[idx.entry_id(nid, 0)] != shape_vec[idx.entry_id(e)]) {
is_map = false;
break;
}
}
}
return is_map;
};
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
nnvm::NodePtr new_node = nnvm::Node::Create(); nnvm::NodePtr new_node = nnvm::Node::Create();
*new_node = *n; *new_node = *(inode.source);
if (new_node->is_variable()) { if (new_node->is_variable()) {
mirror_map[n.get()] = new_node; auto input_iter = std::find(
return; idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid);
CHECK(input_iter != idx.input_nodes().cend());
size_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter);
produce_vec[idx.entry_id(nid, 0)] = input_layouts[input_id];
mirror_vec[nid] = new_node;
continue;
} }
if (vec_op.count(n->op())) { if (op_vecop.count(inode.source->op())) {
new_node = vec_op[n->op()](n); new_node = op_vecop[inode.source->op()](inode.source);
new_node->inputs.resize(new_node->num_inputs()); new_node->inputs.resize(new_node->num_inputs());
} }
if (olayouts.count(new_node->op())) { // set up output and input layouts
std::vector<nnvm::NodePtr> tnodes(n->num_outputs(), nullptr); std::vector<TLayoutInfo> request_ilayouts(new_node->num_inputs(), GetDefaultLayout());
std::vector<LayoutInfo> layouts = olayouts[new_node->op()](new_node->attrs); if (op_layout_request.count(new_node->op())) {
for (uint32_t i = 0; i < n->num_outputs(); ++i) { std::vector<TLayoutInfo> produce_olayouts(new_node->num_outputs(), GetDefaultLayout());
const LayoutInfo& layout = layouts[i]; CHECK(op_layout_request[new_node->op()](new_node->attrs, &request_ilayouts, &produce_olayouts))
if (!IsIdentityLayout(layout)) { << "Layout request fail";
tnodes[i] = CreateLayoutTransformNode(layout.src, layout.dst);
tnodes[i]->attrs.name = new_node->attrs.name + "_" + layout.dst; CHECK_EQ(request_ilayouts.size(), new_node->num_inputs());
tnodes[i]->inputs.emplace_back(nnvm::NodeEntry{new_node, i, 0}); CHECK_EQ(produce_olayouts.size(), new_node->num_outputs());
} for (size_t i = 0; i < new_node->num_outputs(); ++i) {
produce_vec[idx.entry_id(nid, i)] = produce_olayouts[i];
} }
transformed.emplace(n.get(), std::move(tnodes));
} }
for (size_t idx = 0; idx < n->inputs.size(); ++idx) { bool map_layout = is_map_op(nid);
const nnvm::NodeEntry& e = n->inputs[idx]; if (map_layout) {
const nnvm::NodePtr& in = mirror_map.at(e.node.get()); const TLayoutInfo& layout = produce_vec[idx.entry_id(inode.inputs[0])];
new_node->inputs[idx] = for (const auto& e : inode.inputs) {
nnvm::NodeEntry{in, e.index, e.version}; if (produce_vec[idx.entry_id(e)] != layout) {
map_layout = false;
bool otrans = olayouts.count(in->op()); break;
bool itrans = ilayouts.count(new_node->op());
if (otrans && itrans) {
LayoutInfo ilayout = GetLayout(ilayouts, new_node, idx);
LayoutInfo olayout = GetLayout(olayouts, in, e.index);
if (IsPairedLayouts(olayout, ilayout)) {
continue;
} }
} }
if (map_layout) {
if (otrans) { for (size_t i = 0; i < inode.source->num_outputs(); ++i) {
nnvm::NodePtr tnode = transformed.at(in.get())[e.index]; produce_vec[idx.entry_id(nid, i)] = layout;
if (tnode.get()) {
new_node->inputs[idx] =
nnvm::NodeEntry{tnode, 0, 0};
} }
} }
}
for (size_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
const nnvm::NodePtr& in = mirror_vec[e.node_id];
new_node->inputs[i] =
nnvm::NodeEntry{in, e.index, e.version};
if (itrans) { TLayoutInfo produce = produce_vec[idx.entry_id(e)];
LayoutInfo layout = GetLayout(ilayouts, new_node, idx); TLayoutInfo request = request_ilayouts[i];
if (!IsIdentityLayout(layout)) { if (!map_layout && (produce != request)) {
nnvm::NodePtr tnode = nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
CreateLayoutTransformNode(layout.src, layout.dst); tnode->attrs.name =
tnode->attrs.name = n->inputs[idx].node->attrs.name + "_" + layout.dst; idx[e.node_id].source->attrs.name + "_" + request;
tnode->inputs.emplace_back(new_node->inputs[idx]); tnode->inputs.emplace_back(new_node->inputs[i]);
new_node->inputs[idx] = nnvm::NodeEntry{tnode, 0, 0}; new_node->inputs[i] = nnvm::NodeEntry{tnode, 0, 0};
} }
} }
mirror_vec[nid] = new_node;
} }
mirror_map[n.get()] = std::move(new_node);
});
std::vector<nnvm::NodeEntry> outputs; std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : src.outputs) { for (const auto& e : idx.outputs()) {
if (olayouts.count(e.node->op())) { TLayoutInfo produce = produce_vec[idx.entry_id(e)];
nnvm::NodePtr tnode = transformed.at(e.node.get())[e.index]; if (produce != GetDefaultLayout()) {
if (tnode.get()) { nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, GetDefaultLayout());
tnode->attrs.name =
idx[e.node_id].source->attrs.name + "_default";
tnode->inputs.emplace_back(
nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
outputs.emplace_back(nnvm::NodeEntry{tnode, 0, 0}); outputs.emplace_back(nnvm::NodeEntry{tnode, 0, 0});
continue; } else {
}
}
outputs.emplace_back( outputs.emplace_back(
nnvm::NodeEntry{mirror_map.at(e.node.get()), e.index, e.version}); nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
}
} }
nnvm::Graph ret; nnvm::Graph ret;
...@@ -535,7 +548,7 @@ NNVM_REGISTER_OP(layout_transform) ...@@ -535,7 +548,7 @@ NNVM_REGISTER_OP(layout_transform)
nnvm::Graph PruneGraph(nnvm::Graph src) { nnvm::Graph PruneGraph(nnvm::Graph src) {
const auto& params = src.GetAttr<std::unordered_set<std::string>>("params"); const auto& params = src.GetAttr<std::unordered_set<std::string> >("params");
std::unordered_set<nnvm::Node*> pruned; std::unordered_set<nnvm::Node*> pruned;
nnvm::NodeEntryMap<nnvm::NodePtr> entry_var; nnvm::NodeEntryMap<nnvm::NodePtr> entry_var;
......
...@@ -52,35 +52,22 @@ using FTVMSchedule = std::function< ...@@ -52,35 +52,22 @@ using FTVMSchedule = std::function<
const Array<Tensor>& outs, const Array<Tensor>& outs,
const std::string& target)>; const std::string& target)>;
/*! /*! \brief Layout Information. */
* \brief Layout transform information, using TLayoutInfo = std::string;
* from source layout to destination layout.
*/
struct LayoutInfo {
using Layout = std::string;
Layout src;
Layout dst;
};
/*! /*!
* \brief Layout info of the node. * \brief The producer consumer function of node layout
* \param attrs The attribute of the node.
* \return layouts A vector of inputs/outputs layout info.
*/
using FTVMLayoutInfo = std::function<
std::vector<LayoutInfo>(const NodeAttrs& attrs)>;
/*!
* \brief Inputs layout info of the node.
* \param attrs The attribute of the node. * \param attrs The attribute of the node.
* \return layouts A vector of inputs layout info. * \param ilayouts The input layouts that the node request.
* \param olayouts The output layouts that the node produce.
* \return bool The success flag.
*/ */
using FTVMInputsLayoutInfo = FTVMLayoutInfo; using FTVMLayoutRequest = std::function<bool (const NodeAttrs& attrs,
/*! std::vector<TLayoutInfo> *ilayouts,
* \brief Outputs layout info of the node. std::vector<TLayoutInfo> *olayouts)>;
* \param attrs The attribute of the node.
* \return layouts A vector of outputs layout info. /*! \brief The default layout. */
*/ const TLayoutInfo& GetDefaultLayout();
using FTVMOutputsLayoutInfo = FTVMLayoutInfo;
/*! \brief Parameters of layout transform operator */ /*! \brief Parameters of layout transform operator */
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> { struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
...@@ -93,7 +80,7 @@ struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> { ...@@ -93,7 +80,7 @@ struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
}; };
/*! \brief Transform from normal operator to vectorized operator */ /*! \brief Transform from normal operator to vectorized operator */
using FTVMVectorizedOp = std::function<nnvm::NodePtr (nnvm::NodePtr)>; using FTVMVectorizedOp = std::function<nnvm::NodePtr (const nnvm::Node*)>;
// The storage result of op // The storage result of op
enum OpPatternKind : int { enum OpPatternKind : int {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment