Commit 581be165 by ziheng Committed by Tianqi Chen

[PASS] Enhance LayoutTransform pass (#293)

* [PASS] Enhance LayoutTransform pass

* Fix

* Fix Compilation

* Refactor

* Refactor

* doc

* fix

* add file
parent eefcfe19
......@@ -373,21 +373,9 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
NNVM_REGISTER_PASS(GraphFuse)
.set_body(GraphFuse);
inline bool IsIdentityLayout(const LayoutInfo& layout) {
if (layout.src == "" && layout.dst == "") return true;
return false;
}
inline bool IsPairedLayouts(const LayoutInfo& in,
const LayoutInfo& out) {
if (in.src == out.dst && in.dst == out.src) return true;
return false;
}
inline LayoutInfo GetLayout(const nnvm::OpMap<FTVMLayoutInfo>& layouts,
const nnvm::NodePtr& n, int idx) {
return layouts[n->op()](n->attrs)[idx];
const TLayoutInfo& GetDefaultLayout() {
static TLayoutInfo default_layout = "default";
return default_layout;
}
nnvm::NodePtr CreateLayoutTransformNode(const std::string& src,
......@@ -408,92 +396,117 @@ nnvm::NodePtr CreateLayoutTransformNode(const std::string& src,
* insert layout transform nodes automatically.
*/
nnvm::Graph LayoutTransform(nnvm::Graph src) {
static auto& ilayouts =
nnvm::Op::GetAttr<FTVMInputsLayoutInfo>("FTVMInputsLayoutInfo");
static auto& olayouts =
nnvm::Op::GetAttr<FTVMOutputsLayoutInfo>("FTVMOutputsLayoutInfo");
static auto& vec_op =
static auto& op_layout_request =
nnvm::Op::GetAttr<FTVMLayoutRequest>("FTVMLayoutRequest");
static auto& op_vecop =
nnvm::Op::GetAttr<FTVMVectorizedOp>("FTVMVectorizedOp");
static auto& op_pattern = nnvm::Op::GetAttr<TOpPattern>("TOpPattern");
std::unordered_map<nnvm::Node*, nnvm::NodePtr> mirror_map;
std::unordered_map<nnvm::Node*, std::vector<nnvm::NodePtr> > transformed;
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
const std::vector<TLayoutInfo>& input_layouts =
src.GetAttr<std::vector<TLayoutInfo> >("layout");
DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) {
const IndexedGraph& idx = src.indexed_graph();
std::vector<TLayoutInfo> produce_vec(idx.num_node_entries(), GetDefaultLayout());
std::vector<nnvm::NodePtr> mirror_vec(idx.num_nodes(), nullptr);
// use op pattern to decide whether an op is map
auto is_map_op = [&](size_t nid) {
TOpPattern pt = op_pattern.get(idx[nid].source->op(), kExtern);
bool is_map = (pt <= kBroadcast);
if (pt == kBroadcast) {
for (const auto& e : idx[nid].inputs) {
if (shape_vec[idx.entry_id(nid, 0)] != shape_vec[idx.entry_id(e)]) {
is_map = false;
break;
}
}
}
return is_map;
};
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
nnvm::NodePtr new_node = nnvm::Node::Create();
*new_node = *n;
*new_node = *(inode.source);
if (new_node->is_variable()) {
mirror_map[n.get()] = new_node;
return;
auto input_iter = std::find(
idx.input_nodes().cbegin(), idx.input_nodes().cend(), nid);
CHECK(input_iter != idx.input_nodes().cend());
size_t input_id = std::distance(idx.input_nodes().cbegin(), input_iter);
produce_vec[idx.entry_id(nid, 0)] = input_layouts[input_id];
mirror_vec[nid] = new_node;
continue;
}
if (vec_op.count(n->op())) {
new_node = vec_op[n->op()](n);
if (op_vecop.count(inode.source->op())) {
new_node = op_vecop[inode.source->op()](inode.source);
new_node->inputs.resize(new_node->num_inputs());
}
if (olayouts.count(new_node->op())) {
std::vector<nnvm::NodePtr> tnodes(n->num_outputs(), nullptr);
std::vector<LayoutInfo> layouts = olayouts[new_node->op()](new_node->attrs);
for (uint32_t i = 0; i < n->num_outputs(); ++i) {
const LayoutInfo& layout = layouts[i];
if (!IsIdentityLayout(layout)) {
tnodes[i] = CreateLayoutTransformNode(layout.src, layout.dst);
tnodes[i]->attrs.name = new_node->attrs.name + "_" + layout.dst;
tnodes[i]->inputs.emplace_back(nnvm::NodeEntry{new_node, i, 0});
}
// set up output and input layouts
std::vector<TLayoutInfo> request_ilayouts(new_node->num_inputs(), GetDefaultLayout());
if (op_layout_request.count(new_node->op())) {
std::vector<TLayoutInfo> produce_olayouts(new_node->num_outputs(), GetDefaultLayout());
CHECK(op_layout_request[new_node->op()](new_node->attrs, &request_ilayouts, &produce_olayouts))
<< "Layout request fail";
CHECK_EQ(request_ilayouts.size(), new_node->num_inputs());
CHECK_EQ(produce_olayouts.size(), new_node->num_outputs());
for (size_t i = 0; i < new_node->num_outputs(); ++i) {
produce_vec[idx.entry_id(nid, i)] = produce_olayouts[i];
}
transformed.emplace(n.get(), std::move(tnodes));
}
for (size_t idx = 0; idx < n->inputs.size(); ++idx) {
const nnvm::NodeEntry& e = n->inputs[idx];
const nnvm::NodePtr& in = mirror_map.at(e.node.get());
new_node->inputs[idx] =
nnvm::NodeEntry{in, e.index, e.version};
bool otrans = olayouts.count(in->op());
bool itrans = ilayouts.count(new_node->op());
if (otrans && itrans) {
LayoutInfo ilayout = GetLayout(ilayouts, new_node, idx);
LayoutInfo olayout = GetLayout(olayouts, in, e.index);
if (IsPairedLayouts(olayout, ilayout)) {
continue;
bool map_layout = is_map_op(nid);
if (map_layout) {
const TLayoutInfo& layout = produce_vec[idx.entry_id(inode.inputs[0])];
for (const auto& e : inode.inputs) {
if (produce_vec[idx.entry_id(e)] != layout) {
map_layout = false;
break;
}
}
if (otrans) {
nnvm::NodePtr tnode = transformed.at(in.get())[e.index];
if (tnode.get()) {
new_node->inputs[idx] =
nnvm::NodeEntry{tnode, 0, 0};
if (map_layout) {
for (size_t i = 0; i < inode.source->num_outputs(); ++i) {
produce_vec[idx.entry_id(nid, i)] = layout;
}
}
}
for (size_t i = 0; i < inode.inputs.size(); ++i) {
const auto& e = inode.inputs[i];
const nnvm::NodePtr& in = mirror_vec[e.node_id];
new_node->inputs[i] =
nnvm::NodeEntry{in, e.index, e.version};
if (itrans) {
LayoutInfo layout = GetLayout(ilayouts, new_node, idx);
if (!IsIdentityLayout(layout)) {
nnvm::NodePtr tnode =
CreateLayoutTransformNode(layout.src, layout.dst);
tnode->attrs.name = n->inputs[idx].node->attrs.name + "_" + layout.dst;
tnode->inputs.emplace_back(new_node->inputs[idx]);
new_node->inputs[idx] = nnvm::NodeEntry{tnode, 0, 0};
TLayoutInfo produce = produce_vec[idx.entry_id(e)];
TLayoutInfo request = request_ilayouts[i];
if (!map_layout && (produce != request)) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name =
idx[e.node_id].source->attrs.name + "_" + request;
tnode->inputs.emplace_back(new_node->inputs[i]);
new_node->inputs[i] = nnvm::NodeEntry{tnode, 0, 0};
}
}
mirror_vec[nid] = new_node;
}
mirror_map[n.get()] = std::move(new_node);
});
std::vector<nnvm::NodeEntry> outputs;
for (const auto& e : src.outputs) {
if (olayouts.count(e.node->op())) {
nnvm::NodePtr tnode = transformed.at(e.node.get())[e.index];
if (tnode.get()) {
for (const auto& e : idx.outputs()) {
TLayoutInfo produce = produce_vec[idx.entry_id(e)];
if (produce != GetDefaultLayout()) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, GetDefaultLayout());
tnode->attrs.name =
idx[e.node_id].source->attrs.name + "_default";
tnode->inputs.emplace_back(
nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
outputs.emplace_back(nnvm::NodeEntry{tnode, 0, 0});
continue;
}
}
} else {
outputs.emplace_back(
nnvm::NodeEntry{mirror_map.at(e.node.get()), e.index, e.version});
nnvm::NodeEntry{mirror_vec[e.node_id], e.index, e.version});
}
}
nnvm::Graph ret;
......@@ -535,7 +548,7 @@ NNVM_REGISTER_OP(layout_transform)
nnvm::Graph PruneGraph(nnvm::Graph src) {
const auto& params = src.GetAttr<std::unordered_set<std::string>>("params");
const auto& params = src.GetAttr<std::unordered_set<std::string> >("params");
std::unordered_set<nnvm::Node*> pruned;
nnvm::NodeEntryMap<nnvm::NodePtr> entry_var;
......
......@@ -52,35 +52,22 @@ using FTVMSchedule = std::function<
const Array<Tensor>& outs,
const std::string& target)>;
/*!
* \brief Layout transform information,
* from source layout to destination layout.
*/
struct LayoutInfo {
using Layout = std::string;
Layout src;
Layout dst;
};
/*! \brief Layout Information. */
using TLayoutInfo = std::string;
/*!
* \brief Layout info of the node.
* \param attrs The attribute of the node.
* \return layouts A vector of inputs/outputs layout info.
*/
using FTVMLayoutInfo = std::function<
std::vector<LayoutInfo>(const NodeAttrs& attrs)>;
/*!
* \brief Inputs layout info of the node.
* \brief The producer consumer function of node layout
* \param attrs The attribute of the node.
* \return layouts A vector of inputs layout info.
* \param ilayouts The input layouts that the node request.
* \param olayouts The output layouts that the node produce.
* \return bool The success flag.
*/
using FTVMInputsLayoutInfo = FTVMLayoutInfo;
/*!
* \brief Outputs layout info of the node.
* \param attrs The attribute of the node.
* \return layouts A vector of outputs layout info.
*/
using FTVMOutputsLayoutInfo = FTVMLayoutInfo;
using FTVMLayoutRequest = std::function<bool (const NodeAttrs& attrs,
std::vector<TLayoutInfo> *ilayouts,
std::vector<TLayoutInfo> *olayouts)>;
/*! \brief The default layout. */
const TLayoutInfo& GetDefaultLayout();
/*! \brief Parameters of layout transform operator */
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
......@@ -93,7 +80,7 @@ struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
};
/*! \brief Transform from normal operator to vectorized operator */
using FTVMVectorizedOp = std::function<nnvm::NodePtr (nnvm::NodePtr)>;
using FTVMVectorizedOp = std::function<nnvm::NodePtr (const nnvm::Node*)>;
// The storage result of op
enum OpPatternKind : int {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment