Commit 7f5d22d7 by Yizhi Liu Committed by Tianqi Chen

enable AlterOpLayout to keep OP unchanged (#471)

parent be968fef
...@@ -80,11 +80,14 @@ using FTVMSchedule = std::function< ...@@ -80,11 +80,14 @@ using FTVMSchedule = std::function<
* \param attrs The attribute of the original node. * \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node. * \param inputs The input symbols of the original node.
* \param tinfos The inferred shape and dtype of the inputs. * \param tinfos The inferred shape and dtype of the inputs.
* \param ret The replaced operator.
* \return Whether to replace current operator.
*/ */
using FTVMAlterOpLayout = std::function< using FTVMAlterOpLayout = std::function<
Symbol(const NodeAttrs& attrs, bool(const NodeAttrs& attrs,
const Symbol& inputs, const Symbol& inputs,
const Array<Tensor>& tinfos)>; const Array<Tensor>& tinfos,
Symbol* ret)>;
/*! /*!
* \brief Transform from normal operator to vectorized operator * \brief Transform from normal operator to vectorized operator
......
...@@ -120,6 +120,10 @@ def schedule_conv2d(attrs, outs, target): ...@@ -120,6 +120,10 @@ def schedule_conv2d(attrs, outs, target):
return topi.generic.schedule_conv2d_nhwc(outs) return topi.generic.schedule_conv2d_nhwc(outs)
return topi.generic.schedule_depthwise_conv2d_nchw(outs) return topi.generic.schedule_depthwise_conv2d_nchw(outs)
@reg.register_alter_op_layout("conv2d")
def alter_conv2d_layout(attrs, inputs, tinfos):
return topi.nn.conv2d_alter_layout(attrs, inputs, tinfos)
reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# convolution NCHWc # convolution NCHWc
......
...@@ -103,9 +103,11 @@ Graph AlterOpLayout(const Graph& src) { ...@@ -103,9 +103,11 @@ Graph AlterOpLayout(const Graph& src) {
tensor_infos.push_back(op_output_tinfos[input.index]); tensor_infos.push_back(op_output_tinfos[input.index]);
} }
// callback registered function to get a new operator. // callback registered function to get a new operator.
auto op = fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos); Symbol op;
*ret = op.outputs; bool do_alter =
return true; fn_alter_op_layout(n->attrs, Symbol::CreateGroup(op_inputs), tensor_infos, &op);
if (do_alter) *ret = op.outputs;
return do_alter;
}; };
Graph ret = nnvm::compiler::GraphTransform(src, transform); Graph ret = nnvm::compiler::GraphTransform(src, transform);
......
...@@ -466,8 +466,8 @@ bool Conv2DScaleAxisBackward( ...@@ -466,8 +466,8 @@ bool Conv2DScaleAxisBackward(
using top::Conv2DParam; using top::Conv2DParam;
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed); const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if (out_info.kind != kPending) return false; if (out_info.kind != kPending) return false;
// only optimize for nchw for now // only optimize for kernel layout OIHW for now
if (param.layout == "NCHW" && out_info.axis == 1) { if (param.kernel_layout == "OIHW" && out_info.axis == 1) {
(*in_axis)[1].kind = kMulConsumer; (*in_axis)[1].kind = kMulConsumer;
(*in_axis)[1].axis = 0; (*in_axis)[1].axis = 0;
(*in_axis)[1].source = out_info.source; (*in_axis)[1].source = out_info.source;
...@@ -492,7 +492,7 @@ bool Conv2DScaleAxisForward( ...@@ -492,7 +492,7 @@ bool Conv2DScaleAxisForward(
const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed); const Conv2DParam& param = nnvm::get<Conv2DParam>(attrs.parsed);
if ((*in_info)[0].kind != kPending) return false; if ((*in_info)[0].kind != kPending) return false;
// only optimize for nchw for now // only optimize for nchw for now
if (param.layout == "NCHW" && (*in_info)[0].axis == 1) { if (param.kernel_layout == "OIHW" && (*in_info)[0].axis == 1) {
(*in_info)[1].kind = kMulConsumer; (*in_info)[1].kind = kMulConsumer;
(*in_info)[1].axis = 1; (*in_info)[1].axis = 1;
(*in_info)[1].source = (*in_info)[0].source; (*in_info)[1].source = (*in_info)[0].source;
......
...@@ -70,12 +70,17 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout") ...@@ -70,12 +70,17 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout")
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]); Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
auto fpack = [f](const NodeAttrs& attrs, auto fpack = [f](const NodeAttrs& attrs,
const Symbol& inputs, const Symbol& inputs,
const Array<Tensor>& tinfos) { const Array<Tensor>& tinfos,
Symbol* ret_symbol) {
TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, tinfos); TVMRetValue ret = (*f)(GetAttrDict(attrs), inputs, tinfos);
if (ret.type_code() == TVMTypeCode::kNull) {
return false;
}
CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info<Symbol>::code) CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info<Symbol>::code)
<< " expected " << "Symbol (code = " << tvm::runtime::extension_class_info<Symbol>::code << " expected " << "Symbol (code = " << tvm::runtime::extension_class_info<Symbol>::code
<< ") but get code = " << ret.type_code(); << ") but get code = " << ret.type_code();
return *(static_cast<Symbol*>(ret.value().v_handle)); *ret_symbol = *(static_cast<Symbol*>(ret.value().v_handle));
return true;
}; };
op.set_attr<FTVMAlterOpLayout>("FTVMAlterOpLayout", fpack, args[2]); op.set_attr<FTVMAlterOpLayout>("FTVMAlterOpLayout", fpack, args[2]);
}); });
......
...@@ -75,7 +75,7 @@ inline TShape ConvertLayout(TShape src, const Layout& src_layout, const Layout& ...@@ -75,7 +75,7 @@ inline TShape ConvertLayout(TShape src, const Layout& src_layout, const Layout&
CHECK_GT(dst_factor, 0); CHECK_GT(dst_factor, 0);
CHECK_LE(dst_factor, src_dim_size) << "Converting " << src CHECK_LE(dst_factor, src_dim_size) << "Converting " << src
<< " from " << src_layout << " from " << src_layout
<< " to " << dst_factor << " to " << dst_layout
<< ": cannot split dimension size of " << ": cannot split dimension size of "
<< src_dim_size << " by " << dst_factor; << src_dim_size << " by " << dst_factor;
dst[dst_major_pos] /= dst_factor; dst[dst_major_pos] /= dst_factor;
......
...@@ -32,7 +32,7 @@ def test_alter_conv2d_layout(): ...@@ -32,7 +32,7 @@ def test_alter_conv2d_layout():
g = g.apply(["InferShape", "InferType"]) g = g.apply(["InferShape", "InferType"])
layouts_origin = get_layouts(g) layouts_origin = get_layouts(g)
@reg.register_alter_op_layout("conv2d") @reg.register_alter_op_layout("conv2d", level=100)
def alter_conv2d_layout(attrs, inputs, tinfos): def alter_conv2d_layout(attrs, inputs, tinfos):
new_attrs = {k : attrs[k] for k in attrs.keys()} new_attrs = {k : attrs[k] for k in attrs.keys()}
new_attrs["layout"] = "NCHW16c" new_attrs["layout"] = "NCHW16c"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment