Unverified Commit 02eb1833 by Josh Fromm Committed by GitHub

[Relay][Topi][AutoTVM] Winograd support for Conv3D (#5186)

* Functional conv3d winograd working.

* Formatted python code.

* registered conv3d winograd compute and started adding relay without_weight_transform operator.

* Add topi testing for conv3d winograd.

* Format file.

* small tweak to unrolling to prevent build sticking.

* Refactoring convolution ops in relay.

* Refactored relay convolutions.

* Bug fixes.

* Fixed static bug in convolution.

* Added conv3d alter op layout and related support.

* Bug fixes and testing done.

* Fix a few autotvm bugs.

* Drop silly debug print.

* Removed debug_skip_region.

* Add variant of conv3d_winograd that doesn't transform depth.

* initial infrastructure done for depthless conv.

* Fix no_depth schedule bugs.

* automatic topi switching between depth and depthless winograd.

* Fixed bug in schedule.

* lint fixes.

* Removed indents in convolution.cc

* missed a few indents oops.

* fixed flop count.

* One more small tweak.

* Change kernel pack inner axes order.

* Style changes.

* Comment fixes.
parent c76cbd8d
...@@ -82,8 +82,13 @@ This level enables typical convnet models. ...@@ -82,8 +82,13 @@ This level enables typical convnet models.
tvm.relay.nn.pad tvm.relay.nn.pad
tvm.relay.nn.lrn tvm.relay.nn.lrn
tvm.relay.nn.l2_normalize tvm.relay.nn.l2_normalize
tvm.relay.nn.bitpack
tvm.relay.nn.bitserial_dense
tvm.relay.nn.bitserial_conv2d
tvm.relay.nn.contrib_conv2d_winograd_without_weight_transform tvm.relay.nn.contrib_conv2d_winograd_without_weight_transform
tvm.relay.nn.contrib_conv2d_winograd_weight_transform tvm.relay.nn.contrib_conv2d_winograd_weight_transform
tvm.relay.nn.contrib_conv3d_winograd_without_weight_transform
tvm.relay.nn.contrib_conv3d_winograd_weight_transform
**Level 3: Additional Math And Transform Operators** **Level 3: Additional Math And Transform Operators**
......
...@@ -156,12 +156,12 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> { ...@@ -156,12 +156,12 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
}; };
/*! \brief Attributes used in winograd weight transformation operators */ /*! \brief Attributes used in winograd weight transformation operators */
struct Conv2DWinogradWeightTransformAttrs : struct ConvWinogradWeightTransformAttrs :
public tvm::AttrsNode<Conv2DWinogradWeightTransformAttrs> { public tvm::AttrsNode<ConvWinogradWeightTransformAttrs> {
int tile_size; int tile_size;
TVM_DECLARE_ATTRS(Conv2DWinogradWeightTransformAttrs, TVM_DECLARE_ATTRS(ConvWinogradWeightTransformAttrs,
"relay.attrs.Conv2DWinogradWeightTransformAttrs") { "relay.attrs.ConvWinogradWeightTransformAttrs") {
TVM_ATTR_FIELD(tile_size) TVM_ATTR_FIELD(tile_size)
.describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)"); .describe("Tile size of winograd. E.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)");
} }
...@@ -306,6 +306,69 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> { ...@@ -306,6 +306,69 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
} }
}; };
/*! \brief Attributes used in 3d winograd convolution operators */
struct Conv3DWinogradAttrs : public tvm::AttrsNode<Conv3DWinogradAttrs> {
int tile_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv3DWinogradAttrs, "relay.attrs.Conv3DWinogradAttrs") {
TVM_ATTR_FIELD(tile_size)
.describe("The tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)");
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"six int : padding width in the order of (front, top, left, back, bottom,"
"right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCDHW")
.describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Convolution is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW")
.describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
"'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height,"
"and width dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
/*! \brief Attributes used in softmax operators */ /*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> { struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
int axis; int axis;
......
...@@ -178,6 +178,29 @@ def legalize_conv2d_transpose(attrs, inputs, types): ...@@ -178,6 +178,29 @@ def legalize_conv2d_transpose(attrs, inputs, types):
reg.register_strategy("nn.conv3d", strategy.conv3d_strategy) reg.register_strategy("nn.conv3d", strategy.conv3d_strategy)
reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_alter_op_layout("nn.conv3d")
def alter_op_layout_conv3d(attrs, inputs, tinfos, out_type):
"""Alternate the layout of conv3d"""
return topi.nn.conv3d_alter_layout(attrs, inputs, tinfos, out_type)
# conv3d_winograd related operators
reg.register_strategy("nn.contrib_conv3d_winograd_without_weight_transform",
strategy.conv3d_winograd_without_weight_transfrom_strategy)
reg.register_pattern("nn.contrib_conv3d_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("nn.contrib_conv3d_winograd_weight_transform")
def compute_contrib_conv3d_winograd_weight_transform(attrs, inputs, out_dtype):
"""Compute definition of contrib_conv3d_winograd_weight_transform"""
out = topi.nn.conv3d_winograd_weight_transform(
inputs[0], attrs.get_int('tile_size'))
return [out]
reg.register_schedule("nn.contrib_conv3d_winograd_weight_transform",
strategy.schedule_conv3d_winograd_weight_transform)
reg.register_pattern("nn.contrib_conv3d_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)
# conv1d_transpose # conv1d_transpose
reg.register_strategy("nn.conv1d_transpose", strategy.conv1d_transpose_strategy) reg.register_strategy("nn.conv1d_transpose", strategy.conv1d_transpose_strategy)
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ...expr import TupleWrapper from ...expr import TupleWrapper
from . import _make from . import _make
from .util import get_pad_tuple2d from .util import get_pad_tuple2d, get_pad_tuple3d
def conv1d(data, def conv1d(data,
...@@ -295,13 +295,84 @@ def conv3d(data, ...@@ -295,13 +295,84 @@ def conv3d(data,
strides = (strides, strides, strides) strides = (strides, strides, strides)
if isinstance(dilation, int): if isinstance(dilation, int):
dilation = (dilation, dilation, dilation) dilation = (dilation, dilation, dilation)
if isinstance(padding, int): padding = get_pad_tuple3d(padding)
padding = (padding, padding, padding)
return _make.conv3d(data, weight, strides, padding, dilation, return _make.conv3d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout, groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype) kernel_layout, out_layout, out_dtype)
def contrib_conv3d_winograd_without_weight_transform(data,
weight,
tile_size,
strides=(1, 1, 1),
padding=(0, 0, 0),
dilation=(1, 1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCDHW",
kernel_layout="OIDHW",
out_layout="",
out_dtype=""):
r"""3D convolution with winograd algorithm.
The basic parameters are the same as the ones in vanilla conv3d.
It assumes the weight is pre-transformed by nn.contrib_conv3d_winograd_weight_transform
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
tile_size : int
The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)
strides : tuple of int, optional
The strides of convolution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
# convert 3-way padding to 6-way padding
padding = get_pad_tuple3d(padding)
return _make.contrib_conv3d_winograd_without_weight_transform(
data, weight, tile_size, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def conv2d_transpose(data, def conv2d_transpose(data,
weight, weight,
strides=(1, 1), strides=(1, 1),
...@@ -1952,6 +2023,29 @@ def contrib_conv2d_winograd_weight_transform(weight, ...@@ -1952,6 +2023,29 @@ def contrib_conv2d_winograd_weight_transform(weight,
return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size) return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)
def contrib_conv3d_winograd_weight_transform(weight,
tile_size):
r"""Weight Transformation part for 3D convolution with winograd algorithm.
We separate this as a single op to enable pre-compute for inference.
Use this together with nn.contrib_conv3d_winograd_without_weight_transform
Parameters
----------
weight : tvm.relay.Expr
The weight expressions.
tile_size : int
The Tile size of winograd. E.g. 2 for F(2x2x2, 3x3x3) and 4 for F(4x4x4, 3x3x3)
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_conv3d_winograd_weight_transform(weight, tile_size)
def contrib_conv2d_winograd_nnpack_weight_transform(weight, def contrib_conv2d_winograd_nnpack_weight_transform(weight,
convolution_algorithm, convolution_algorithm,
out_dtype=""): out_dtype=""):
......
...@@ -54,3 +54,46 @@ def get_pad_tuple2d(padding): ...@@ -54,3 +54,46 @@ def get_pad_tuple2d(padding):
pad_top = (pad_h + 1) // 2 pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2 pad_left = (pad_w + 1) // 2
return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left
def get_pad_tuple3d(padding):
"""Common code to get the pad option
Parameters
----------
padding : Union[int, Tuple[int, ...]]
Padding size
Returns
-------
pad_front : int
Padding size on front
pad_top : int
Padding size on top
pad_left : int
Padding size on left
pad_back : int
Padding size on back
pad_down : int
Padding size on down.
pad_right : int
Padding size on right.
"""
# compute the padding size
if isinstance(padding, container.Array):
padding = list(padding)
if isinstance(padding, (tuple, list)):
if len(padding) == 3:
pad_d = padding[0] * 2
pad_h = padding[1] * 2
pad_w = padding[2] * 2
elif len(padding) == 6:
return padding[0], padding[1], padding[2], padding[3], padding[4], padding[5]
else:
raise ValueError("Size of padding can only be 3 or 6")
elif isinstance(padding, int):
pad_d = pad_h = pad_w = padding * 2
else:
raise ValueError("Unknown padding option %s" % padding)
pad_front = (pad_d + 1) // 2
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
return pad_front, pad_top, pad_left, pad_d - pad_front, pad_h - pad_top, pad_w - pad_left
...@@ -34,9 +34,19 @@ class Conv2DWinogradAttrs(Attrs): ...@@ -34,9 +34,19 @@ class Conv2DWinogradAttrs(Attrs):
"""Attributes for nn.contrib_conv2d_winograd_without_weight_transform""" """Attributes for nn.contrib_conv2d_winograd_without_weight_transform"""
@tvm._ffi.register_object("relay.attrs.Conv2DWinogradWeightTransformAttrs") @tvm._ffi.register_object("relay.attrs.Conv3DAttrs")
class Conv2DWinogradWeightTransformAttrs(Attrs): class Conv3DAttrs(Attrs):
"""Attributes for nn.contrib_conv2d_winograd_weight_transform""" """Attributes for nn.conv3d"""
@tvm._ffi.register_object("relay.attrs.Conv3DWinogradAttrs")
class Conv3DWinogradAttrs(Attrs):
"""Attributes for nn.contrib_conv3d_winograd_without_weight_transform"""
@tvm._ffi.register_object("relay.attrs.ConvWinogradWeightTransformAttrs")
class ConvWinogradWeightTransformAttrs(Attrs):
"""Attributes for nn.contrib_convNd_winograd_weight_transform"""
@tvm._ffi.register_object("relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs") @tvm._ffi.register_object("relay.attrs.Conv2DWinogradNNPACKWeightTransformAttrs")
......
...@@ -233,13 +233,25 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target): ...@@ -233,13 +233,25 @@ def conv2d_transpose_strategy_cuda(attrs, inputs, out_type, target):
def conv3d_strategy_cuda(attrs, inputs, out_type, target): def conv3d_strategy_cuda(attrs, inputs, out_type, target):
"""conv3d cuda strategy""" """conv3d cuda strategy"""
strategy = _op.OpStrategy() strategy = _op.OpStrategy()
_, kernel = inputs
layout = attrs.data_layout layout = attrs.data_layout
_, stride_h, stride_w = attrs.get_int_tuple("strides")
_, dilation_h, dilation_w = attrs.get_int_tuple("dilation")
assert layout in ["NCDHW", "NDHWC"], "Not support this layout {} yet".format(layout) assert layout in ["NCDHW", "NDHWC"], "Not support this layout {} yet".format(layout)
if layout == "NCDHW": if layout == "NCDHW":
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ncdhw), strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ncdhw),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw), wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
name="conv3d_ncdhw.cuda", name="conv3d_ncdhw.cuda",
plevel=10) plevel=10)
_, _, _, kh, kw = get_const_tuple(kernel.shape)
if 2 < kh < 8 and 2 < kw < 8 and kh == kw and \
stride_h == 1 and stride_w == 1 and \
dilation_h == 1 and dilation_w == 1:
strategy.add_implementation(
wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw_winograd),
name="conv3d_ncdhw_winograd.cuda",
plevel=5)
else: # layout == "NDHWC": else: # layout == "NDHWC":
strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ndhwc), strategy.add_implementation(wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc), wrap_topi_schedule(topi.cuda.schedule_conv3d_ndhwc),
...@@ -252,6 +264,26 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target): ...@@ -252,6 +264,26 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
plevel=15) plevel=15)
return strategy return strategy
@conv3d_winograd_without_weight_transfrom_strategy.register(["cuda", "gpu"])
def conv3d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_type, target):
"""conv3d_winograd_without_weight_transfrom cuda strategy"""
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.data_layout
assert dilation == (1, 1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"
strategy = _op.OpStrategy()
if layout == "NCDHW":
strategy.add_implementation(
wrap_compute_conv3d(topi.cuda.conv3d_ncdhw_winograd_without_weight_transform),
wrap_topi_schedule(
topi.cuda.schedule_conv3d_ncdhw_winograd_without_weight_transform),
name="conv3d_ncdhw_winograd_without_weight_transform.cuda")
else:
raise RuntimeError("Unsupported conv3d_winograd_without_weight_transfrom layout {}".
format(layout))
return strategy
@conv1d_strategy.register(["cuda", "gpu"]) @conv1d_strategy.register(["cuda", "gpu"])
def conv1d_strategy_cuda(attrs, inputs, out_type, target): def conv1d_strategy_cuda(attrs, inputs, out_type, target):
"""conv1d cuda strategy""" """conv1d cuda strategy"""
......
...@@ -374,6 +374,19 @@ def conv3d_strategy(attrs, inputs, out_type, target): ...@@ -374,6 +374,19 @@ def conv3d_strategy(attrs, inputs, out_type, target):
raise ValueError("Not support this layout {} yet".format(layout)) raise ValueError("Not support this layout {} yet".format(layout))
return strategy return strategy
# conv3d_winograd_without_weight_transform
@override_native_generic_func("conv3d_winograd_without_weight_transform_strategy")
def conv3d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target):
"""conv3d_winograd_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv3d_winograd_without_weight_transform")
# conv3d_winograd_weight_transform
@generic_func
def schedule_conv3d_winograd_weight_transform(attrs, outs, target):
"""Schedule conv3d_winograd_weight_transform"""
with target:
return topi.generic.schedule_conv3d_winograd_weight_transform(outs)
# conv1d # conv1d
def wrap_compute_conv1d(topi_compute): def wrap_compute_conv1d(topi_compute):
"""wrap conv1d topi compute""" """wrap conv1d topi compute"""
......
...@@ -25,6 +25,7 @@ from tvm.relay import transform ...@@ -25,6 +25,7 @@ from tvm.relay import transform
from tvm.relay.testing import ctx_list, run_infer_type from tvm.relay.testing import ctx_list, run_infer_type
from tvm.contrib import util from tvm.contrib import util
import topi.testing import topi.testing
from topi.cuda.conv3d_winograd import _infer_tile_size
def test_conv1d_infer_type(): def test_conv1d_infer_type():
...@@ -326,7 +327,7 @@ def test_conv2d_winograd(): ...@@ -326,7 +327,7 @@ def test_conv2d_winograd():
cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1]) cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1]) cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg['tile_rc'] = autotvm.task.space.SplitEntity([-1, 1]) cfg['tile_rc'] = autotvm.task.space.SplitEntity([-1, 1])
cfg['auto_unroll_max_setp'] = autotvm.task.space.OtherOptionEntity(1500) cfg['auto_unroll_max_step'] = autotvm.task.space.OtherOptionEntity(1500)
cfg['unroll_explicit'] = autotvm.task.space.OtherOptionEntity(1) cfg['unroll_explicit'] = autotvm.task.space.OtherOptionEntity(1)
self.memory[key] = cfg self.memory[key] = cfg
return cfg return cfg
...@@ -522,6 +523,94 @@ def test_conv3d_ndhwc_run(): ...@@ -522,6 +523,94 @@ def test_conv3d_ndhwc_run():
run_test_conv3d("float32", "float32", 1, dshape, kshape, run_test_conv3d("float32", "float32", 1, dshape, kshape,
padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3), except_targets=["cuda"]) padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3), except_targets=["cuda"])
def test_conv3d_winograd():
class WinogradFallback(autotvm.FallbackContext):
def _query_inside(self, target, workload):
key = (target, workload)
if key in self.memory:
return self.memory[key]
cfg = autotvm.task.space.FallbackConfigEntity()
cfg.is_fallback = False
cfg.cost = 0.1 if 'winograd' in workload[0] else 1
cfg['tile_b'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg['tile_y'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg['tile_x'] = autotvm.task.space.SplitEntity([-1, 1, 1, 1])
cfg['tile_rc'] = autotvm.task.space.SplitEntity([-1, 1])
cfg['auto_unroll_max_step'] = autotvm.task.space.OtherOptionEntity(0)
cfg['unroll_explicit'] = autotvm.task.space.OtherOptionEntity(1)
self.memory[key] = cfg
return cfg
def run_test_conv3d_cuda(dtype, out_dtype, scale, dshape, kshape,
padding=(1, 1, 1),
groups=1,
dilation=(1, 1, 1),
prepack=False,
**attrs):
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", shape=kshape, dtype=dtype)
if prepack:
tile_size = _infer_tile_size(np.zeros(shape=dshape), np.zeros(shape=kshape))
w_packed = relay.nn.contrib_conv3d_winograd_weight_transform(w, tile_size)
y = relay.nn.contrib_conv3d_winograd_without_weight_transform(
x, w_packed, tile_size,
padding=padding,
dilation=dilation,
groups=groups,
channels=kshape[0],
**attrs)
else:
y = relay.nn.conv3d(x, w,
padding=padding,
dilation=dilation,
groups=groups,
**attrs)
func = relay.Function([x, w], y)
mod = tvm.IRModule()
mod['main'] = func
mod = relay.transform.InferType()(mod)
data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
ref_res = topi.testing.conv3d_ncdhw_python(
data.astype(out_dtype), kernel.astype(out_dtype), 1, padding,
groups=groups)
with WinogradFallback(), relay.build_config(opt_level=3):
for target, ctx in ctx_list():
if target != 'cuda':
continue
params = {'w': tvm.nd.array(kernel)}
graph, lib, params = relay.build_module.build(mod, target=target, params=params)
module = tvm.contrib.graph_runtime.create(graph, lib, ctx)
module.set_input('x', tvm.nd.array(data))
module.set_input(**params)
module.run()
op_res1 = module.get_output(0)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-3, atol=1e-3)
# normal winograd: stride 1, padding 1, kernel 3x3x3
dshape = (1, 32, 16, 16, 16)
kshape = (64, 32, 3, 3, 3)
run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
padding=(1, 1, 1), kernel_size=(3, 3, 3))
# Without depth transform using 1x3x3 kernel.
kshape = (64, 32, 1, 3, 3)
run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
padding=(0, 1, 1), kernel_size=(1, 3, 3))
# extended winograd: stride 1, padding N, kernel NxNxN
dshape = (1, 61, 20, 20, 20)
kshape = (120, 61, 5, 5, 5)
run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
padding=(2, 2, 2), channels=120, kernel_size=(5, 5, 5))
# Without depth transform
kshape = (120, 61, 1, 5, 5)
run_test_conv3d_cuda("float32", "float32", 1, dshape, kshape,
padding=(0, 2, 2), channels=120, kernel_size=(1, 5, 5))
def test_conv2d_transpose_infer_type(): def test_conv2d_transpose_infer_type():
# symbolic in batch dimension # symbolic in batch dimension
...@@ -1268,6 +1357,7 @@ if __name__ == "__main__": ...@@ -1268,6 +1357,7 @@ if __name__ == "__main__":
test_conv2d_winograd() test_conv2d_winograd()
test_conv3d_run() test_conv3d_run()
test_conv3d_ndhwc_run() test_conv3d_ndhwc_run()
test_conv3d_winograd()
test_bitserial_conv2d_infer_type() test_bitserial_conv2d_infer_type()
test_batch_flatten() test_batch_flatten()
test_upsampling() test_upsampling()
......
...@@ -31,6 +31,8 @@ from . import conv2d_alter_op ...@@ -31,6 +31,8 @@ from . import conv2d_alter_op
from .conv2d_transpose_nchw import * from .conv2d_transpose_nchw import *
from .deformable_conv2d import * from .deformable_conv2d import *
from .conv3d import * from .conv3d import *
from .conv3d_winograd import *
from . import conv3d_alter_op
from .reduction import schedule_reduce from .reduction import schedule_reduce
from .softmax import schedule_softmax from .softmax import schedule_softmax
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument
"""Conv3D alter op and legalize functions for cuda backend"""
import logging
import tvm
from tvm import te
from tvm import relay
from tvm import autotvm
from .. import nn
from ..util import get_const_tuple
from .conv3d_winograd import _infer_tile_size
logger = logging.getLogger('topi')
@nn.conv3d_alter_layout.register(["cuda", "gpu"])
def _alter_conv3d_layout(attrs, inputs, tinfos, out_type):
target = tvm.target.Target.current(allow_none=False)
dispatch_ctx = autotvm.task.DispatchContext.current
_, outs = relay.backend.compile_engine.select_implementation(
relay.op.get("nn.conv3d"), attrs, tinfos, out_type, target)
workload = autotvm.task.get_workload(outs)
if workload is None:
# The best implementation is not an AutoTVM template,
# we then assume it's not necessary to alter this op.
return None
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None
topi_tmpl = workload[0]
new_attrs = {k: attrs[k] for k in attrs.keys()}
strides = attrs.get_int_tuple("strides")
padding = attrs.get_int_tuple("padding")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int('groups')
data_layout = attrs["data_layout"]
kernel_layout = attrs["kernel_layout"]
data, kernel = tinfos
out_dtype = out_type.dtype
if topi_tmpl == "conv3d_ncdhw_winograd.cuda":
if dilation != (1, 1, 1):
logger.warning("Does not support weight pre-transform for dilated 3D convolution.")
return None
assert data_layout == "NCDHW" and kernel_layout == "OIDHW"
N, CI, D, H, W = get_const_tuple(data.shape)
CO, _, KD, KH, KW = get_const_tuple(kernel.shape)
# Pre-compute weight transformation in winograd
tile_size = _infer_tile_size(tinfos[0], tinfos[1])
weight = relay.nn.contrib_conv3d_winograd_weight_transform(inputs[1], tile_size=tile_size)
new_attrs['tile_size'] = tile_size
new_attrs['channels'] = CO
# Store the same config for the altered operators (workload)
new_data = data
# Check if depth is transformed or not
if 2 < KD < 8 and KD == KH:
new_weight = te.placeholder(
(KD + tile_size - 1, KH + tile_size - 1, KW + tile_size - 1, CO, CI),
dtype=kernel.dtype)
else:
new_weight = te.placeholder(
(KH + tile_size - 1, KW + tile_size - 1, KD, CO, CI),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation, out_dtype],
"conv3d_ncdhw_winograd_without_weight_transform.cuda")
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_conv3d_winograd_without_weight_transform(
inputs[0], weight, **new_attrs)
return None
...@@ -187,6 +187,43 @@ def schedule_conv2d_winograd_weight_transform(outs): ...@@ -187,6 +187,43 @@ def schedule_conv2d_winograd_weight_transform(outs):
return s return s
def schedule_conv3d_winograd_weight_transform(outs):
"""Schedule for weight transformation of 3D winograd
Parameters
----------
outs: Array of Tensor
The computation graph description of this operator
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
# Typically this is computed in PreCompute pass
# so we make a schedule here for cpu llvm
s = te.create_schedule([x.op for x in outs])
output = outs[0]
_, G = s[output].op.input_tensors
s[G].compute_inline()
transform_depth = len(s[output].op.reduce_axis) == 3
if transform_depth:
omg, eps, nu, ci, co = s[output].op.axis
r_kd, r_kh, r_kw = s[output].op.reduce_axis
s[output].reorder(co, ci, omg, eps, nu, r_kd, r_kh, r_kw)
for axis in [r_kd, r_kh, r_kw]:
s[output].unroll(axis)
else:
eps, nu, d, ci, co = s[output].op.axis
r_kh, r_kw = s[output].op.reduce_axis
s[output].reorder(co, ci, d, eps, nu, r_kh, r_kw)
for axis in [r_kh, r_kw]:
s[output].unroll(axis)
s[output].parallel(co)
return s
def schedule_conv2d_winograd_without_weight_transform(outs): def schedule_conv2d_winograd_without_weight_transform(outs):
"""Schedule for winograd without weight transformation """Schedule for winograd without weight transformation
......
...@@ -17,11 +17,13 @@ ...@@ -17,11 +17,13 @@
# pylint: disable=invalid-name, unused-variable, too-many-locals # pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin, no-else-return # pylint: disable=unused-argument, redefined-builtin, no-else-return
"""Conv3D operators""" """Conv3D operators"""
import tvm
from tvm import te from tvm import te
from .pad import pad from .pad import pad
from .util import get_pad_tuple3d from .util import get_pad_tuple3d
from ..util import simplify from ..util import simplify, get_const_tuple
from .winograd_util import winograd_transform_matrices
def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None): def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None):
...@@ -159,3 +161,74 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): ...@@ -159,3 +161,74 @@ def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
Filter[rd, rh, rw, rc, cc].astype(out_dtype), axis=[rd, rh, rw, rc]), Filter[rd, rh, rw, rc, cc].astype(out_dtype), axis=[rd, rh, rw, rc]),
name="Conv3dOutput", tag="conv3d_ndhwc") name="Conv3dOutput", tag="conv3d_ndhwc")
return Output return Output
def conv3d_winograd_weight_transform(kernel, tile_size):
"""Weight transformation for 3D winograd
Parameters
----------
kernel: Tensor
The raw kernel tensor with layout "NCDHW".
tile_size: int
Tile size of winograd transform. e.g. 2 for F(2x2, 3x3) and 4 for F(4x4, 3x3)
Returns
-------
output : tvm.te.Tensor
5-D with shape [alpha, alpha, alpha, CO, CI]
"""
CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
depth_transform = 2 < KD < 8 and KD == KH
if depth_transform:
assert KD == KH == KW, "Only support NxNxN kernel"
else:
assert KH == KW, "Only supports DxNxN kernel"
r = tile_size + KH - 1
r_kh = te.reduce_axis((0, KH), name='r_kh')
r_kw = te.reduce_axis((0, KW), name='r_kw')
_, _, G = winograd_transform_matrices(tile_size, KH, kernel.dtype)
if depth_transform:
shape = (r, r, r, CO, CI)
r_kd = te.reduce_axis((0, KD), name='r_kd')
return te.compute(
shape,
lambda omg, eps, nu, co, ci: te.sum(
kernel[co][ci][r_kd][r_kh][r_kw] * G[omg][r_kd] * G[eps][r_kh] * G[nu][r_kw],
axis=[r_kd, r_kh, r_kw]),
name='transform_weight')
else:
shape = (r, r, KD, CO, CI)
return te.compute(
shape,
lambda eps, nu, d, co, ci: te.sum(
kernel[co][ci][d][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]),
name='transform_weight')
@tvm.target.generic_func
def conv3d_alter_layout(attrs, inputs, tinfos, out_type):
"""Change Conv3D layout.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current convolution
inputs : tvm.relay.Expr
Grouped input symbols
tinfos : list
Input shape and dtype
out_type: type
The output type
Note
----
Unlike other TOPI functions, this function operates on both graph level and operator level.
"""
# not to change by default
return None
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for 3d convolution with winograd."""
import numpy as np
import tvm
from tvm import te
from tvm import autotvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.nn.util import get_pad_tuple3d
from topi.util import get_const_tuple
from common import get_all_backend
_conv3d_ncdhw_implement = {
"gpu": (topi.cuda.conv3d_ncdhw_winograd, topi.cuda.schedule_conv3d_ncdhw_winograd),
}
def verify_conv3d_ncdhw(batch,
in_channel,
in_size,
num_filter,
depth_kernel,
space_kernel,
stride,
padding,
dilation=1,
add_bias=False,
add_relu=False):
pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(
padding, (depth_kernel, space_kernel, space_kernel))
padding_sum = pad_front + pad_back + pad_top + pad_left + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" %
(batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation))
in_depth = in_height = in_width = in_size
A = te.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A')
W = te.placeholder((num_filter, in_channel, depth_kernel, space_kernel, space_kernel), name='W')
bias = te.placeholder((num_filter, 1, 1, 1), name='bias')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv3d_ncdhw.verify_conv3d_ncdhw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation, dilation))
c_np = topi.testing.conv3d_ncdhw_python(a_np, dw_np, stride, padding)
if add_bias:
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
fcompute, fschedule = topi.testing.dispatch(device, _conv3d_ncdhw_implement)
with tvm.target.create(device):
C = fcompute(A, W, (stride, stride, stride), padding, (dilation, dilation, dilation),
dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = fschedule([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(
s, [A, W, bias, C],
device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation))
func(a, w, b, c)
else:
func = tvm.build(
s, [A, W, C],
device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, space_kernel, stride, padding_sum, dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
for device in ["cuda"]:
with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
def test_conv3d_ncdhw():
# Try without depth transformation
#3DCNN workloads
verify_conv3d_ncdhw(1, 61, 20, 120, 3, 3, 1, 0)
verify_conv3d_ncdhw(1, 61, 20, 120, 1, 3, 1, 0)
verify_conv3d_ncdhw(1, 61, 20, 120, 5, 3, 1, 0)
verify_conv3d_ncdhw(1, 61, 20, 120, 5, 5, 1, 2)
verify_conv3d_ncdhw(1, 61, 20, 120, 1, 5, 1, 2)
verify_conv3d_ncdhw(1, 61, 20, 120, 7, 7, 1, 3)
verify_conv3d_ncdhw(1, 128, 12, 256, 3, 3, 1, 1)
verify_conv3d_ncdhw(1, 64, 12, 128, 3, 3, 1, 1)
# bias, relu
verify_conv3d_ncdhw(1, 64, 12, 128, 3, 3, 1, 1, add_relu=True)
verify_conv3d_ncdhw(1, 64, 12, 128, 3, 3, 1, 1, add_relu=True, add_bias=True)
verify_conv3d_ncdhw(1, 64, 12, 128, 1, 3, 1, 1, add_relu=True, add_bias=True)
# dilation = 2
verify_conv3d_ncdhw(1, 16, 12, 16, 3, 3, 1, "VALID", dilation=2)
verify_conv3d_ncdhw(1, 16, 12, 16, 1, 3, 1, "VALID", dilation=2)
# batch size
verify_conv3d_ncdhw(4, 32, 12, 64, 3, 3, 1, 1)
verify_conv3d_ncdhw(4, 32, 12, 64, 1, 3, 1, 1)
# weird workloads
verify_conv3d_ncdhw(2, 2, 2, 2, 3, 3, 1, 2)
verify_conv3d_ncdhw(3, 3, 3, 3, 3, 3, 1, 3)
if __name__ == "__main__":
test_conv3d_ncdhw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment