Commit cdf8dff6 by eqy Committed by Yizhi Liu

[RELAY][TOPI] `alter_op_layout` for x86 (#2602)

* alter_op_layout for x86

* cleanup

* cleanup

* fix lint

* fix lint

* fix lint

* fix lint

* change support level

* change other support levels
parent 231bac21
......@@ -305,3 +305,28 @@ def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("nn.contrib_conv2d_NCHWc")
def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target):
"""Compute definition of conv2d NCHWc"""
# pylint: disable=assignment-from-no-return
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
out = topi.nn.conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
out_layout, out_dtype)
return [out]
@reg.register_schedule("nn.contrib_conv2d_NCHWc")
def schedule_contrib_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of contrib_conv2d_NCHWc"""
with target:
return topi.generic.schedule_conv2d_NCHWc(outs)
reg.register_pattern("nn.contrib_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE)
......@@ -837,6 +837,72 @@ def contrib_conv2d_winograd_without_weight_transform(data,
kernel_layout, out_layout, out_dtype)
def contrib_conv2d_nchwc(data,
kernel,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW8c",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""Variant of 2D convolution.
This operator takes the weight as the convolution kernel
and convolves it with data to produce an output, following a specialized
NCHWc data layout.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
kernel : tvm.relay.Expr
The kernel expressions.
strides : tuple of int, optional
The strides of convoltution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv2d_NCHWc(data, kernel, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def contrib_conv2d_winograd_weight_transform(weight,
tile_size):
r"""Weight Transformation part for 2D convolution with winograd algorithm.
......
......@@ -453,7 +453,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(5)
.set_support_level(10)
.add_type_rel("Conv2DWinograd", Conv2DWinogradRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DWinogradAttrs>);
......@@ -513,8 +513,61 @@ weight transformation in advance.
.set_attrs_type_key("relay.attrs.Conv2DWinogradWeightTransformAttrs")
.set_num_inputs(1)
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(5)
.set_support_level(10)
.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
Expr MakeConv2DNCHWc(Expr data,
Expr kernel,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_NCHWc");
return CallNode::make(op, {data, kernel}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_NCHWc")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeConv2DNCHWc, args, rv);
});
RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
- **data**: Input is 5D packed tensor.
- **weight**: 6D packed tensor.
- **out**: Output is 5D packed tensor
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2D")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2D", Conv2DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);
} // namespace relay
} // namespace tvm
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
"""Conv2D schedule on x86"""
import warnings
import tvm
from tvm import autotvm
......@@ -285,10 +284,6 @@ def _topi_nn_conv2d_NCHWc(*args, **kwargs):
@conv2d_alter_layout.register("cpu")
def _alter_conv2d_layout(attrs, inputs, tinfo, F):
import nnvm.symbol as sym
if F != sym:
warnings.warn("Only support alter layout for x86 in NNVM now. "
"This pass is ignored in relay.")
return None
copy_inputs = [s for s in inputs]
new_attrs = {k : attrs[k] for k in attrs.keys()}
......@@ -300,7 +295,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
layout = attrs['layout']
layout_name = 'layout' if F == sym else 'data_layout'
layout = attrs[layout_name]
kh, kw = attrs.get_int_tuple("kernel_size")
dtype = data.dtype
......@@ -326,7 +324,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise)
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
new_attrs['layout'] = 'NCHW%dc' % ic_bn
new_attrs[layout_name] = 'NCHW%dc' % ic_bn
new_attrs['out_layout'] = 'NCHW%dc' % oc_bn
new_data = tvm.placeholder((batch_size, in_channel//ic_bn, height, width, ic_bn),
......@@ -342,7 +341,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs['layout'],
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc)
else:
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
......@@ -353,11 +352,13 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, new_attrs['layout'],
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)
return sym.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
if F == sym:
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
@autotvm.register_topi_compute(conv2d_NCHWc, 'cpu', 'direct')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment