Unverified Commit 02eb1833 by Josh Fromm Committed by GitHub

[Relay][Topi][AutoTVM] Winograd support for Conv3D (#5186)

* Functional conv3d winograd working.

* Formatted python code.

* registered conv3d winograd compute and started adding relay without_weight_transform operator.

* Add topi testing for conv3d winograd.

* Format file.

* small tweak to unrolling to prevent build sticking.

* Refactoring convolution ops in relay.

* Refactored relay convolutions.

* Bug fixes.

* Fixed static bug in convolution.

* Added conv3d alter op layout and related support.

* Bug fixes and testing done.

* Fix a few autotvm bugs.

* Drop silly debug print.

* Removed debug_skip_region.

* Add variant of conv3d_winograd that doesn't transform depth.

* initial infrastructure done for depthless conv.

* Fix no_depth schedule bugs.

* automatic topi switching between depth and depthless winograd.

* Fixed bug in schedule.

* lint fixes.

* Removed indents in convolution.cc

* missed a few indents oops.

* fixed flop count.

* One more small tweak.

* Change kernel pack inner axes order.

* Style changes.

* Comment fixes.
parent c76cbd8d
......@@ -82,8 +82,13 @@ This level enables typical convnet models.
tvm.relay.nn.pad
tvm.relay.nn.lrn
tvm.relay.nn.l2_normalize
tvm.relay.nn.bitpack
tvm.relay.nn.bitserial_dense
tvm.relay.nn.bitserial_conv2d
tvm.relay.nn.contrib_conv2d_winograd_without_weight_transform
tvm.relay.nn.contrib_conv2d_winograd_weight_transform
tvm.relay.nn.contrib_conv3d_winograd_without_weight_transform
tvm.relay.nn.contrib_conv3d_winograd_weight_transform
**Level 3: Additional Math And Transform Operators**
......
......@@ -156,12 +156,12 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
};
/*! \brief Attributes used in winograd weight transformation operators */
struct Conv2DWinogradWeightTransformAttrs :
public tvm::AttrsNode<Conv2DWinogradWeightTransformAttrs> {
struct ConvWinogradWeightTransformAttrs :
public tvm::AttrsNode<ConvWinogradWeightTransformAttrs> {
int tile_size;
TVM_DECLARE_ATTRS(Conv2DWinogradWeightTransformAttrs,
"relay.attrs.Conv2DWinogradWeightTransformAttrs") {
TVM_DECLARE_ATTRS(ConvWinogradWeightTransformAttrs,
"relay.attrs.ConvWinogradWeightTransformAttrs") {
TVM_ATTR_FIELD(tile_size)
.describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
}
......@@ -306,6 +306,69 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
}
};
/*! \brief Attributes used in 3d winograd convolution operators */
struct Conv3DWinogradAttrs : public tvm::AttrsNode<Conv3DWinogradAttrs> {
int tile_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv3DWinogradAttrs, "relay.attrs.Conv3DWinogradAttrs") {
TVM_ATTR_FIELD(tile_size)
.describe("The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)");
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"six int : padding width in the order of (front, top, left, back, bottom,"
"right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCDHW")
.describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Convolution is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW")
.describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
"'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height,"
"and width dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
/*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
int axis;
......
......@@ -178,6 +178,29 @@ def legalize_conv2d_transpose(attrs, inputs, types):
reg.register_strategy("nn.conv3d", strategy.conv3d_strategy)
reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_alter_op_layout("nn.conv3d")
def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type):
"""Alternate the layout of conv3d"""
return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type)
# conv3d_winograd related operators
reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform",
strategy.conv3d_winograd_without_weight_transfrom_strategy)
reg.register_pattern("nn.contrib_conv3d_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("nn.contrib_conv3d_winograd_weight_transform")
def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype):
"""Compute definition of contrib_conv3d_winograd_weight_transform"""
out = topi.nn.conv3d_winograd_weight_transform(
inputs[0], attrs.get_int('tile_size'))
return [out]
reg.register_schedule("nn.contrib_conv3d_winograd_weight_transform",
strategy.schedule_conv3d_winograd_weight_transform)
reg.register_pattern("nn.contrib_conv3d_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)
# conv1d_transpose
reg.register_strategy("nn.conv1d_transpose", strategy.conv1d_transpose_strategy)
......
......@@ -19,7 +19,7 @@
from __future__ import absolute_import as _abs
from ...expr import TupleWrapper
from . import _make
from .util import get_pad_tuple2d
from .util import get_pad_tuple2d, get_pad_tuple3d
def conv1d(data,
......@@ -295,13 +295,84 @@ def conv3d(data,
strides = (strides, strides, strides)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
if isinstance(padding, int):
padding = (padding, padding, padding)
padding = get_pad_tuple3d(padding)
return _make.conv3d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def contrib_conv3d_winograd_without_weight_transform(data,
weight,
tile_size,
strides=(1, 1, 1),
padding=(0, 0, 0),
dilation=(1, 1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCDHW",
kernel_layout="OIDHW",
out_layout="",
out_dtype=""):
r"""3D convolution with winograd algorithm.
The basic parameters are the same as the ones in vanilla conv3d.
It assumes the weight is pre-transformed by nn.contrib_conv3d_winograd_weight_transform
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
tile_size : int
The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)
strides : tuple of int, optional
The strides of convolution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# convert 3-way padding to 6-way padding
padding = get_pad_tuple3d(padding)
return _make.contrib_conv3d_winograd_without_weight_transform(
data, weight, tile_size, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def conv2d_transpose(data,
weight,
strides=(1, 1),
......@@ -1952,6 +2023,29 @@ def contrib_conv2d_winograd_weight_transform(weight,
return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)
def contrib_conv3d_winograd_weight_transform(weight,
tile_size):
r"""Weight Transformation part for 3D convolution with winograd algorithm.
We separate this as a single op to enable pre-compute for inference.
Use this together with nn.contrib_conv3d_winograd_without_weight_transform
Parameters
----------
weight : tvm.relay.Expr
The weight expressions.
tile_size : int
The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv3d_winograd_weight_transform(weight, tile_size)
def contrib_conv2d_winograd_nnpack_weight_transform(weight,
convolution_algorithm,
out_dtype=""):
......
......@@ -54,3 +54,46 @@ def get_pad_tuple2d(padding):
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left
def get_pad_tuple3d(padding):
"""Common code to get the pad option
Parameters
----------
padding : Union[int, Tuple[int, ...]]
Padding size
Returns
-------
pad_front : int
Padding size on front
pad_top : int
Padding size on top
pad_left : int
Padding size on left
pad_back : int
Padding size on back
pad_down : int
Padding size on down.
pad_right : int
Padding size on right.
"""
# compute the padding size
if isinstance(padding, container.Array):
padding = list(padding)
if isinstance(padding, (tuple, list)):
if len(padding) == 3:
pad_d = padding[0] * 2
pad_h = padding[1] * 2
pad_w = padding[2] * 2
elif len(padding) == 6:
return padding[0], padding[1], padding[2], padding[3], padding[4], padding[5]
else:
raise ValueError("Size of padding can only be 3 or 6")
elif isinstance(padding, int):
pad_d = pad_h = pad_w = padding * 2
else:
raise ValueError("Unknown padding option %s" % padding)
pad_front = (pad_d + 1) // 2
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
return pad_front, pad_top, pad_left, pad_d - pad_front, pad_h - pad_top, pad_w - pad_left
......@@ -34,9 +34,19 @@ class Conv2DWinogradAttrs(Attrs):
"""Attributes for nn.contrib_conv2d_winograd_without_weight_transform"""
@tvm._ffi.register_object("relay.attrs.Conv2DWinogradWeightTransformAttrs")
class Conv2DWinogradWeightTransformAttrs(Attrs):
"""Attributes for nn.contrib_conv2d_winograd_weight_transform"""
@tvm._ffi.register_object("relay.attrs.Conv3DAttrs")
class Conv3DAttrs(Attrs):
"""Attributes for nn.conv3d"""
@tvm._ffi.register_object("relay.attrs.Conv3DWinogradAttrs")
class Conv3DWinogradAttrs(Attrs):
"""Attributes for nn.contrib_conv3d_winograd_without_weight_transform"""
@tvm._ffi.register_object("relay.attrs.ConvWinogradWeightTransformAttrs")
class ConvWinogradWeightTransformAttrs(Attrs):
"""Attributes for nn.contrib_convNd_winograd_weight_transform"""
@tvm._ffi.register_object("relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs")
......
......@@ -233,13 +233,25 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
def conv3d_strategy_cuda(attrs, inputs, out_type, target):
"""conv3d cuda strategy"""
strategy = _op.OpStrategy()
_, kernel = inputs
layout = attrs.data_layout
_, stride_h, stride_w = attrs.get_int_tuple("strides")
_, dilation_h, dilation_w = attrs.get_int_tuple("dilation")
assert layout in ["NCDHW", "NDHWC"], "Not support this layout {} yet".format(layout)
if layout == "NCDHW":
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ncdhw),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
name="conv3d_ncdhw.cuda",
plevel=10)
_, _, _, kh, kw = get_const_tuple(kernel.shape)
if 2 < kh < 8 and 2 < kw < 8 and kh == kw and \
stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
strategy.add_implementation(
wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw_winograd),
name="conv3d_ncdhw_winograd.cuda",
plevel=5)
else: # layout == "NDHWC":
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
......@@ -252,6 +264,26 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
plevel=15)
return strategy
@conv3d_winograd_without_weight_transfrom_strategy.register(["cuda", "gpu"])
def conv3d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_type, target):
"""conv3d_winograd_without_weight_transfrom cuda strategy"""
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.data_layout
assert dilation == (1, 1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"
strategy = _op.OpStrategy()
if layout == "NCDHW":
strategy.add_implementation(
wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd_without_weight_transform),
wrap_topi_schedule(
topi.cuda.schedule_conv3d_ncdhw_winograd_without_weight_transform),
name="conv3d_ncdhw_winograd_without_weight_transform.cuda")
else:
raise RuntimeError("Unsupported conv3d_winograd_without_weight_transfrom layout {}".
format(layout))
return strategy
@conv1d_strategy.register(["cuda", "gpu"])
def conv1d_strategy_cuda(attrs, inputs, out_type, target):
"""conv1d cuda strategy"""
......
......@@ -374,6 +374,19 @@ def conv3d_strategy(attrs, inputs, out_type, target):
raise ValueError("Not support this layout {} yet".format(layout))
return strategy
# conv3d_winograd_without_weight_transform
@override_native_generic_func("conv3d_winograd_without_weight_transform_strategy")
def conv3d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target):
"""conv3d_winograd_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv3d_winograd_without_weight_transform")
# conv3d_winograd_weight_transform
@generic_func
def schedule_conv3d_winograd_weight_transform(attrs, outs, target):
"""Schedule conv3d_winograd_weight_transform"""
with target:
return topi.generic.schedule_conv3d_winograd_weight_transform(outs)
# conv1d
def wrap_compute_conv1d(topi_compute):
"""wrap conv1d topi compute"""
......
......@@ -59,10 +59,113 @@ Expr MakeConv(Expr data,
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get(op_name);
const Op& op = Op::Get(op_name);
return Call(op, {data, weight}, Attrs(attrs), {});
}
template <typename T>
Expr MakeConvWinograd(Expr data,
Expr weight,
int tile_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype,
std::string op_name) {
auto attrs = make_object<T>();
attrs->tile_size = tile_size;
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = std::move(channels);
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
const Op& op = Op::Get(op_name);
return Call(op, {data, weight}, Attrs(attrs), {});
}
Expr MakeConvWinogradWeightTransform(Expr weight,
int tile_size,
std::string op_name) {
auto attrs = make_object<ConvWinogradWeightTransformAttrs>();
attrs->tile_size = tile_size;
const Op& op = Op::Get(op_name);
return Call(op, {weight}, Attrs(attrs), {});
}
template <typename T>
Expr MakeConvTranspose(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
Array<IndexExpr> output_padding,
DataType out_dtype,
std::string op_name) {
auto attrs = make_object<T>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = std::move(channels);
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->output_padding = std::move(output_padding);
attrs->out_dtype = std::move(out_dtype);
const Op& op = Op::Get(op_name);
return Call(op, {data, weight}, Attrs(attrs), {});
}
template <typename T>
Expr MakeDeformableConv(Expr data,
Expr offset,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int deformable_groups,
int groups,
int channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype,
std::string op_name) {
auto attrs = make_object<T>();
attrs->strides = strides;
attrs->padding = padding;
attrs->dilation = dilation;
attrs->deformable_groups = deformable_groups;
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = kernel_size;
attrs->data_layout = data_layout;
attrs->kernel_layout = kernel_layout;
attrs->out_layout = out_layout;
attrs->out_dtype = out_dtype;
const Op& op = Op::Get(op_name);
return Call(op, {data, offset, weight}, Attrs{attrs}, {});
}
// relay.nn.conv1d
TVM_REGISTER_NODE_TYPE(Conv1DAttrs);
......@@ -153,6 +256,7 @@ with the layer input to produce a tensor of outputs.
.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
// relay.nn.conv3d
TVM_REGISTER_NODE_TYPE(Conv3DAttrs);
......@@ -198,107 +302,12 @@ with the layer input to produce a tensor of outputs.
.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>);
// relay.nn.conv2d_transpose
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
bool Conv2DTransposeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
const Conv2DTransposeAttrs* param = attrs.as<Conv2DTransposeAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
auto dshape_nchw = trans_in_layout.ForwardShape(data->shape);
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape({dshape_nchw[1],
indexdiv(param->channels, param->groups),
param->kernel_size[0],
param->kernel_size[1]});
wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
channels = param->channels;
// assign result to reporter
reporter->Assign(types[1], TensorType(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size
<< " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
channels = wshape[1];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
pad_h + param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
pad_w + param->output_padding[1]));
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
Expr MakeConv2DTranspose(Expr data,
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose")
.set_body_typed([](Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
......@@ -311,25 +320,11 @@ Expr MakeConv2DTranspose(Expr data,
std::string out_layout,
Array<IndexExpr> output_padding,
DataType out_dtype) {
auto attrs = make_object<Conv2DTransposeAttrs>();
attrs->channels = std::move(channels);
attrs->kernel_size = std::move(kernel_size);
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->output_padding = std::move(output_padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv2d_transpose");
return Call(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d_transpose")
.set_body_typed(MakeConv2DTranspose);
return MakeConvTranspose<Conv2DTransposeAttrs>(
data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, output_padding, out_dtype, "nn.conv2d_transpose");
});
RELAY_REGISTER_OP("nn.conv2d_transpose")
.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
......@@ -360,104 +355,13 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
.set_support_level(2)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ConvInferCorrectLayout<Conv2DTransposeAttrs>)
.add_type_rel("Conv2DTranspose", Conv2DTransposeRel);
.add_type_rel("Conv2DTranspose", Conv2DTransposeRel<Conv2DTransposeAttrs>);
// relay.nn.conv1d_transpose
TVM_REGISTER_NODE_TYPE(Conv1DTransposeAttrs);
bool Conv1DTransposeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCW("NCW");
static const Layout kOIW("OIW");
const Conv1DTransposeAttrs* param = attrs.as<Conv1DTransposeAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCW."
<< " But got " << in_layout;
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCW."
<< " But got " << out_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
auto dshape_ncw = trans_in_layout.ForwardShape(data->shape);
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 1);
CHECK_EQ(param->dilation.size(), 1);
Array<IndexExpr> wshape({dshape_ncw[1],
indexdiv(param->channels, param->groups),
param->kernel_size[0]});
wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
channels = param->channels;
// assign result to reporter
reporter->Assign(types[1], TensorType(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 1);
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
<< "Conv1D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size
<< " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv1D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
channels = wshape[1];
dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0];
}
// dilation
IndexExpr pad_w;
GetPaddingWidth(param->padding, &pad_w);
Array<IndexExpr> oshape({dshape_ncw[0], channels, 0});
oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x -
pad_w + param->output_padding[0]));
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
Expr MakeConv1DTranspose(Expr data,
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d_transpose")
.set_body_typed([](Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
......@@ -470,25 +374,11 @@ Expr MakeConv1DTranspose(Expr data,
std::string out_layout,
Array<IndexExpr> output_padding,
DataType out_dtype) {
auto attrs = make_object<Conv1DTransposeAttrs>();
attrs->channels = std::move(channels);
attrs->kernel_size = std::move(kernel_size);
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->output_padding = std::move(output_padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv1d_transpose");
return Call(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d_transpose")
.set_body_typed(MakeConv1DTranspose);
return MakeConvTranspose<Conv1DTransposeAttrs>(
data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, output_padding, out_dtype, "nn.conv1d_transpose");
});
RELAY_REGISTER_OP("nn.conv1d_transpose")
.describe(R"code(Transposed 1D convolution layer (sometimes called Deconvolution).
......@@ -516,97 +406,13 @@ said convolution.
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv1DTranspose", Conv1DTransposeRel);
.add_type_rel("Conv1DTranspose", Conv1DTransposeRel<Conv1DTransposeAttrs>);
// relay.nn.contrib_conv2d_winograd_without_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradAttrs);
template<class Param>
bool Conv2DWinogradRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
const Param* param = attrs.as<Param>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
CHECK(param->kernel_size.defined() && param->channels.defined())
<< "The kernel size and channels of a Conv must be set or infered by previous pass";
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
// NOTE: Do not check weight shape here!
// Different backend requires different layout to compute
// the batch gemm stage in winograd efficiently, but we want to
// make this op work for all backends.
// So we accept all weight shapes, and assume the TOPI developers
// can handle this correctly in alter_op_layout.
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (!dshape_nchw[2].as<tir::AnyNode>()) {
oshape.Set(2, (dshape_nchw[2] + pad_h
- dilated_ksize_y) / param->strides[0] + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
}
if (!dshape_nchw[3].as<tir::AnyNode>()) {
oshape.Set(3, (dshape_nchw[3] + pad_w
- dilated_ksize_x) / param->strides[1] + 1);
} else {
oshape.Set(3, dshape_nchw[3]);
}
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
// Positional relay function to create conv2d winograd operator
// used by frontend FFI.
Expr MakeConv2DWinograd(Expr data,
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform")
.set_body_typed([](Expr data,
Expr weight,
int tile_size,
Array<IndexExpr> strides,
......@@ -619,25 +425,11 @@ Expr MakeConv2DWinograd(Expr data,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_object<Conv2DWinogradAttrs>();
attrs->tile_size = tile_size;
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_winograd_without_weight_transform");
return Call(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform")
.set_body_typed(MakeConv2DWinograd);
return MakeConvWinograd<Conv2DWinogradAttrs>(
data, weight, tile_size, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_winograd_without_weight_transform");
});
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
......@@ -662,46 +454,14 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
ConvInferCorrectLayout<Conv2DWinogradAttrs>);
// relay.nn.contrib_conv2d_winograd_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
bool Conv2DWinogradWeightTransformRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const Conv2DWinogradWeightTransformAttrs* param = attrs.as<Conv2DWinogradWeightTransformAttrs>();
CHECK(param != nullptr);
CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
// each pad width element should be a pair of positive integers
std::vector<IndexExpr> oshape {
param->tile_size + data->shape[2] - 1,
param->tile_size + data->shape[3] - 1,
data->shape[0],
data->shape[1],
};
reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
data->dtype));
return true;
}
Expr MakeConv2DWinogradWeightTransform(Expr weight,
int tile_size) {
auto attrs = make_object<Conv2DWinogradWeightTransformAttrs>();
attrs->tile_size = tile_size;
static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform");
return Call(op, {weight}, Attrs(attrs), {});
}
TVM_REGISTER_NODE_TYPE(ConvWinogradWeightTransformAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
.set_body_typed(MakeConv2DWinogradWeightTransform);
.set_body_typed([](Expr weight,
int tile_size) {
return MakeConvWinogradWeightTransform(
weight, tile_size, "nn.contrib_conv2d_winograd_weight_transform");
});
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
.describe(R"code(Weight transformation of winograd fast convolution algorithm.
......@@ -711,47 +471,82 @@ weight transformation in advance.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv2DWinogradWeightTransformAttrs>()
.set_attrs_type<ConvWinogradWeightTransformAttrs>()
.set_num_inputs(1)
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DWinogradWeightTransform", Conv2DWinogradWeightTransformRel);
// relay.nn.contrib_conv3d_winograd_without_weight_transform
TVM_REGISTER_NODE_TYPE(Conv3DWinogradAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_without_weight_transform")
.set_body_typed([](Expr data,
Expr weight,
int tile_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
return MakeConvWinograd<Conv3DWinogradAttrs>(
data, weight, tile_size, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype, "nn.contrib_conv3d_winograd_without_weight_transform");
});
RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_without_weight_transform")
.describe(R"code(Compute conv3d with winograd algorithm. Only supports NCDHW layout.
This operator assumes the weight tensor is already pre-transformed by
nn.contrib_conv3d_winograd_weight_transform.
- **data**: Input is 5D array of shape (batch_size, in_channels, depth, height, width)
- **weight**: Any shape
We do not check the shape for this input tensor. Since different backend
has different layout strategy.
- **out**: Output is 5D array of shape (batch_size, channels, depth, out_height, out_width)
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv3DWinogradAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv3DWinograd", Conv3DWinogradRel<Conv3DWinogradAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
ConvInferCorrectLayout<Conv3DWinogradAttrs>);
// relay.nn.contrib_conv3d_winograd_weight_transform
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv3d_winograd_weight_transform")
.set_body_typed([](Expr weight,
int tile_size) {
return MakeConvWinogradWeightTransform(
weight, tile_size, "nn.contrib_conv3d_winograd_weight_transform");
});
RELAY_REGISTER_OP("nn.contrib_conv3d_winograd_weight_transform")
.describe(R"code(Weight transformation of winograd fast 3d convolution algorithm.
Separate this into another operator in order to enable Precompute Pass to compute the
weight transformation in advance.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2])
)code" TVM_ADD_FILELINE)
.set_attrs_type<ConvWinogradWeightTransformAttrs>()
.set_num_inputs(1)
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv3DWinogradWeightTransform", Conv3DWinogradWeightTransformRel);
// relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
bool Conv2DWinogradNNPACKWeightTransformRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const Conv2DWinogradNNPACKWeightTransformAttrs* param =
attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>();
CHECK(param != nullptr);
CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
std::vector<IndexExpr> oshape{
data->shape[0],
data->shape[1],
8,
8,
};
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), out_dtype));
return true;
}
Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight,
int convolution_algorithm,
DataType out_dtype) {
......@@ -779,10 +574,12 @@ weight transformation in advance.
.set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel);
// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
Expr MakeConv2DNCHWc(Expr data,
Expr kernel,
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc")
.set_body_typed([](Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
......@@ -793,24 +590,11 @@ Expr MakeConv2DNCHWc(Expr data,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_object<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_NCHWc");
return Call(op, {data, kernel}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc")
.set_body_typed(MakeConv2DNCHWc);
return MakeConv<Conv2DAttrs>(
data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_NCHWc");
});
RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
......@@ -831,8 +615,9 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
// Positional relay function to create depthwise conv2d NCHWc operator
// used by frontend FFI.
Expr MakeDepthwiseConv2DNCHWc(Expr data,
Expr kernel,
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
.set_body_typed([](Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
......@@ -843,23 +628,11 @@ Expr MakeDepthwiseConv2DNCHWc(Expr data,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_object<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_depthwise_conv2d_NCHWc");
return Call(op, {data, kernel}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
.set_body_typed(MakeDepthwiseConv2DNCHWc);
return MakeConv<Conv2DAttrs>(
data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype, "nn.contrib_depthwise_conv2d_NCHWc");
});
RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
......@@ -879,85 +652,6 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
ConvInferCorrectLayout<Conv2DAttrs>);
bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[2].as<TensorTypeNode>();
CHECK(data);
auto* param = attrs.as<DeformableConv2DAttrs>();
CHECK_EQ(param->data_layout, "NCHW") << "data layout not supported.";
CHECK_EQ(param->kernel_layout, "OIHW") << "kernel_layout not supported.";
IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x;
// infer weight shape if kernel_size and channels are defiend
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape(
{param->channels,
indexdiv(data->shape[1], param->groups),
param->kernel_size[0],
param->kernel_size[1]});
channels = param->channels;
ksize_y = param->kernel_size[0];
ksize_x = param->kernel_size[1];
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
// assign result to reporter
reporter->Assign(types[2], TensorType(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = weight->shape;
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "DeformableConv2D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size
<< " wshape=" << wshape;
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[0]))
<< "DeformableConv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels
<< " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1]));
channels = wshape[0];
ksize_y = wshape[2];
ksize_x = wshape[3];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
Array<IndexExpr> oshape({data->shape[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y,
param->strides[0]) + 1);
oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x,
param->strides[1]) + 1);
DataType out_dtype = param->out_dtype;
// infer offset shape
Array<IndexExpr> offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups,
oshape[2], oshape[3]});
reporter->Assign(types[1], TensorType(offset_shape, data->dtype));
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
reporter->Assign(types[3], TensorType(oshape, out_dtype));
return true;
}
TVM_REGISTER_NODE_TYPE(DeformableConv2DAttrs);
RELAY_REGISTER_OP("nn.deformable_conv2d")
......@@ -986,11 +680,12 @@ by concating all the *g* results.
.add_argument("offset", "Tensor", "The offset tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(5)
.add_type_rel("DeformableConv2D", DeformableConv2DRel);
.add_type_rel("DeformableConv2D", DeformableConv2DRel<DeformableConv2DAttrs>);
// Positional relay function to create deformable_conv2d operator
// used by frontend FFI.
Expr MakeDeformableConv2D(Expr data,
TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d")
.set_body_typed([](Expr data,
Expr offset,
Expr weight,
Array<IndexExpr> strides,
......@@ -1004,24 +699,11 @@ Expr MakeDeformableConv2D(Expr data,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_object<DeformableConv2DAttrs>();
attrs->strides = strides;
attrs->padding = padding;
attrs->dilation = dilation;
attrs->deformable_groups = deformable_groups;
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = kernel_size;
attrs->data_layout = data_layout;
attrs->kernel_layout = kernel_layout;
attrs->out_layout = out_layout;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("nn.deformable_conv2d");
return Call(op, {data, offset, weight}, Attrs{attrs}, {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d")
.set_body_typed(MakeDeformableConv2D);
return MakeDeformableConv<DeformableConv2DAttrs>(
data, offset, weight, strides, padding, dilation,
deformable_groups, groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype, "nn.deformable_conv2d");
});
} // namespace relay
} // namespace tvm
......@@ -29,12 +29,15 @@
#include <string>
#include <utility>
#include <vector>
#include "../op_common.h"
namespace tvm {
namespace relay {
// Standard convolution operator shape relations
template <typename AttrType>
bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
......@@ -363,6 +366,533 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}
// Winograd convolution shape relations
inline bool Conv2DWinogradWeightTransformRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs, const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const ConvWinogradWeightTransformAttrs* param = attrs.as<ConvWinogradWeightTransformAttrs>();
CHECK(param != nullptr);
CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
std::vector<IndexExpr> oshape {
param->tile_size + data->shape[2] - 1,
param->tile_size + data->shape[3] - 1,
data->shape[0],
data->shape[1],
};
reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape),
data->dtype));
return true;
}
inline bool Conv3DWinogradWeightTransformRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs, const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const ConvWinogradWeightTransformAttrs* param = attrs.as<ConvWinogradWeightTransformAttrs>();
CHECK(param != nullptr);
CHECK_EQ(data->shape.size(), 5) << "Only support NCDHW normal kernel layout";
// Shape of packed weights depends on whether depth is being transformed or not.
Array<IndexExpr> oshape({0, 0, 0, data->shape[0], data->shape[1]});
auto* depth_imm = data->shape[2].as<IntImmNode>();
bool transform_depth = (depth_imm->value > 2)&&(depth_imm->value < 8);
if (transform_depth) {
oshape.Set(0, param->tile_size + data->shape[2] - 1);
oshape.Set(1, param->tile_size + data->shape[3] - 1);
oshape.Set(2, param->tile_size + data->shape[4] - 1);
} else {
oshape.Set(0, param->tile_size + data->shape[3] - 1);
oshape.Set(1, param->tile_size + data->shape[4] - 1);
oshape.Set(2, data->shape[2]);
}
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
inline bool Conv2DWinogradNNPACKWeightTransformRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const Conv2DWinogradNNPACKWeightTransformAttrs* param =
attrs.as<Conv2DWinogradNNPACKWeightTransformAttrs>();
CHECK(param != nullptr);
CHECK_EQ(data->shape.size(), 4) << "Only support NCHW normal kernel layout";
std::vector<IndexExpr> oshape{
data->shape[0],
data->shape[1],
8,
8,
};
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
reporter->Assign(types[1], TensorType(Array<IndexExpr>(oshape), out_dtype));
return true;
}
template<typename AttrType>
bool Conv2DWinogradRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
const AttrType* param = attrs.as<AttrType>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
CHECK(param->kernel_size.defined() && param->channels.defined())
<< "The kernel size and channels of a Conv must be set or inferred by previous pass";
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
channels = param->channels;
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
// NOTE: Do not check weight shape here!
// Different backend requires different layout to compute
// the batch gemm stage in winograd efficiently, but we want to
// make this op work for all backends.
// So we accept all weight shapes, and assume the TOPI developers
// can handle this correctly in alter_op_layout.
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (!dshape_nchw[2].as<tir::AnyNode>()) {
oshape.Set(2, (dshape_nchw[2] + pad_h
- dilated_ksize_y) / param->strides[0] + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
}
if (!dshape_nchw[3].as<tir::AnyNode>()) {
oshape.Set(3, (dshape_nchw[3] + pad_w
- dilated_ksize_x) / param->strides[1] + 1);
} else {
oshape.Set(3, dshape_nchw[3]);
}
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
template<typename AttrType>
bool Conv3DWinogradRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCDHW("NCDHW");
static const Layout kOIDHW("OIDHW");
const AttrType* param = attrs.as<AttrType>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCDHW."
<< " But got " << in_layout;
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIDHW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCDHW."
<< " But got " << out_layout;
Array<IndexExpr> dshape_ncdhw = trans_in_layout.ForwardShape(data->shape);
IndexExpr channels, dilated_ksize_d, dilated_ksize_y, dilated_ksize_x;
CHECK(param->kernel_size.defined() && param->channels.defined())
<< "The kernel size and channels of a Conv must be set or inferred by previous pass";
CHECK_EQ(param->kernel_size.size(), 3);
CHECK_EQ(param->dilation.size(), 3);
channels = param->channels;
dilated_ksize_d = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2];
// NOTE: Do not check weight shape here!
// Different backend requires different layout to compute
// the batch gemm stage in winograd efficiently, but we want to
// make this op work for all backends.
// So we accept all weight shapes, and assume the TOPI developers
// can handle this correctly in alter_op_layout.
// dilation
Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0});
IndexExpr pad_d, pad_h, pad_w;
GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
oshape.Set(2, (dshape_ncdhw[2] + pad_d
- dilated_ksize_d) / param->strides[0] + 1);
} else {
oshape.Set(2, dshape_ncdhw[2]);
}
if (!dshape_ncdhw[2].as<tir::AnyNode>()) {
oshape.Set(3, (dshape_ncdhw[3] + pad_h
- dilated_ksize_y) / param->strides[1] + 1);
} else {
oshape.Set(3, dshape_ncdhw[3]);
}
if (!dshape_ncdhw[4].as<tir::AnyNode>()) {
oshape.Set(4, (dshape_ncdhw[4] + pad_w
- dilated_ksize_x) / param->strides[2] + 1);
} else {
oshape.Set(4, dshape_ncdhw[4]);
}
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
// Transposed convolution shape relations
template <typename AttrType>
bool Conv1DTransposeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCW("NCW");
static const Layout kOIW("OIW");
const Conv1DTransposeAttrs* param = attrs.as<AttrType>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCW."
<< " But got " << in_layout;
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCW."
<< " But got " << out_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
auto dshape_ncw = trans_in_layout.ForwardShape(data->shape);
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 1);
CHECK_EQ(param->dilation.size(), 1);
Array<IndexExpr> wshape({dshape_ncw[1],
indexdiv(param->channels, param->groups),
param->kernel_size[0]});
wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_x = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
channels = param->channels;
// assign result to reporter
reporter->Assign(types[1], TensorType(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 1);
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]))
<< "Conv1D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size
<< " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv1D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(indexdiv(dshape_ncw[1], param->groups), wshape[0]));
channels = wshape[1];
dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilation[0];
}
// dilation
IndexExpr pad_w;
GetPaddingWidth(param->padding, &pad_w);
Array<IndexExpr> oshape({dshape_ncw[0], channels, 0});
oshape.Set(2, (param->strides[0] * (dshape_ncw[2] - 1) + dilated_ksize_x -
pad_w + param->output_padding[0]));
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
template <typename AttrType>
bool Conv2DTransposeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
const Conv2DTransposeAttrs* param = attrs.as<AttrType>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
auto dshape_nchw = trans_in_layout.ForwardShape(data->shape);
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape({dshape_nchw[1],
indexdiv(param->channels, param->groups),
param->kernel_size[0],
param->kernel_size[1]});
wshape = trans_kernel_layout.BackwardShape(wshape);
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
channels = param->channels;
// assign result to reporter
reporter->Assign(types[1], TensorType(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size
<< " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(indexdiv(dshape_nchw[1], param->groups), wshape[0]));
channels = wshape[1];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
pad_h + param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
pad_w + param->output_padding[1]));
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
// Deformable Convolution shape relations.
template <typename AttrType>
bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[2].as<TensorTypeNode>();
CHECK(data);
auto* param = attrs.as<AttrType>();
CHECK_EQ(param->data_layout, "NCHW") << "data layout not supported.";
CHECK_EQ(param->kernel_layout, "OIHW") << "kernel_layout not supported.";
IndexExpr channels, dilated_ksize_y, dilated_ksize_x, ksize_y, ksize_x;
// infer weight shape if kernel_size and channels are defiend
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
Array<IndexExpr> wshape(
{param->channels,
indexdiv(data->shape[1], param->groups),
param->kernel_size[0],
param->kernel_size[1]});
channels = param->channels;
ksize_y = param->kernel_size[0];
ksize_x = param->kernel_size[1];
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
// assign result to reporter
reporter->Assign(types[2], TensorType(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = weight->shape;
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "DeformableConv2D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size
<< " wshape=" << wshape;
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[0]))
<< "DeformableConv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels
<< " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(indexdiv(data->shape[1], param->groups), wshape[1]));
channels = wshape[0];
ksize_y = wshape[2];
ksize_x = wshape[3];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
Array<IndexExpr> oshape({data->shape[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y,
param->strides[0]) + 1);
oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x,
param->strides[1]) + 1);
DataType out_dtype = param->out_dtype;
// infer offset shape
Array<IndexExpr> offset_shape({data->shape[0], 2 * ksize_y * ksize_x * param->deformable_groups,
oshape[2], oshape[3]});
reporter->Assign(types[1], TensorType(offset_shape, data->dtype));
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
reporter->Assign(types[3], TensorType(oshape, out_dtype));
return true;
}
template<typename T>
Array<Array<Layout> > ConvInferCorrectLayout(
const Attrs& attrs,
......@@ -378,6 +908,7 @@ Array<Array<Layout> > ConvInferCorrectLayout(
params->data_layout : params->out_layout}};
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_NN_CONVOLUTION_H_
......@@ -25,6 +25,7 @@ from tvm.relay import transform
from tvm.relay.testing import ctx_list, run_infer_type
from tvm.contrib import util
import topi.testing
from topi.cuda.conv3d_winograd import _infer_tile_size
def test_conv1d_infer_type():
......@@ -326,7 +327,7 @@ def test_conv2d_winograd():
cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg['tile_rc'] = autotvm.task.space.SplitEntity([-1, 1])
cfg['auto_unroll_max_setp'] = autotvm.task.space.OtherOptionEntity(1500)
cfg['auto_unroll_max_step'] = autotvm.task.space.OtherOptionEntity(1500)
cfg['unroll_explicit'] = autotvm.task.space.OtherOptionEntity(1)
self.memory[key] = cfg
return cfg
......@@ -522,6 +523,94 @@ def test_conv3d_ndhwc_run():
run_test_conv3d("float32", "float32", 1, dshape, kshape,
padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3), except_targets=["cuda"])
def test_conv3d_winograd():
class WinogradFallback(autotvm.FallbackContext):
def _query_inside(self, target, workload):
key = (target, workload)
if key in self.memory:
return self.memory[key]
cfg = autotvm.task.space.FallbackConfigEntity()
cfg.is_fallback = False
cfg.cost = 0.1 if 'winograd' in workload[0] else 1
cfg['tile_b'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg['tile_rc'] = autotvm.task.space.SplitEntity([-1, 1])
cfg['auto_unroll_max_step'] = autotvm.task.space.OtherOptionEntity(0)
cfg['unroll_explicit'] = autotvm.task.space.OtherOptionEntity(1)
self.memory[key] = cfg
return cfg
def run_test_conv3d_cuda(dtype, out_dtype, scale, dshape, kshape,
padding=(1, 1, 1),
groups=1,
dilation=(1, 1, 1),
prepack=False,
**attrs):
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", shape=kshape, dtype=dtype)
if prepack:
tile_size = _infer_tile_size(np.zeros(shape=dshape), np.zeros(shape=kshape))
w_packed = relay.nn.contrib_conv3d_winograd_weight_transform(w, tile_size)
y = relay.nn.contrib_conv3d_winograd_without_weight_transform(
x, w_packed, tile_size,
padding=padding,
dilation=dilation,
groups=groups,
channels=kshape[0],
**attrs)
else:
y = relay.nn.conv3d(x, w,
padding=padding,
dilation=dilation,
groups=groups,
**attrs)
func = relay.Function([x, w], y)
mod = tvm.IRModule()
mod['main'] = func
mod = relay.transform.InferType()(mod)
data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
ref_res = topi.testing.conv3d_ncdhw_python(
data.astype(out_dtype), kernel.astype(out_dtype), 1, padding,
groups=groups)
with WinogradFallback(), relay.build_config(opt_level=3):
for target, ctx in ctx_list():
if target != 'cuda':
continue
params = {'w': tvm.nd.array(kernel)}
graph, lib, params = relay.build_module.build(mod, target=target, params=params)
module = tvm.contrib.graph_runtime.create(graph, lib, ctx)
module.set_input('x', tvm.nd.array(data))
module.set_input(**params)
module.run()
op_res1 = module.get_output(0)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-3, atol=1e-3)
# normal winograd: stride 1, padding 1, kernel 3x3x3
dshape = (1, 32, 16, 16, 16)
kshape = (64, 32, 3, 3, 3)
run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
padding=(1, 1, 1), kernel_size=(3, 3, 3))
# Without depth transform using 1x3x3 kernel.
kshape = (64, 32, 1, 3, 3)
run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
padding=(0, 1, 1), kernel_size=(1, 3, 3))
# extended winograd: stride 1, padding N, kernel NxNxN
dshape = (1, 61, 20, 20, 20)
kshape = (120, 61, 5, 5, 5)
run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
padding=(2, 2, 2), channels=120, kernel_size=(5, 5, 5))
# Without depth transform
kshape = (120, 61, 1, 5, 5)
run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
padding=(0, 2, 2), channels=120, kernel_size=(1, 5, 5))
def test_conv2d_transpose_infer_type():
# symbolic in batch dimension
......@@ -1268,6 +1357,7 @@ if __name__ == "__main__":
test_conv2d_winograd()
test_conv3d_run()
test_conv3d_ndhwc_run()
test_conv3d_winograd()
test_bitserial_conv2d_infer_type()
test_batch_flatten()
test_upsampling()
......
......@@ -31,6 +31,8 @@ from . import conv2d_alter_op
from .conv2d_transpose_nchw import *
from .deformable_conv2d import *
from .conv3d import *
from .conv3d_winograd import *
from . import conv3d_alter_op
from .reduction import schedule_reduce
from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument
"""Conv3D alter op and legalize functions for cuda backend"""
import logging
import tvm
from tvm import te
from tvm import relay
from tvm import autotvm
from .. import nn
from ..util import get_const_tuple
from .conv3d_winograd import _infer_tile_size
logger = logging.getLogger('topi')
@nn.conv3d_alter_layout.register(["cuda", "gpu"])
def _alter_conv3d_layout(attrs, inputs, tinfos, out_type):
target = tvm.target.Target.current(allow_none=False)
dispatch_ctx = autotvm.task.DispatchContext.current
_, outs = relay.backend.compile_engine.select_implementation(
relay.op.get("nn.conv3d"), attrs, tinfos, out_type, target)
workload = autotvm.task.get_workload(outs)
if workload is None:
# The best implementation is not an AutoTVM template,
# we then assume it's not necessary to alter this op.
return None
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None
topi_tmpl = workload[0]
new_attrs = {k: attrs[k] for k in attrs.keys()}
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int('groups')
data_layout = attrs["data_layout"]
kernel_layout = attrs["kernel_layout"]
data, kernel = tinfos
out_dtype = out_type.dtype
if topi_tmpl == "conv3d_ncdhw_winograd.cuda":
if dilation != (1, 1, 1):
logger.warning("Does not support weight pre-transform for dilated 3D convolution.")
return None
assert data_layout == "NCDHW" and kernel_layout == "OIDHW"
N, CI, D, H, W = get_const_tuple(data.shape)
CO, _, KD, KH, KW = get_const_tuple(kernel.shape)
# Pre-compute weight transformation in winograd
tile_size = _infer_tile_size(tinfos[0], tinfos[1])
weight = relay.nn.contrib_conv3d_winograd_weight_transform(inputs[1], tile_size=tile_size)
new_attrs['tile_size'] = tile_size
new_attrs['channels'] = CO
# Store the same config for the altered operators (workload)
new_data = data
# Check if depth is transformed or not
if 2 < KD < 8 and KD == KH:
new_weight = te.placeholder(
(KD + tile_size - 1, KH + tile_size - 1, KW + tile_size - 1, CO, CI),
dtype=kernel.dtype)
else:
new_weight = te.placeholder(
(KH + tile_size - 1, KW + tile_size - 1, KD, CO, CI),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation, out_dtype],
"conv3d_ncdhw_winograd_without_weight_transform.cuda")
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_conv3d_winograd_without_weight_transform(
inputs[0], weight, **new_attrs)
return None
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument
"""Winograd template for cuda backend"""
import logging
import tvm
from tvm import te
from tvm import autotvm
from .. import nn
from ..util import get_const_int, get_const_tuple, traverse_inline, simplify
from ..nn.winograd_util import winograd_transform_matrices
logger = logging.getLogger('conv3d_winograd')
def _infer_tile_size(data, kernel):
N, CI, D, H, W = get_const_tuple(data.shape)
if H % 8 == 0:
return 4
return 2
def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed):
"""Compute declaration for winograd"""
tile_size = _infer_tile_size(data, kernel)
N, CI, D, H, W = get_const_tuple(data.shape)
if isinstance(dilation, int):
dilation_d = dilation_h = dilation_w = dilation
else:
dilation_d, dilation_h, dilation_w = dilation
DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
if not pre_computed: # kernel tensor is raw tensor, do strict check
if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
alpha = KW + tile_size - 1
assert DSTR == 1 and HSTR == 1 and WSTR == 1 and KD == KH and KH == KW
else:
# kernel tensor is pre-transformed. this op is created by alter op layout.
# dilation is not supported
alpha, _, _, CO, CI = get_const_tuple(kernel.shape)
KD = KH = KW = alpha + 1 - tile_size
assert DSTR == 1 and HSTR == 1 and WSTR == 1 and \
dilation_d == 1 and dilation_h == 1 and dilation_w == 1
pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
r = KW
m = tile_size
A, B, G = winograd_transform_matrices(m, r, out_dtype)
D = (D + pf + pb - KD) // DSTR + 1
H = (H + pt + pd - KH) // HSTR + 1
W = (W + pl + pr - KW) // WSTR + 1
nD, nH, nW = (D + m - 1) // m, (H + m - 1) // m, (W + m - 1) // m
P = N * nD * nH * nW
# transform kernel
if not pre_computed:
# Check if we are currently tuning, if so we want to avoid counting
# prepacking in time costs. Just use a placeholder with the packed shape instead.
if autotvm.GLOBAL_SCOPE.in_tuning:
kernel_pack = te.placeholder((alpha, alpha, alpha, CO, CI),
dtype=kernel.dtype,
name='kernel_pack')
else:
r_kd = te.reduce_axis((0, KD), name='r_kd')
r_kh = te.reduce_axis((0, KH), name='r_kh')
r_kw = te.reduce_axis((0, KW), name='r_kw')
kernel_pack = te.compute(
(alpha, alpha, alpha, CO, CI),
lambda omg, eps, nu, co, ci: te.sum(
kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
axis=[r_kd, r_kh, r_kw]),
name='kernel_pack')
else:
kernel_pack = kernel
idxdiv = tvm.tir.indexdiv
idxmod = tvm.tir.indexmod
# pack input tile
input_tile = te.compute((CI, P, alpha, alpha, alpha),
lambda c, p, omg, eps, nu: data_pad[idxdiv(p, (nD * nH * nW))]
[c]
[idxmod(idxdiv(p, nH * nW), nD) * m + omg]
[idxmod(idxdiv(p, nW), nH) * m + eps]
[idxmod(p, nW) * m + nu],
name='d')
# transform data
r_a = te.reduce_axis((0, alpha), 'r_a')
r_b = te.reduce_axis((0, alpha), 'r_b')
r_c = te.reduce_axis((0, alpha), 'r_c')
data_pack = te.compute(
(alpha, alpha, alpha, CI, P),
lambda omg, eps, nu, ci, p: te.sum(
input_tile[ci][p][r_a][r_b][r_c] * B[r_a][omg] * B[r_b][eps] * B[r_c][nu],
axis=[r_a, r_b, r_c]),
name='data_pack')
# do batch gemm
ci = te.reduce_axis((0, CI), name='ci')
bgemm = te.compute(
(alpha, alpha, alpha, CO, P),
lambda omg, eps, nu, co, p: te.sum(
kernel_pack[omg][eps][nu][co][ci] * data_pack[omg][eps][nu][ci][p], axis=[ci]),
name='bgemm')
# inverse transform
r_a = te.reduce_axis((0, alpha), 'r_a')
r_b = te.reduce_axis((0, alpha), 'r_b')
r_c = te.reduce_axis((0, alpha), 'r_c')
inverse = te.compute((CO, P, m, m, m),
lambda co, p, vd, vh, vw: te.sum(
bgemm[r_a][r_b][r_c][co][p] * A[r_a][vd] * A[r_b][vh] * A[r_c][vw],
axis=[r_a, r_b, r_c]),
name='inverse')
# output
output = te.compute((N, CO, D, H, W),
lambda n, co, d, h, w: inverse[co, n * nD * nH * nW + idxdiv(d, m) * nH * nW
+ idxdiv(h, m) * nW + idxdiv(w, m),
idxmod(d, m),
idxmod(h, m),
idxmod(w, m)],
name='output',
tag='conv3d_ncdhw_winograd')
cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
return output
def winograd_without_depth_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype,
pre_computed):
"""Compute declaration for winograd without transforming depth"""
tile_size = _infer_tile_size(data, kernel)
N, CI, D, H, W = get_const_tuple(data.shape)
if isinstance(dilation, int):
dilation_d = dilation_h = dilation_w = dilation
else:
dilation_d, dilation_h, dilation_w = dilation
DSTR, HSTR, WSTR = (strides, strides, strides) if isinstance(strides, int) else strides
if not pre_computed: # kernel tensor is raw tensor, do strict check
if dilation_d != 1 or dilation_h != 1 or dilation_w != 1:
kernel = nn.dilate(kernel, (1, 1, dilation_d, dilation_h, dilation_w))
CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
alpha = KW + tile_size - 1
assert HSTR == 1 and WSTR == 1 and KH == KW
else:
# kernel tensor is pre-transfomred. this op is created by alter op layout.
# dilation is not supported
alpha, _, KD, CO, CI = get_const_tuple(kernel.shape)
KH = KW = alpha + 1 - tile_size
assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
pf, pt, pl, pb, pd, pr = nn.get_pad_tuple3d(padding, (KD, KH, KW))
data_pad = nn.pad(data, (0, 0, pf, pt, pl), (0, 0, pb, pd, pr), name="data_pad")
out_depth = simplify((D - KD + pf + pb) // DSTR + 1)
D += pf + pb
r = KW
m = tile_size
A, B, G = winograd_transform_matrices(m, r, out_dtype)
H = (H + pt + pd - KH) // HSTR + 1
W = (W + pl + pr - KW) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW
# transform kernel
if not pre_computed:
# During autotuning dont count kernel packing as a time cost
# as it will later be removed via alter_op_layout.
if autotvm.GLOBAL_SCOPE.in_tuning:
kernel_pack = te.placeholder((alpha, alpha, KD, CO, CI),
dtype=kernel.dtype,
name='kernel_pack')
else:
r_kh = te.reduce_axis((0, KH), name='r_kh')
r_kw = te.reduce_axis((0, KW), name='r_kw')
kernel_pack = te.compute(
(alpha, alpha, KD, CO, CI),
lambda eps, nu, d, co, ci: te.sum(
kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
name='kernel_pack')
else:
kernel_pack = kernel
idxdiv = tvm.tir.indexdiv
idxmod = tvm.tir.indexmod
# pack input tile
input_tile = te.compute((CI, D, P, alpha, alpha), lambda c, d, p, eps, nu:
data_pad[idxdiv(p, (nH * nW))][c][d]
[idxmod(idxdiv(p, nW), nH) * m + eps]
[idxmod(p, nW) * m + nu], name='d')
# transform data
r_a = te.reduce_axis((0, alpha), 'r_a')
r_b = te.reduce_axis((0, alpha), 'r_b')
data_pack = te.compute((alpha, alpha, CI, D, P), lambda eps, nu, ci, d, p:
te.sum(input_tile[ci][d][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu],
axis=[r_a, r_b]), name='data_pack')
# do batch gemm
ci = te.reduce_axis((0, CI), name='ci')
rz = te.reduce_axis((0, KD), name='rz')
bgemm = te.compute((alpha, alpha, CO, out_depth, P), lambda eps, nu, co, d, p:
te.sum(kernel_pack[eps][nu][rz][co][ci] *
data_pack[eps][nu][ci][d * DSTR + rz][p],
axis=[ci, rz]), name='bgemm')
# inverse transform
r_a = te.reduce_axis((0, alpha), 'r_a')
r_b = te.reduce_axis((0, alpha), 'r_b')
inverse = te.compute((CO, out_depth, P, m, m), lambda co, d, p, vh, vw:
te.sum(bgemm[r_a][r_b][co][d][p] * A[r_a][vh] * A[r_b][vw],
axis=[r_a, r_b]), name='inverse')
# output
output = te.compute((N, CO, out_depth, H, W), lambda n, co, d, h, w:
inverse[co,
d,
n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m),
idxmod(h, m),
idxmod(w, m)],
name='output', tag='conv3d_ncdhw_winograd_without_depth')
cfg.add_flop(2 * N * CO * D * H * W * CI * KD * KH * KW)
return output
def schedule_winograd_cuda(cfg, s, output, pre_computed):
"""Schedule winograd template"""
# get stages
inverse = s[output].op.input_tensors[0]
bgemm, A = s[inverse].op.input_tensors
kernel_pack, data_pack = s[bgemm].op.input_tensors
input_tile, B = s[data_pack].op.input_tensors
pad_data = s[input_tile].op.input_tensors[0]
# data transform
s[B].compute_inline()
data_l = s.cache_write(data_pack, 'local')
omg, eps, nu, c, p = s[data_l].op.axis
r_a, r_b, r_c = s[data_l].op.reduce_axis
# TODO unrolling by omg, eps, nu may improve performance but
# in some cases causes extremely long build times due to imperfect tiling.
for axis in [r_a, r_b, r_c]:
s[data_l].unroll(axis)
omg, eps, nu, c, p = s[data_pack].op.axis
p, pi = s[data_pack].split(p, 1)
fused = s[data_pack].fuse(c, p)
bb, tt = s[data_pack].split(fused, 128)
s[data_pack].reorder(bb, tt, pi, omg, eps, nu)
s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
s[data_l].compute_at(s[data_pack], pi)
s[input_tile].compute_at(s[data_pack], pi)
s[pad_data].compute_inline()
# transform kernel
if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning:
kernel, G = s[kernel_pack].op.input_tensors
omg, eps, nu, co, ci = s[kernel_pack].op.axis
s[G].compute_inline()
r_a, r_b, r_c = s[kernel_pack].op.reduce_axis
# Could add additional unrolling by omg, eps, nu in the future.
for axis in [r_a, r_b, r_c]:
s[kernel_pack].unroll(axis)
fused = s[kernel_pack].fuse(co, ci)
bb, tt = s[kernel_pack].split(fused, 128)
s[kernel_pack].reorder(bb, tt, omg, eps, nu, r_a, r_b, r_c)
s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x"))
s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x"))
else:
kernel = kernel_pack
if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
##### space definition begin #####
b1, b2, b3, y, x = s[bgemm].op.axis
rc = s[bgemm].op.reduce_axis[0]
alpha = get_const_int(b1.dom.extent)
cfg.define_split(
"tile_b",
cfg.axis(alpha * alpha * alpha),
num_outputs=4,
filter=lambda x: x.size[-3:] == [1, 1, 1])
cfg.define_split("tile_y", y, num_outputs=4)
cfg.define_split("tile_x", x, num_outputs=4)
cfg.define_split("tile_rc", rc, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1])
else:
cfg.define_knob("unroll_explicit", [0, 1])
##### space definition end #####
# batch gemm
C = bgemm
A0, B0 = kernel_pack, data_pack
OL = s.cache_write(C, 'local')
AA = s.cache_read(A0, 'shared', [OL])
BB = s.cache_read(B0, 'shared', [OL])
b = s[bgemm].fuse(b1, b2, b3)
# tile and bind spatial axes
bgemm_scope, b = s[bgemm].split(b, nparts=1)
bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
by, vy, ty, yi = cfg["tile_y"].apply(s, C, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
s[C].bind(bz, te.thread_axis("blockIdx.z"))
s[C].bind(by, te.thread_axis("blockIdx.y"))
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(vz, te.thread_axis("vthread"))
s[C].bind(vy, te.thread_axis("vthread"))
s[C].bind(vx, te.thread_axis("vthread"))
s[C].bind(tz, te.thread_axis("threadIdx.z"))
s[C].bind(ty, te.thread_axis("threadIdx.y"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)
# tile reduction axes
s[OL].compute_at(s[C], tx)
b1, b2, b3, y, x = s[OL].op.axis
b = s[OL].fuse(b1, b2, b3)
rc, = s[OL].op.reduce_axis
rco, rci = cfg['tile_rc'].apply(s, OL, rc)
s[OL].reorder(rco, rci, b, y, x)
s[AA].compute_at(s[OL], rco)
s[BB].compute_at(s[OL], rco)
# cooperative fetching
for load in [AA, BB]:
fused = s[load].fuse(*list(s[load].op.axis))
fused, tx = s[load].split(fused, cfg["tile_x"].size[2])
fused, ty = s[load].split(fused, cfg["tile_y"].size[2])
fused, tz = s[load].split(fused, cfg["tile_b"].size[2])
s[load].bind(tz, te.thread_axis("threadIdx.z"))
s[load].bind(ty, te.thread_axis("threadIdx.y"))
s[load].bind(tx, te.thread_axis("threadIdx.x"))
s[C].pragma(bgemm_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[C].pragma(bgemm_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
# schedule inverse, output and fusion
if output.op in s.outputs:
OL = None
else:
OL = output
s[OL].set_scope('local')
output = s.outputs[0]
m = alpha - 3 + 1
n, co, d, h, w = s[output].op.axis
do, di = s[output].split(d, m)
ho, hi = s[output].split(h, m)
wo, wi = s[output].split(w, m)
s[output].reorder(n, co, do, ho, wo, di, hi, wi)
inverse_scope, n = s[output].split(n, nparts=1)
fused = s[output].fuse(n, co, do, ho, wo)
bb, tt = s[output].split(fused, 128)
s[output].bind(bb, te.thread_axis("blockIdx.x"))
s[output].bind(tt, te.thread_axis("threadIdx.x"))
if OL is not None:
s[OL].compute_at(s[output], tt)
s[A].compute_inline()
co, p, vd, vh, vw = s[inverse].op.axis
r_a, r_b, r_c = s[inverse].op.reduce_axis
# Could add additional unrolling of vd, vh, vw, in the future
for axis in [r_a, r_b, r_c]:
s[inverse].unroll(axis)
s[inverse].compute_at(s[output], tt)
return s
def schedule_winograd_no_depth_cuda(cfg, s, output, pre_computed):
"""Schedule winograd template"""
# get stages
inverse = s[output].op.input_tensors[0]
bgemm, A = s[inverse].op.input_tensors
kernel_pack, data_pack = s[bgemm].op.input_tensors
input_tile, B = s[data_pack].op.input_tensors
pad_data = s[input_tile].op.input_tensors[0]
# data transform
s[B].compute_inline()
data_l = s.cache_write(data_pack, 'local')
eps, nu, c, d, p = s[data_l].op.axis
r_a, r_b = s[data_l].op.reduce_axis
for axis in [eps, nu, r_a, r_b]:
s[data_l].unroll(axis)
eps, nu, c, d, p = s[data_pack].op.axis
p, pi = s[data_pack].split(p, 1)
fused = s[data_pack].fuse(c, d, p)
bb, tt = s[data_pack].split(fused, 128)
s[data_pack].reorder(bb, tt, pi, eps, nu)
s[data_pack].bind(bb, te.thread_axis("blockIdx.x"))
s[data_pack].bind(tt, te.thread_axis("threadIdx.x"))
s[data_l].compute_at(s[data_pack], pi)
s[input_tile].compute_at(s[data_pack], pi)
s[pad_data].compute_inline()
# transform kernel
if not pre_computed and not autotvm.GLOBAL_SCOPE.in_tuning:
kernel, G = s[kernel_pack].op.input_tensors
eps, nu, kd, co, ci = s[kernel_pack].op.axis
s[G].compute_inline()
r_a, r_b = s[kernel_pack].op.reduce_axis
for axis in [eps, nu, r_a, r_b]:
s[kernel_pack].unroll(axis)
fused = s[kernel_pack].fuse(kd, co, ci)
bb, tt = s[kernel_pack].split(fused, 128)
s[kernel_pack].reorder(bb, tt, eps, nu, r_a, r_b)
s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x"))
s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x"))
else:
kernel = kernel_pack
if isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
##### space definition begin #####
b1, b2, z, y, x = s[bgemm].op.axis
# Combine channel and depth axes.
rc = s[bgemm].op.reduce_axis[0]
rz = s[bgemm].op.reduce_axis[1]
alpha = get_const_int(b1.dom.extent)
cfg.define_split("tile_b", cfg.axis(alpha * alpha), num_outputs=4,
filter=lambda x: x.size[-3:] == [1, 1, 1])
cfg.define_split("tile_y", y, num_outputs=4)
cfg.define_split("tile_x", x, num_outputs=4)
cfg.define_split("tile_rc", rc, num_outputs=2)
cfg.define_split("tile_rz", rz, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 128, 1500])
target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1])
else:
cfg.define_knob("unroll_explicit", [0, 1])
##### space definition end #####
# batch gemm
C = bgemm
A0, B0 = kernel_pack, data_pack
OL = s.cache_write(C, 'local')
AA = s.cache_read(A0, 'shared', [OL])
BB = s.cache_read(B0, 'shared', [OL])
b = s[bgemm].fuse(b1, b2)
y = s[bgemm].fuse(z, y)
# tile and bind spatial axes
bgemm_scope, b = s[bgemm].split(b, nparts=1)
bz, vz, tz, zi = cfg["tile_b"].apply(s, C, b)
by, vy, ty, yi = cfg["tile_y"].apply(s, C, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, C, x)
s[C].bind(bz, te.thread_axis("blockIdx.z"))
s[C].bind(by, te.thread_axis("blockIdx.y"))
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(vz, te.thread_axis("vthread"))
s[C].bind(vy, te.thread_axis("vthread"))
s[C].bind(vx, te.thread_axis("vthread"))
s[C].bind(tz, te.thread_axis("threadIdx.z"))
s[C].bind(ty, te.thread_axis("threadIdx.y"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
s[C].reorder(bgemm_scope, bz, by, bx, vz, vy, vx, tz, ty, tx, zi, yi, xi)
# tile reduction axes
s[OL].compute_at(s[C], tx)
b1, b2, y1, y2, x = s[OL].op.axis
y = s[OL].fuse(y1, y2)
b = s[OL].fuse(b1, b2)
rc, rz = s[OL].op.reduce_axis
rco, rci = cfg['tile_rc'].apply(s, OL, rc)
rzo, rzi = cfg['tile_rz'].apply(s, OL, rz)
s[OL].reorder(rco, rzo, rci, rzi, b, y, x)
s[AA].compute_at(s[OL], rco)
s[BB].compute_at(s[OL], rco)
# cooperative fetching
for load in [AA, BB]:
fused = s[load].fuse(*list(s[load].op.axis))
fused, tx = s[load].split(fused, cfg["tile_x"].size[2])
fused, ty = s[load].split(fused, cfg["tile_y"].size[2])
fused, tz = s[load].split(fused, cfg["tile_b"].size[2])
s[load].bind(tz, te.thread_axis("threadIdx.z"))
s[load].bind(ty, te.thread_axis("threadIdx.y"))
s[load].bind(tx, te.thread_axis("threadIdx.x"))
s[C].pragma(bgemm_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[C].pragma(bgemm_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
# schedule inverse, output and fusion
if output.op in s.outputs:
OL = None
else:
OL = output
s[OL].set_scope('local')
output = s.outputs[0]
m = alpha - 3 + 1
n, co, d, h, w = s[output].op.axis
do, di = s[output].split(d, m)
ho, hi = s[output].split(h, m)
wo, wi = s[output].split(w, m)
s[output].reorder(n, co, do, ho, wo, di, hi, wi)
inverse_scope, n = s[output].split(n, nparts=1)
fused = s[output].fuse(n, co, do, ho, wo)
bb, tt = s[output].split(fused, 128)
s[output].bind(bb, te.thread_axis("blockIdx.x"))
s[output].bind(tt, te.thread_axis("threadIdx.x"))
if OL is not None:
s[OL].compute_at(s[output], tt)
s[A].compute_inline()
co, d, p, vh, vw = s[inverse].op.axis
r_a, r_b = s[inverse].op.reduce_axis
for axis in [vh, vw, r_a, r_b]:
s[inverse].unroll(axis)
s[inverse].compute_at(s[output], tt)
return s
@autotvm.register_topi_compute("conv3d_ncdhw_winograd.cuda")
def conv3d_ncdhw_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype):
CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
# Check if we can transform depth.
if 2 < KD < 8 and KD == KH:
return winograd_cuda(
cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=False)
return winograd_without_depth_cuda(
cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=False)
@autotvm.register_topi_schedule("conv3d_ncdhw_winograd.cuda")
def schedule_conv3d_ncdhw_winograd(cfg, outs):
"""Dispatch to schedule approriate for conv3d winograd algorithm used."""
s = te.create_schedule([x.op for x in outs])
def _callback(op):
if 'conv3d_ncdhw_winograd_without_depth' in op.tag:
schedule_winograd_no_depth_cuda(cfg, s, op.output(0), pre_computed=False)
elif 'conv3d_ncdhw_winograd' in op.tag:
schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=False)
traverse_inline(s, outs[0].op, _callback)
return s
@autotvm.register_topi_compute("conv3d_ncdhw_winograd_without_weight_transform.cuda")
def conv3d_ncdhw_winograd_without_weight_transform(cfg, data, kernel, strides, padding, dilation,
out_dtype):
A, B, C, _, _ = get_const_tuple(kernel.shape)
# Check if we can transform depth.
if A == B == C:
return winograd_cuda(
cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=True)
return winograd_without_depth_cuda(
cfg, data, kernel, strides, padding, dilation, out_dtype, pre_computed=True)
@autotvm.register_topi_schedule("conv3d_ncdhw_winograd_without_weight_transform.cuda")
def schedule_conv3d_ncdhw_winograd_without_weight_transform(cfg, outs):
"""TOPI schedule callback"""
s = te.create_schedule([x.op for x in outs])
def _callback(op):
if 'conv3d_ncdhw_winograd_without_depth' in op.tag:
schedule_winograd_no_depth_cuda(cfg, s, op.output(0), pre_computed=True)
elif 'conv3d_ncdhw_winograd' in op.tag:
schedule_winograd_cuda(cfg, s, op.output(0), pre_computed=True)
traverse_inline(s, outs[0].op, _callback)
return s
......@@ -187,6 +187,43 @@ def schedule_conv2d_winograd_weight_transform(outs):
return s
def schedule_conv3d_winograd_weight_transform(outs):
"""Schedule for weight transformation of 3D winograd
Parameters
----------
outs: Array of Tensor
The computation graph description of this operator
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
# Typically this is computed in PreCompute pass
# so we make a schedule here for cpu llvm
s = te.create_schedule([x.op for x in outs])
output = outs[0]
_, G = s[output].op.input_tensors
s[G].compute_inline()
transform_depth = len(s[output].op.reduce_axis) == 3
if transform_depth:
omg, eps, nu, ci, co = s[output].op.axis
r_kd, r_kh, r_kw = s[output].op.reduce_axis
s[output].reorder(co, ci, omg, eps, nu, r_kd, r_kh, r_kw)
for axis in [r_kd, r_kh, r_kw]:
s[output].unroll(axis)
else:
eps, nu, d, ci, co = s[output].op.axis
r_kh, r_kw = s[output].op.reduce_axis
s[output].reorder(co, ci, d, eps, nu, r_kh, r_kw)
for axis in [r_kh, r_kw]:
s[output].unroll(axis)
s[output].parallel(co)
return s
def schedule_conv2d_winograd_without_weight_transform(outs):
"""Schedule for winograd without weight transformation
......
......@@ -17,11 +17,13 @@
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin, no-else-return
"""Conv3D operators"""
import tvm
from tvm import te
from .pad import pad
from .util import get_pad_tuple3d
from ..util import simplify
from ..util import simplify, get_const_tuple
from .winograd_util import winograd_transform_matrices
def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None):
......@@ -159,3 +161,74 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
Filter[rd, rh, rw, rc, cc].astype(out_dtype), axis=[rd, rh, rw, rc]),
name="Conv3dOutput", tag="conv3d_ndhwc")
return Output
def conv3d_winograd_weight_transform(kernel, tile_size):
"""Weight transformation for 3D winograd
Parameters
----------
kernel: Tensor
The raw kernel tensor with layout "NCDHW".
tile_size: int
Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
Returns
-------
output : tvm.te.Tensor
5-D with shape [alpha, alpha, alpha, CO, CI]
"""
CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
depth_transform = 2 < KD < 8 and KD == KH
if depth_transform:
assert KD == KH == KW, "Only support NxNxN kernel"
else:
assert KH == KW, "Only supports DxNxN kernel"
r = tile_size + KH - 1
r_kh = te.reduce_axis((0, KH), name='r_kh')
r_kw = te.reduce_axis((0, KW), name='r_kw')
_, _, G = winograd_transform_matrices(tile_size, KH, kernel.dtype)
if depth_transform:
shape = (r, r, r, CO, CI)
r_kd = te.reduce_axis((0, KD), name='r_kd')
return te.compute(
shape,
lambda omg, eps, nu, co, ci: te.sum(
kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
axis=[r_kd, r_kh, r_kw]),
name='transform_weight')
else:
shape = (r, r, KD, CO, CI)
return te.compute(
shape,
lambda eps, nu, d, co, ci: te.sum(
kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
name='transform_weight')
@tvm.target.generic_func
def conv3d_alter_layout(attrs, inputs, tinfos, out_type):
"""Change Conv3D layout.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : tvm.relay.Expr
Grouped input symbols
tinfos : list
Input shape and dtype
out_type: type
The output type
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level.
"""
# not to change by default
return None
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for 3d convolution with winograd."""
import numpy as np
import tvm
from tvm import te
from tvm import autotvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.nn.util import get_pad_tuple3d
from topi.util import get_const_tuple
from common import get_all_backend
_conv3d_ncdhw_implement = {
"gpu": (topi.cuda.conv3d_ncdhw_winograd, topi.cuda.schedule_conv3d_ncdhw_winograd),
}
def verify_conv3d_ncdhw(batch,
in_channel,
in_size,
num_filter,
depth_kernel,
space_kernel,
stride,
padding,
dilation=1,
add_bias=False,
add_relu=False):
pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(
padding, (depth_kernel, space_kernel, space_kernel))
padding_sum = pad_front + pad_back + pad_top + pad_left + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" %
(batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation))
in_depth = in_height = in_width = in_size
A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A')
W = te.placeholder((num_filter, in_channel, depth_kernel, space_kernel, space_kernel), name='W')
bias = te.placeholder((num_filter, 1, 1, 1), name='bias')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv3d_ncdhw.verify_conv3d_ncdhw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation, dilation))
c_np = topi.testing.conv3d_ncdhw_python(a_np, dw_np, stride, padding)
if add_bias:
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
fcompute, fschedule = topi.testing.dispatch(device, _conv3d_ncdhw_implement)
with tvm.target.create(device):
C = fcompute(A, W, (stride, stride, stride), padding, (dilation, dilation, dilation),
dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = fschedule([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(
s, [A, W, bias, C],
device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation))
func(a, w, b, c)
else:
func = tvm.build(
s, [A, W, C],
device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
for device in ["cuda"]:
with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
def test_conv3d_ncdhw():
# Try without depth transformation
#3DCNN workloads
verify_conv3d_ncdhw(1, 61, 20, 120, 3, 3, 1, 0)
verify_conv3d_ncdhw(1, 61, 20, 120, 1, 3, 1, 0)
verify_conv3d_ncdhw(1, 61, 20, 120, 5, 3, 1, 0)
verify_conv3d_ncdhw(1, 61, 20, 120, 5, 5, 1, 2)
verify_conv3d_ncdhw(1, 61, 20, 120, 1, 5, 1, 2)
verify_conv3d_ncdhw(1, 61, 20, 120, 7, 7, 1, 3)
verify_conv3d_ncdhw(1, 128, 12, 256, 3, 3, 1, 1)
verify_conv3d_ncdhw(1, 64, 12, 128, 3, 3, 1, 1)
# bias, relu
verify_conv3d_ncdhw(1, 64, 12, 128, 3, 3, 1, 1, add_relu=True)
verify_conv3d_ncdhw(1, 64, 12, 128, 3, 3, 1, 1, add_relu=True, add_bias=True)
verify_conv3d_ncdhw(1, 64, 12, 128, 1, 3, 1, 1, add_relu=True, add_bias=True)
# dilation = 2
verify_conv3d_ncdhw(1, 16, 12, 16, 3, 3, 1, "VALID", dilation=2)
verify_conv3d_ncdhw(1, 16, 12, 16, 1, 3, 1, "VALID", dilation=2)
# batch size
verify_conv3d_ncdhw(4, 32, 12, 64, 3, 3, 1, 1)
verify_conv3d_ncdhw(4, 32, 12, 64, 1, 3, 1, 1)
# weird workloads
verify_conv3d_ncdhw(2, 2, 2, 2, 3, 3, 1, 2)
verify_conv3d_ncdhw(3, 3, 3, 3, 3, 3, 1, 3)
if __name__ == "__main__":
test_conv3d_ncdhw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment