Commit a02916b5 by hlu1 Committed by Tianqi Chen

winograd_nnpack (#2721)

parent 7f942474
...@@ -155,6 +155,24 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> { ...@@ -155,6 +155,24 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
} }
}; };
/*! \brief Attributes used in winograd weight transformation operators */
struct Conv2DWinogradNNPACKWeightTransformAttrs
: public tvm::AttrsNode<Conv2DWinogradNNPACKWeightTransformAttrs> {
int convolution_algorithm;
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv2DWinogradNNPACKWeightTransformAttrs,
"relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs") {
TVM_ATTR_FIELD(convolution_algorithm)
.describe(
"The convolution algorithm for Winograd NNPACK. "
"E.g. tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8 for WT_8x8, "
"tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16 for WT_8x8_FP16");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
/*! \brief Attributes used in softmax operators */ /*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> { struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
......
...@@ -183,6 +183,26 @@ struct WinogradWeightTransformParam : public dmlc::Parameter<WinogradWeightTrans ...@@ -183,6 +183,26 @@ struct WinogradWeightTransformParam : public dmlc::Parameter<WinogradWeightTrans
static const constexpr int kWeight = 0; static const constexpr int kWeight = 0;
}; };
struct WinogradNNPACKWeightTransformParam
: public dmlc::Parameter<WinogradNNPACKWeightTransformParam> {
int convolution_algorithm;
int out_dtype;
DMLC_DECLARE_PARAMETER(WinogradNNPACKWeightTransformParam) {
DMLC_DECLARE_FIELD(convolution_algorithm)
.describe(
"The convolution algorithm for Winograd NNPACK. "
"E.g. tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8 for WT_8x8, "
"tvm.contrib.nnpack.ConvolutionAlgorithm.WT_8x8_FP16 for WT_8x8_FP16");
DMLC_DECLARE_DTYPE_FIELD(out_dtype)
.add_enum("same", -1)
.set_default(-1)
.describe("Output data type, set to explicit type under mixed precision setting");
}
static const constexpr int kWeight = 0;
};
struct WinogradConv2DParam : public dmlc::Parameter<WinogradConv2DParam> { struct WinogradConv2DParam : public dmlc::Parameter<WinogradConv2DParam> {
int channels; int channels;
TShape kernel_size; TShape kernel_size;
......
...@@ -161,6 +161,10 @@ def alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -161,6 +161,10 @@ def alter_conv2d_layout(attrs, inputs, tinfos):
sym.contrib.conv2d_winograd_without_weight_transform sym.contrib.conv2d_winograd_without_weight_transform
sym.contrib_conv2d_winograd_weight_transform = \ sym.contrib_conv2d_winograd_weight_transform = \
sym.contrib.conv2d_winograd_weight_transform sym.contrib.conv2d_winograd_weight_transform
sym.contrib_conv2d_winograd_nnpack_without_weight_transform = \
sym.contrib.conv2d_winograd_nnpack_without_weight_transform
sym.contrib_conv2d_winograd_nnpack_weight_transform = \
sym.contrib.conv2d_winograd_nnpack_weight_transform
sym.nn = sym sym.nn = sym
# map relay argument names to nnvm argument names # map relay argument names to nnvm argument names
...@@ -274,6 +278,49 @@ reg.register_pattern("_contrib_conv2d_winograd_without_weight_transform", ...@@ -274,6 +278,49 @@ reg.register_pattern("_contrib_conv2d_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE) OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("_contrib_conv2d_winograd_nnpack_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, _):
convolution_algorithm = attrs.get_int('convolution_algorithm')
out_dype = attrs.get_str('out_dtype')
return topi.nn.conv2d_winograd_nnpack_weight_transform(
inputs[0], convolution_algorithm, out_dype)
@reg.register_schedule("_contrib_conv2d_winograd_nnpack_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
reg.register_pattern("_contrib_conv2d_winograd_nnpack_weight_transform", OpPattern.OPAQUE)
@reg.register_compute("_contrib_conv2d_winograd_nnpack_without_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, inputs, _):
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.get_str("layout")
out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "same" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"
# pylint: disable=assignment-from-no-return
out = topi.nn.conv2d_winograd_nnpack_without_weight_transform(
inputs[0], inputs[1], inputs[2] if attrs.get_bool("use_bias") else None,
strides, padding, dilation, layout, out_dtype)
return out
@reg.register_schedule("_contrib_conv2d_winograd_nnpack_without_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
with tvm.target.create(target):
return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)
reg.register_pattern("_contrib_conv2d_winograd_nnpack_without_weight_transform",
OpPattern.OPAQUE)
# conv2d_transpose # conv2d_transpose
@reg.register_compute("conv2d_transpose") @reg.register_compute("conv2d_transpose")
def compute_conv2d_transpose(attrs, inputs, _): def compute_conv2d_transpose(attrs, inputs, _):
......
...@@ -130,13 +130,14 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs, ...@@ -130,13 +130,14 @@ inline bool Conv2DInferShape(const nnvm::NodeAttrs& attrs,
return true; return true;
} }
template<class Param>
inline bool WinogradConv2DInferShape(const nnvm::NodeAttrs& attrs, inline bool WinogradConv2DInferShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_shape, std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape) { std::vector<TShape>* out_shape) {
static const Layout kNCHW("NCHW"); static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW"); static const Layout kOIHW("OIHW");
const WinogradConv2DParam& param = nnvm::get<WinogradConv2DParam>(attrs.parsed); const Param& param = nnvm::get<Param>(attrs.parsed);
const Layout in_layout(param.layout); const Layout in_layout(param.layout);
const Layout kernel_layout(param.kernel_layout); const Layout kernel_layout(param.kernel_layout);
...@@ -403,7 +404,7 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform) ...@@ -403,7 +404,7 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)
.set_attr_parser(ParamParser<WinogradConv2DParam>) .set_attr_parser(ParamParser<WinogradConv2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<WinogradConv2DParam>) .set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<WinogradConv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<WinogradConv2DParam>) .set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<WinogradConv2DParam>)
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape) .set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape<WinogradConv2DParam>)
.set_attr<FInferType>("FInferType", Conv2DInferType<WinogradConv2DParam>) .set_attr<FInferType>("FInferType", Conv2DInferType<WinogradConv2DParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout<WinogradConv2DParam>) .set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout<WinogradConv2DParam>)
.set_num_outputs(1) .set_num_outputs(1)
...@@ -412,6 +413,82 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform) ...@@ -412,6 +413,82 @@ NNVM_REGISTER_OP(_contrib_conv2d_winograd_without_weight_transform)
DMLC_REGISTER_PARAMETER(WinogradConv2DParam); DMLC_REGISTER_PARAMETER(WinogradConv2DParam);
inline bool Conv2DWinogradNNPACKWTInferType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_type,
std::vector<int>* out_type) {
const WinogradNNPACKWeightTransformParam& param =
nnvm::get<WinogradNNPACKWeightTransformParam>(attrs.parsed);
CHECK_EQ(in_type->size(), 1U) << "Input:[weight]";
CHECK_EQ(out_type->size(), 1U);
if (param.out_dtype != -1) {
NNVM_ASSIGN_OUTPUT_TYPE(attrs, *out_type, 0, param.out_dtype);
} else {
ElemwiseType<1, 1>(attrs, in_type, out_type);
}
return true;
}
NNVM_REGISTER_OP(_contrib_conv2d_winograd_nnpack_weight_transform)
.describe(R"code(Weight transformation of winograd fast convolution algorithm.
Separate this into another nnvm symbol in order to enable Precompute Pass to compute the
weight transformation in advance.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
)code" NNVM_ADD_FILELINE)
.add_argument("weight", "4D Tensor", "Weight tensor.")
.add_arguments(WinogradNNPACKWeightTransformParam::__FIELDS__())
.set_attr_parser(ParamParser<WinogradNNPACKWeightTransformParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<WinogradNNPACKWeightTransformParam>)
.set_attr<FInferShape>("FInferShape", [](const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
const TShape &wshape = (*in_shape)[0];
CHECK_EQ(wshape.ndim(), 4) << "Weight should be a 4 dimensional tensor";
TShape oshape({wshape[0], wshape[1], 8, 8});
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, 0, oshape);
return true;
})
.set_attr<FCorrectLayout>("FCorrectLayout", [](const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
Layout layout("OIHW");
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, layout);
NNVM_ASSIGN_LAYOUT(*olayouts, 0, layout);
return true;
})
.set_attr<FInferType>("FInferType", Conv2DWinogradNNPACKWTInferType)
.set_num_outputs(1)
.set_num_inputs(1)
.set_support_level(5);
DMLC_REGISTER_PARAMETER(WinogradNNPACKWeightTransformParam);
NNVM_REGISTER_OP(_contrib_conv2d_winograd_nnpack_without_weight_transform)
.describe(R"code(Compute conv2d with winograd nnpack.
- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
- **weight**: Any shape
We do not check shape for this input tensor.
- **bias**: (channels,)
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" NNVM_ADD_FILELINE)
.add_argument("data", "4D Tensor", "Input data.")
.add_argument("weight", "4D Tensor", "Transformed weight tensor.")
.add_argument("bias", "1D Tensor", "Bias parameter.")
.add_arguments(Conv2DParam::__FIELDS__())
.set_attr_parser(ParamParser<Conv2DParam>)
.set_attr<FGetAttrDict>("FGetAttrDict", ParamGetAttrDict<Conv2DParam>)
.set_attr<FListInputNames>("FListInputNames", UseBiasListInputNames<Conv2DParam>)
.set_attr<FInferShape>("FInferShape", WinogradConv2DInferShape<Conv2DParam>)
.set_attr<FInferType>("FInferType", Conv2DInferType<Conv2DParam>)
.set_attr<FCorrectLayout>("FCorrectLayout", Conv2DCorrectLayout<Conv2DParam>)
.set_num_outputs(1)
.set_num_inputs(UseBiasNumInputs<Conv2DParam>)
.set_support_level(5);
NNVM_REGISTER_OP(_conv2d_grad) NNVM_REGISTER_OP(_conv2d_grad)
.describe(R"code(2D convolution grad. .describe(R"code(2D convolution grad.
......
...@@ -149,11 +149,12 @@ def convolution_inference_without_weight_transform( ...@@ -149,11 +149,12 @@ def convolution_inference_without_weight_transform(
ins[1], ins[1],
ins[2] if bias is not None else 0, ins[2] if bias is not None else 0,
outs[0], padding[0], padding[1], padding[2], padding[3], outs[0], padding[0], padding[1], padding[2], padding[3],
stride[0], stride[1], nthreads, algorithm), name="C") stride[0], stride[1], nthreads, algorithm), name="C", dtype='float32')
def convolution_inference_weight_transform( def convolution_inference_weight_transform(
kernel, nthreads=1, kernel, nthreads=1,
algorithm=ConvolutionAlgorithm.AUTO): algorithm=ConvolutionAlgorithm.AUTO,
dtype='float32'):
"""Create an extern op to do inference convolution of 3D tensor data and """Create an extern op to do inference convolution of 3D tensor data and
4D tensor kernel and 1D tensor bias with nnpack. 4D tensor kernel and 1D tensor bias with nnpack.
...@@ -171,13 +172,14 @@ def convolution_inference_weight_transform( ...@@ -171,13 +172,14 @@ def convolution_inference_weight_transform(
""" """
assert algorithm in (ConvolutionAlgorithm.WT_8x8, ConvolutionAlgorithm.WT_8x8_FP16) assert algorithm in (ConvolutionAlgorithm.WT_8x8, ConvolutionAlgorithm.WT_8x8_FP16)
output_channels, input_channels, _, _ = kernel.shape output_channels, input_channels, _, _ = kernel.shape
transform_tile_size = 8 transform_tile_size = 8
if not isinstance(dtype, str):
dtype = dtype.dtype
return _api.extern( return _api.extern(
(output_channels, input_channels, transform_tile_size, transform_tile_size), (output_channels, input_channels, transform_tile_size, transform_tile_size),
[kernel], [kernel],
lambda ins, outs: _intrin.call_packed( lambda ins, outs: _intrin.call_packed(
"tvm.contrib.nnpack.convolution_inference_weight_transform", "tvm.contrib.nnpack.convolution_inference_weight_transform",
ins[0], outs[0], nthreads, algorithm), name="transform_kernel") ins[0], outs[0], nthreads, algorithm), name="transform_kernel", dtype=dtype)
_init_api("tvm.contrib.nnpack") _init_api("tvm.contrib.nnpack")
...@@ -326,6 +326,58 @@ def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target): ...@@ -326,6 +326,58 @@ def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform", reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE) OpPattern.OUT_ELEMWISE_FUSABLE)
# winograd nnpack related operators
@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_without_weight_transform(
attrs, inputs, out_dtype, target):
"""Compute definition of conv2d_winograd_nnpack_without_weight_transform"""
# pylint: disable=assignment-from-no-return
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
data_layout = attrs.get_str("data_layout")
out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"
# No bias
out = topi.nn.conv2d_winograd_nnpack_without_weight_transform(
inputs[0], inputs[1], None, strides, padding, dilation, data_layout,
out_dtype)
return [out]
@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_without_weight_transform(attrs, outs, target):
"""Schedule definition of conv2d_winograd_nnpack_without_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_without_weight_transform(outs)
reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_without_weight_transform",
OpPattern.OPAQUE)
@reg.register_compute("nn.contrib_conv2d_winograd_nnpack_weight_transform")
def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_dtype, target):
"""Compute definition of contrib_conv2d_winograd_nnpack_weight_transform"""
convolution_algorithm = attrs.get_int('convolution_algorithm')
out = topi.nn.conv2d_winograd_nnpack_weight_transform(
inputs[0], convolution_algorithm, out_dtype)
return [out]
@reg.register_schedule("nn.contrib_conv2d_winograd_nnpack_weight_transform")
def schedule_contrib_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
"""Schedule definition of contrib_conv2d_winograd_nnpack_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
reg.register_pattern("nn.contrib_conv2d_winograd_nnpack_weight_transform",
OpPattern.OPAQUE)
@reg.register_compute("nn.contrib_conv2d_NCHWc") @reg.register_compute("nn.contrib_conv2d_NCHWc")
def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target): def compute_contrib_conv2d_NCHWc(attrs, inputs, out_dtype, target):
"""Compute definition of conv2d NCHWc""" """Compute definition of conv2d NCHWc"""
......
#pylint: disable=invalid-name, too-many-lines
"""Neural network operations.""" """Neural network operations."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ...expr import TupleWrapper from ...expr import TupleWrapper
...@@ -862,6 +863,72 @@ def contrib_conv2d_winograd_without_weight_transform(data, ...@@ -862,6 +863,72 @@ def contrib_conv2d_winograd_without_weight_transform(data,
kernel_layout, out_layout, out_dtype) kernel_layout, out_layout, out_dtype)
def contrib_conv2d_winograd_nnpack_without_weight_transform(data,
weight,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""2D convolution with the NNPACK implementation of winograd algorithm.
The basic parameters are the same as the ones in vanilla conv2d.
It assumes the weight is pre-transformed by nn.contrib_conv2d_winograd_nnpack_weight_transform
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
strides : tuple of int, optional
The strides of convoltution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv2d_winograd_nnpack_without_weight_transform(
data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def contrib_conv2d_nchwc(data, def contrib_conv2d_nchwc(data,
kernel, kernel,
strides=(1, 1), strides=(1, 1),
...@@ -1013,3 +1080,28 @@ def contrib_conv2d_winograd_weight_transform(weight, ...@@ -1013,3 +1080,28 @@ def contrib_conv2d_winograd_weight_transform(weight,
The computed result. The computed result.
""" """
return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size) return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)
def contrib_conv2d_winograd_nnpack_weight_transform(weight,
convolution_algorithm,
out_dtype=""):
r"""Weight Transformation part for 2D convolution with winograd algorithm.
We separate this as a single op to enable pre-compute for inference.
Use this together with nn.contrib_conv2d_winograd_without_weight_transform
Parameters
----------
weight : tvm.relay.Expr
The weight expressions.
convolution_algorithm : int
The Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv2d_winograd_nnpack_weight_transform(
weight, convolution_algorithm, out_dtype)
...@@ -19,5 +19,10 @@ class Conv2DWinogradWeightTransformAttrs(Attrs): ...@@ -19,5 +19,10 @@ class Conv2DWinogradWeightTransformAttrs(Attrs):
@register_relay_attr_node @register_relay_attr_node
class Conv2DWinogradNNPACKWeightTransformAttrs(Attrs):
"""Attribute of nn.contrib_conv2d_winograd_nnpack_weight_transform"""
@register_relay_attr_node
class GlobalPool2DAttrs(Attrs): class GlobalPool2DAttrs(Attrs):
"""Attribute of nn.global_pool""" """Attribute of nn.global_pool"""
...@@ -189,20 +189,20 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra ...@@ -189,20 +189,20 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra
CHECK(workspace_buffer != nullptr); CHECK(workspace_buffer != nullptr);
for (auto n = 0; n < input->shape[0]; ++n) { for (auto n = 0; n < input->shape[0]; ++n) {
nnp_status status = nnp_convolution_inference( nnp_status status = nnp_convolution_inference(
algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels, algo, nnp_convolution_transform_strategy_reuse, input_channels, output_channels,
input_size, input_padding, kernel_size, stride_size, input_size, input_padding, kernel_size, stride_size,
static_cast<float *>(input->data) + n * input->shape[1] * static_cast<float *>(input->data) + n * input->shape[1] *
input->shape[2] * input->shape[2] *
input->shape[3], input->shape[3],
static_cast<float *>(transformed_kernel->data), static_cast<float *>(transformed_kernel->data),
bias ? static_cast<float *>(bias->data) : zero_bias->data(), bias ? static_cast<float *>(bias->data) : zero_bias->data(),
static_cast<float *>(output->data) + n * output->shape[1] * static_cast<float *>(output->data) + n * output->shape[1] *
output->shape[2] * output->shape[2] *
output->shape[3], output->shape[3],
workspace_buffer, &workspace_size, workspace_buffer, &workspace_size,
nnp_activation_identity, nullptr, entry->threadpool, nullptr); nnp_activation_identity, nullptr, entry->threadpool, nullptr);
CHECK_EQ(status, nnp_status_success); CHECK_EQ(status, nnp_status_success);
} }
cpu_api->FreeWorkspace(ctx, workspace_buffer); cpu_api->FreeWorkspace(ctx, workspace_buffer);
......
...@@ -344,6 +344,7 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW` ...@@ -344,6 +344,7 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
// relay.nn.contrib_conv2d_winograd_without_weight_transform // relay.nn.contrib_conv2d_winograd_without_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs); TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs);
template<class Param>
bool Conv2DWinogradRel(const Array<Type>& types, bool Conv2DWinogradRel(const Array<Type>& types,
int num_inputs, int num_inputs,
const Attrs& attrs, const Attrs& attrs,
...@@ -354,7 +355,7 @@ bool Conv2DWinogradRel(const Array<Type>& types, ...@@ -354,7 +355,7 @@ bool Conv2DWinogradRel(const Array<Type>& types,
static const Layout kNCHW("NCHW"); static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW"); static const Layout kOIHW("OIHW");
const Conv2DWinogradAttrs* param = attrs.as<Conv2DWinogradAttrs>(); const Param* param = attrs.as<Param>();
CHECK(param != nullptr); CHECK(param != nullptr);
const Layout in_layout(param->data_layout); const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout); const Layout kernel_layout(param->kernel_layout);
...@@ -467,7 +468,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") ...@@ -467,7 +468,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.") .add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10) .set_support_level(10)
.add_type_rel("Conv2DWinograd", Conv2DWinogradRel) .add_type_rel("Conv2DWinograd", Conv2DWinogradRel<Conv2DWinogradAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DWinogradAttrs>); Conv2DInferCorrectLayout<Conv2DWinogradAttrs>);
...@@ -511,8 +512,8 @@ Expr MakeConv2DWinogradWeightTransform(Expr weight, ...@@ -511,8 +512,8 @@ Expr MakeConv2DWinogradWeightTransform(Expr weight,
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform") TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradWeightTransform, args, rv); runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradWeightTransform, args, rv);
}); });
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform") RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
...@@ -530,6 +531,124 @@ weight transformation in advance. ...@@ -530,6 +531,124 @@ weight transformation in advance.
.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel); .add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
// Positional relay function to create conv2d winograd nnpack operator
// used by frontend FFI.
Expr MakeConv2DWinogradNNPACK(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_without_weight_transform");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeConv2DWinogradNNPACK, args, rv);
});
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
.describe(R"code(Compute conv2d with winograd nnpack. Only supports NCHW layout.
This operator assumes the weight tensor is already pre-transformed by
nn.contrib_conv2d_winograd_nnpack_weight_transform.
- **data**: Input is 4D array of shape (batch_size, in_channels, height, width)
- **weight**: Any shape
We do not check the shape for this input tensor. Since different backend
has different layout strategy.
- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DAttrs")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKRel", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
// relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
bool Conv2DWinogradNNPACKWeightTransformRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const Conv2DWinogradNNPACKWeightTransformAttrs* param =
attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>();
CHECK(param != nullptr);
CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
std::vector<IndexExpr> oshape{
data->shape[0],
data->shape[1],
8,
8,
};
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape), out_dtype));
return true;
}
Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight,
int convolution_algorithm,
DataType out_dtype) {
auto attrs = make_node<Conv2DWinogradNNPACKWeightTransformAttrs>();
attrs->convolution_algorithm = convolution_algorithm;
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_weight_transform");
return CallNode::make(op, {weight}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeConv2DWinogradNNPACKWeightTransform, args, rv);
});
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform")
.describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK.
Separate this into another symbol in order to enable Precompute Pass to compute the
weight transformation in advance.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs")
.set_num_inputs(1)
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel);
// Positional relay function to create conv2d NCHWc operator // Positional relay function to create conv2d NCHWc operator
// used by frontend FFI. // used by frontend FFI.
Expr MakeConv2DNCHWc(Expr data, Expr MakeConv2DNCHWc(Expr data,
......
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm.task.space import FallbackConfigEntity
from tvm.contrib import nnpack
from tvm.contrib.pickle_memoize import memoize
import topi
import topi.testing
from topi.util import get_const_tuple
from nose import SkipTest
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False,
devices=['cuda', 'llvm -device=arm_cpu', 'opencl -device=mali']):
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
bias = tvm.placeholder((num_filter, 1, 1), name='bias')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
raise SkipTest("Skip because %s is not enabled" % device)
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NCHW', out_dtype=dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.generic.schedule_conv2d_nchw([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
for device in devices:
check_device(device)
class WinogradFallback(autotvm.FallbackContext):
def _query_inside(self, target, workload):
key = (target, workload)
if key in self.memory:
return self.memory[key]
cfg = FallbackConfigEntity()
cfg.template_key = 'winograd_nnpack_fp32'
self.memory[key] = cfg
return cfg
def test_conv2d_nchw():
if not tvm.get_global_func("tvm.contrib.nnpack.convolution_inference_without_weight_transform", True):
raise SkipTest("skip because extern function is not available")
if not nnpack.is_available():
raise SkipTest("skip because nnpack is not available")
devices = ['llvm -device=arm_cpu']
autotvm.DispatchContext.current.silent = True
with WinogradFallback():
# resnet 18 workloads
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, devices=devices)
verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1, devices=devices)
verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1, devices=devices)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1, devices=devices)
# unet workloads
verify_conv2d_nchw(1, 3, 192, 12, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 4, 192, 12, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 12, 96, 24, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 24, 48, 48, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 48, 24, 96, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 96, 12, 180, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 180, 6, 220, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 220, 6, 180, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 180, 12, 96, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 96, 24, 48, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 48, 48, 24, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 24, 96, 12, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 12, 192, 1, 3, 1, 1, add_bias=True, devices=devices)
# relu, bias
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_bias=True, devices=devices)
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_relu=True, devices=devices)
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1, add_relu=True, add_bias=True, devices=devices)
# werid workloads
verify_conv2d_nchw(1, 3, 3, 3, 3, 1, 1, devices=devices)
verify_conv2d_nchw(1, 13, 71, 59, 3, 1, 1, devices=devices)
if __name__ == "__main__":
import nose
nose.runmodule()
...@@ -122,6 +122,39 @@ def schedule_conv2d_winograd_without_weight_transform(outs): ...@@ -122,6 +122,39 @@ def schedule_conv2d_winograd_without_weight_transform(outs):
@tvm.target.generic_func @tvm.target.generic_func
def schedule_conv2d_winograd_nnpack_weight_transform(outs):
"""Schedule for weight transformation of winograd
Parameters
----------
outs: Array of Tensor
The computation graph description of this operator
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
# Typically this is computed in nnvm PreCompute pass
s = tvm.create_schedule([x.op for x in outs])
return s
@tvm.target.generic_func
def schedule_conv2d_winograd_nnpack_without_weight_transform(outs):
"""Schedule for winograd without weight transformation
Parameters
----------
outs: Array of Tensor
The computation graph description of this operator
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv2d_transpose_nchw(outs): def schedule_conv2d_transpose_nchw(outs):
"""Schedule for conv2d_transpose_nchw """Schedule for conv2d_transpose_nchw
......
...@@ -410,6 +410,48 @@ def conv2d_winograd_without_weight_transform(input, filter, strides, padding, di ...@@ -410,6 +410,48 @@ def conv2d_winograd_without_weight_transform(input, filter, strides, padding, di
raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform") raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")
def conv2d_winograd_nnpack_weight_transform(kernel, convolution_algorithm, out_dtype):
"""Weight transformation for winograd
Parameters
----------
kernel: Tensor
The raw kernel tensor with layout "NCHW". Only 3x3 kernel is supported for now.
convolution_algorithm: int
The convolution algorithm for Winograd NNPACK.
Returns
-------
output : tvm.Tensor
4-D with shape [alpha, alpha, CO, CI]
"""
from tvm.contrib import nnpack
return nnpack.convolution_inference_weight_transform(
kernel, algorithm=convolution_algorithm, dtype=out_dtype)
@tvm.target.generic_func
def conv2d_winograd_nnpack_without_weight_transform(
input, filter, bias, strides, padding, dilation, layout, out_dtype):
"""Compute convolution in winograd algorithm. The filter is supposed to be transformed
in advance.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
filter : tvm.Tensor
4-D with shape [num_filter, in_channel, 8, 8]
bias : tvm.Tensor
1-D with shape [num_filter]
strides : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
output : tvm.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
raise ValueError("missing register for topi.nn.conv2d_winograd_without_weight_transform")
@tvm.target.generic_func @tvm.target.generic_func
def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=None): def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtype=None):
"""Group convolution operator in NCHW layout. """Group convolution operator in NCHW layout.
......
...@@ -10,7 +10,8 @@ from tvm.contrib.pickle_memoize import memoize ...@@ -10,7 +10,8 @@ from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple from topi.util import get_const_tuple
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False,
devices=['cuda', 'llvm -device=arm_cpu', 'opencl -device=mali']):
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
in_height = in_width = in_size in_height = in_width = in_size
...@@ -67,7 +68,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -67,7 +68,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'llvm -device=arm_cpu', 'opencl -device=mali']: for device in devices:
check_device(device) check_device(device)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment