Commit 3935329b by ziheng Committed by GitHub

[EXECUTOR] Improve LayoutTransform pass (#273)

* [EXECUTOR] Improve LayoutTransform pass

* Remove offline params for now

* Small fix
parent 591afad9
......@@ -364,12 +364,12 @@ nnvm::NodePtr CreateLayoutTransformNode(const std::string& src,
nnvm::NodePtr n = nnvm::Node::Create();
n->attrs.op = trans_op;
n->attrs.name = src + "_to_" + dst + std::to_string(count++);
n->attrs.dict["src"] = src;
n->attrs.dict["dst"] = dst;
n->attrs.dict["src_layout"] = src;
n->attrs.dict["dst_layout"] = dst;
n->op()->attr_parser(&(n->attrs));
return n;
}
/*!
* \brief A simple layout transform pass that will
* insert layout transform nodes automatically.
......@@ -379,6 +379,8 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
nnvm::Op::GetAttr<FTVMInputsLayoutInfo>("FTVMInputsLayoutInfo");
static auto& olayouts =
nnvm::Op::GetAttr<FTVMOutputsLayoutInfo>("FTVMOutputsLayoutInfo");
static auto& vec_op =
nnvm::Op::GetAttr<FTVMVectorizedOp>("FTVMVectorizedOp");
std::unordered_map<nnvm::Node*, nnvm::NodePtr> mirror_map;
std::unordered_map<nnvm::Node*, std::vector<nnvm::NodePtr> > transformed;
......@@ -391,13 +393,19 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
return;
}
if (olayouts.count(n->op())) {
if (vec_op.count(n->op())) {
new_node = vec_op[n->op()](n);
new_node->inputs.resize(new_node->num_inputs());
}
if (olayouts.count(new_node->op())) {
std::vector<nnvm::NodePtr> tnodes(n->num_outputs(), nullptr);
std::vector<LayoutInfo> layouts = olayouts[n->op()](n->attrs);
std::vector<LayoutInfo> layouts = olayouts[new_node->op()](new_node->attrs);
for (uint32_t i = 0; i < n->num_outputs(); ++i) {
const LayoutInfo& layout = layouts[i];
if (!IsIdentityLayout(layout)) {
tnodes[i] = CreateLayoutTransformNode(layout.src, layout.dst);
tnodes[i]->attrs.name = new_node->attrs.name + "_" + layout.dst;
tnodes[i]->inputs.emplace_back(nnvm::NodeEntry{new_node, i, 0});
}
}
......@@ -406,14 +414,14 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
for (size_t idx = 0; idx < n->inputs.size(); ++idx) {
const nnvm::NodeEntry& e = n->inputs[idx];
const nnvm::NodePtr& in = e.node;
const nnvm::NodePtr& in = mirror_map.at(e.node.get());
new_node->inputs[idx] =
nnvm::NodeEntry{mirror_map.at(in.get()), e.index, e.version};
nnvm::NodeEntry{in, e.index, e.version};
bool otrans = olayouts.count(in->op());
bool itrans = ilayouts.count(n->op());
bool itrans = ilayouts.count(new_node->op());
if (otrans && itrans) {
LayoutInfo ilayout = GetLayout(ilayouts, n, idx);
LayoutInfo ilayout = GetLayout(ilayouts, new_node, idx);
LayoutInfo olayout = GetLayout(olayouts, in, e.index);
if (IsPairedLayouts(olayout, ilayout)) {
continue;
......@@ -429,10 +437,11 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
}
if (itrans) {
LayoutInfo layout = GetLayout(ilayouts, n, idx);
LayoutInfo layout = GetLayout(ilayouts, new_node, idx);
if (!IsIdentityLayout(layout)) {
nnvm::NodePtr tnode =
CreateLayoutTransformNode(layout.src, layout.dst);
tnode->attrs.name = n->inputs[idx].node->attrs.name + "_" + layout.dst;
tnode->inputs.emplace_back(new_node->inputs[idx]);
new_node->inputs[idx] = nnvm::NodeEntry{tnode, 0, 0};
}
......@@ -462,8 +471,33 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
NNVM_REGISTER_PASS(LayoutTransform)
.set_body(LayoutTransform);
DMLC_REGISTER_PARAMETER(LayoutTransformParam);
/*! \brief Parse keyword arguments as PType arguments and save to parsed */
template<typename PType>
inline void ParamParser(nnvm::NodeAttrs* attrs) {
PType param;
try {
param.Init(attrs->dict);
} catch (const dmlc::ParamError& e) {
std::ostringstream os;
os << e.what();
os << ", in operator " << attrs->op->name << "("
<< "name=\"" << attrs->name << "\"";
for (const auto& k : attrs->dict) {
os << ", " << k.first << "=\"" << k.second << "\"";
}
os << ")";
throw dmlc::ParamError(os.str());
}
attrs->parsed = std::move(param);
}
NNVM_REGISTER_OP(layout_transform)
.set_attr_parser(ParamParser<LayoutTransformParam>)
.set_num_inputs(1)
.set_num_outputs(1);
.set_num_outputs(1)
.add_argument("data", "NDArray-or-Symbol", "Input data")
.add_arguments(LayoutTransformParam::__FIELDS__());
} // namespace contrib
} // namespace tvm
......@@ -82,6 +82,19 @@ using FTVMInputsLayoutInfo = FTVMLayoutInfo;
*/
using FTVMOutputsLayoutInfo = FTVMLayoutInfo;
/*! \brief Parameters of layout transform operator */
struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
std::string src_layout;
std::string dst_layout;
DMLC_DECLARE_PARAMETER(LayoutTransformParam) {
DMLC_DECLARE_FIELD(src_layout);
DMLC_DECLARE_FIELD(dst_layout);
}
};
/*! \brief Transform from normal operator to vectorized operator */
using FTVMVectorizedOp = std::function<nnvm::NodePtr (nnvm::NodePtr)>;
// The storage result of op
enum OpPatternKind : int {
// Elementwise operation
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment