Commit 7f5d22d7 by Yizhi Liu Committed by Tianqi Chen

enable AlterOpLayout to keep OP unchanged (#471)

parent be968fef
......@@ -80,11 +80,14 @@ using FTVMSchedule = std::function<
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos The inferred shape and dtype of the inputs.
* \param ret The replaced operator.
* \return Whether to replace current operator.
*/
using FTVMAlterOpLayout = std::function<
Symbol(const NodeAttrs& attrs,
const Symbol& inputs,
const Array<Tensor>& tinfos)>;
bool(const NodeAttrs& attrs,
const Symbol& inputs,
const Array<Tensor>& tinfos,
Symbol* ret)>;
/*!
* \brief Transform from normal operator to vectorized operator
......
......@@ -120,6 +120,10 @@ def schedule_conv2d(attrs, outs, target):
return topi.generic.schedule_conv2d_nhwc(outs)
return topi.generic.schedule_depthwise_conv2d_nchw(outs)
@reg.register_alter_op_layout("conv2d")
def alter_conv2d_layout(attrs, inputs, tinfos):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos)
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# convolution NCHWc
......
......@@ -103,9 +103,11 @@ Graph AlterOpLayout(const Graph& src) {
tensor_infos.push_back(op_output_tinfos[input.index]);
}
// callback registered function to get a new operator.
auto op = fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos);
*ret = op.outputs;
return true;
Symbol op;
bool do_alter =
fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos, &op);
if (do_alter) *ret = op.outputs;
return do_alter;
};
Graph ret = nnvm::compiler::GraphTransform(src, transform);
......
......@@ -466,8 +466,8 @@ bool Conv2DScaleAxisBackward(
using top::Conv2DParam;
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if (out_info.kind != kPending) return false;
// only optimize for nchw for now
if (param.layout == "NCHW" && out_info.axis == 1) {
// only optimize for kernel layout OIHW for now
if (param.kernel_layout == "OIHW" && out_info.axis == 1) {
(*in_axis)[1].kind = kMulConsumer;
(*in_axis)[1].axis = 0;
(*in_axis)[1].source = out_info.source;
......@@ -492,7 +492,7 @@ bool Conv2DScaleAxisForward(
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if ((*in_info)[0].kind != kPending) return false;
// only optimize for nchw for now
if (param.layout == "NCHW" && (*in_info)[0].axis == 1) {
if (param.kernel_layout == "OIHW" && (*in_info)[0].axis == 1) {
(*in_info)[1].kind = kMulConsumer;
(*in_info)[1].axis = 1;
(*in_info)[1].source = (*in_info)[0].source;
......
......@@ -70,12 +70,17 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout")
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
auto fpack = [f](const NodeAttrs& attrs,
const Symbol& inputs,
const Array<Tensor>& tinfos) {
const Array<Tensor>& tinfos,
Symbol* ret_symbol) {
TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, tinfos);
if (ret.type_code() == TVMTypeCode::kNull) {
return false;
}
CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info<Symbol>::code)
<< " expected " << "Symbol (code = " << tvm::runtime::extension_class_info<Symbol>::code
<< ") but get code = " << ret.type_code();
return *(static_cast<Symbol*>(ret.value().v_handle));
*ret_symbol = *(static_cast<Symbol*>(ret.value().v_handle));
return true;
};
op.set_attr<FTVMAlterOpLayout>("FTVMAlterOpLayout", fpack, args[2]);
});
......
......@@ -75,7 +75,7 @@ inline TShape ConvertLayout(TShape src, const Layout& src_layout, const Layout&
CHECK_GT(dst_factor, 0);
CHECK_LE(dst_factor, src_dim_size) << "Converting " << src
<< " from " << src_layout
<< " to " << dst_factor
<< " to " << dst_layout
<< ": cannot split dimension size of "
<< src_dim_size << " by " << dst_factor;
dst[dst_major_pos] /= dst_factor;
......
......@@ -32,7 +32,7 @@ def test_alter_conv2d_layout():
g = g.apply(["InferShape", "InferType"])
layouts_origin = get_layouts(g)
@reg.register_alter_op_layout("conv2d")
@reg.register_alter_op_layout("conv2d", level=100)
def alter_conv2d_layout(attrs, inputs, tinfos):
new_attrs = {k : attrs[k] for k in attrs.keys()}
new_attrs["layout"] = "NCHW16c"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment