Commit 34b98eb7 by optima2005 Committed by Yizhi Liu

[CONV] Asymmetric padding (#4511)

* [CONV] Asymmetic padding

* fix lint error

* update for legalize, rocm and cudnn

* add more test cases

* change more symmetric padding

* change conv2d winograd tests according orginal cases

* remove 'alter_op_layout.h' header in bitserial.cc
parent 8e2f229a
...@@ -67,7 +67,10 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> { ...@@ -67,7 +67,10 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
.describe("Specifies the strides of the convolution."); .describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0})) TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded" .describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points"); "Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1})) TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution."); .describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1) TVM_ATTR_FIELD(groups).set_default(1)
...@@ -138,7 +141,10 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> { ...@@ -138,7 +141,10 @@ struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
.describe("Specifies the strides of the convolution."); .describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0})) TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded" .describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points"); "Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1})) TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution."); .describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1) TVM_ATTR_FIELD(groups).set_default(1)
...@@ -288,10 +294,17 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> { ...@@ -288,10 +294,17 @@ struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1})) TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("The strides of the convolution."); .describe("The strides of the convolution.");
TVM_ATTR_FIELD(output_padding).set_default(Array<IndexExpr>({0, 0})) TVM_ATTR_FIELD(output_padding).set_default(Array<IndexExpr>({0, 0}))
.describe("Zero-padding added to one side of the output."); .describe("Zero-padding added to one side of the output."
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0})) TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded" .describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points"); "Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1})) TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution."); .describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1) TVM_ATTR_FIELD(groups).set_default(1)
...@@ -817,7 +830,10 @@ struct DeformableConv2DAttrs : public tvm::AttrsNode<DeformableConv2DAttrs> { ...@@ -817,7 +830,10 @@ struct DeformableConv2DAttrs : public tvm::AttrsNode<DeformableConv2DAttrs> {
.describe("Specifies the strides of the convolution."); .describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0})) TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded" .describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points"); "Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1})) TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution."); .describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(deformable_groups).set_default(1) TVM_ATTR_FIELD(deformable_groups).set_default(1)
......
...@@ -84,7 +84,7 @@ def memoize(key, save_at_exit=False): ...@@ -84,7 +84,7 @@ def memoize(key, save_at_exit=False):
""" """
def _register(f): def _register(f):
"""Registration function""" """Registration function"""
allow_types = (string_types, int, float) allow_types = (string_types, int, float, tuple)
fkey = key + "." + f.__name__ + ".pkl" fkey = key + "." + f.__name__ + ".pkl"
if fkey not in Cache.cache_by_key: if fkey not in Cache.cache_by_key:
Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit) Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit)
......
...@@ -372,24 +372,7 @@ def _conv(opname): ...@@ -372,24 +372,7 @@ def _conv(opname):
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h) pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w) pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
if opname != 'conv_transpose': attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
if attr['data_format'] == 'NHWC':
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))
attr['padding'] = [0, 0]
else:
attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]
else: else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' \ msg = 'Value {} in attribute "padding" of operator Conv is not ' \
'valid.' 'valid.'
......
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
#include <tvm/relay/attrs/bitserial.h> #include <tvm/relay/attrs/bitserial.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include "../op_common.h"
#include "../../pass/infer_layout_util.h" #include "../../pass/infer_layout_util.h"
namespace tvm { namespace tvm {
...@@ -134,10 +135,12 @@ bool BinaryConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr ...@@ -134,10 +135,12 @@ bool BinaryConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr
CHECK(param->channels.defined()); CHECK(param->channels.defined());
CHECK(param->kernel_size.defined()); CHECK(param->kernel_size.defined());
Array<IndexExpr> oshape({dshape_nchw[0], param->channels, 0, 0}); Array<IndexExpr> oshape({dshape_nchw[0], param->channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
oshape.Set( oshape.Set(
2, (dshape_nchw[2] + param->padding[0] * 2 - param->kernel_size[0]) / param->strides[0] + 1); 2, (dshape_nchw[2] + pad_h - param->kernel_size[0]) / param->strides[0] + 1);
oshape.Set( oshape.Set(
3, (dshape_nchw[3] + param->padding[1] * 2 - param->kernel_size[1]) / param->strides[1] + 1); 3, (dshape_nchw[3] + pad_w - param->kernel_size[1]) / param->strides[1] + 1);
DataType out_dtype = param->out_dtype; DataType out_dtype = param->out_dtype;
oshape = trans_in_layout.BackwardShape(oshape); oshape = trans_in_layout.BackwardShape(oshape);
// assign output type // assign output type
......
...@@ -166,7 +166,6 @@ with the layer input to produce a tensor of outputs. ...@@ -166,7 +166,6 @@ with the layer input to produce a tensor of outputs.
.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>) .add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>); .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>);
// relay.nn.conv2d_transpose // relay.nn.conv2d_transpose
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
...@@ -250,18 +249,8 @@ bool Conv2DTransposeRel(const Array<Type>& types, ...@@ -250,18 +249,8 @@ bool Conv2DTransposeRel(const Array<Type>& types,
} }
// dilation // dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
auto pad_h = param->padding[0]; IndexExpr pad_h, pad_w;
auto pad_w = param->padding[1]; GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (param->padding.size() == 2) {
pad_h *= 2;
pad_w *= 2;
} else if (param->padding.size() == 4) {
pad_h += param->padding[2];
pad_w += param->padding[3];
} else {
CHECK_EQ(param->padding.size(), 4) << " Padding should be 2 or 4, but got "
<< param->padding.size();
}
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
pad_h + param->output_padding[0])); pad_h + param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
...@@ -557,14 +546,16 @@ bool Conv2DWinogradRel(const Array<Type>& types, ...@@ -557,14 +546,16 @@ bool Conv2DWinogradRel(const Array<Type>& types,
// dilation // dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (!dshape_nchw[2].as<ir::Any>()) { if (!dshape_nchw[2].as<ir::Any>()) {
oshape.Set(2, (dshape_nchw[2] + param->padding[0] * 2 oshape.Set(2, (dshape_nchw[2] + pad_h
- dilated_ksize_y) / param->strides[0] + 1); - dilated_ksize_y) / param->strides[0] + 1);
} else { } else {
oshape.Set(2, dshape_nchw[2]); oshape.Set(2, dshape_nchw[2]);
} }
if (!dshape_nchw[3].as<ir::Any>()) { if (!dshape_nchw[3].as<ir::Any>()) {
oshape.Set(3, (dshape_nchw[3] + param->padding[1] * 2 oshape.Set(3, (dshape_nchw[3] + pad_w
- dilated_ksize_x) / param->strides[1] + 1); - dilated_ksize_x) / param->strides[1] + 1);
} else { } else {
oshape.Set(3, dshape_nchw[3]); oshape.Set(3, dshape_nchw[3]);
...@@ -1015,9 +1006,11 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& ...@@ -1015,9 +1006,11 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
// dilation // dilation
Array<IndexExpr> oshape({data->shape[0], channels, 0, 0}); Array<IndexExpr> oshape({data->shape[0], channels, 0, 0});
oshape.Set(2, indexdiv(data->shape[2] + param->padding[0] * 2 - dilated_ksize_y, IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
oshape.Set(2, indexdiv(data->shape[2] + pad_h - dilated_ksize_y,
param->strides[0]) + 1); param->strides[0]) + 1);
oshape.Set(3, indexdiv(data->shape[3] + param->padding[1] * 2 - dilated_ksize_x, oshape.Set(3, indexdiv(data->shape[3] + pad_w - dilated_ksize_x,
param->strides[1]) + 1); param->strides[1]) + 1);
DataType out_dtype = param->out_dtype; DataType out_dtype = param->out_dtype;
......
...@@ -117,15 +117,17 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, ...@@ -117,15 +117,17 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// dilation // dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0}); Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (!dshape_nchw[2].as<ir::Any>()) { if (!dshape_nchw[2].as<ir::Any>()) {
oshape.Set(2, indexdiv(dshape_nchw[2] + param->padding[0] * 2 - dilated_ksize_y, oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y,
param->strides[0]) + 1); param->strides[0]) + 1);
} else { } else {
oshape.Set(2, dshape_nchw[2]); oshape.Set(2, dshape_nchw[2]);
} }
if (!dshape_nchw[3].as<ir::Any>()) { if (!dshape_nchw[3].as<ir::Any>()) {
oshape.Set(3, indexdiv(dshape_nchw[3] + param->padding[1] * 2 - dilated_ksize_x, oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x,
param->strides[1]) + 1); param->strides[1]) + 1);
} else { } else {
oshape.Set(3, dshape_nchw[3]); oshape.Set(3, dshape_nchw[3]);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import tvm import tvm
import numpy as np import numpy as np
import scipy.signal import scipy.signal
from topi.nn.util import get_pad_tuple
from tvm.contrib import nnpack from tvm.contrib import nnpack
import pytest import pytest
...@@ -59,17 +60,9 @@ def np_conv(na, nw, padding, stride=1): ...@@ -59,17 +60,9 @@ def np_conv(na, nw, padding, stride=1):
else: else:
stride_h, stride_w = stride stride_h, stride_w = stride
if isinstance(padding, int): pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w))
pad_h = pad_w = padding * 2 pad_h = pad_top + pad_bottom
else: pad_w = pad_left + pad_right
pad_h, pad_w = padding
pad_h *= 2
pad_w *= 2
pad_top = int(np.ceil(float(pad_h) / 2))
pad_bottom = pad_h - pad_top
pad_left = int(np.ceil(float(pad_w) / 2))
pad_right = pad_w - pad_left
out_channel = num_filter out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1 out_height = (in_height - kernel_h + pad_h) // stride_h + 1
...@@ -78,9 +71,9 @@ def np_conv(na, nw, padding, stride=1): ...@@ -78,9 +71,9 @@ def np_conv(na, nw, padding, stride=1):
for n in range(batch): for n in range(batch):
for f in range(out_channel): for f in range(out_channel):
for c in range(in_channel): for c in range(in_channel):
if pad_h > 0: if pad_h > 0 or pad_w > 0:
apad = np.zeros((in_height + pad_h, in_width + pad_w)) apad = np.zeros((in_height + pad_h, in_width + pad_w))
apad[pad_top:-pad_bottom, pad_left:-pad_right] = na[n, c] apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = na[n, c]
else: else:
apad = na[n, c] apad = na[n, c]
out = scipy.signal.convolve2d( out = scipy.signal.convolve2d(
......
...@@ -197,11 +197,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -197,11 +197,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
CO *= VC CO *= VC
KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1 KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
assert layout == 'NCHW' assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") data_pad = pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
idxd = tvm.indexdiv idxd = tvm.indexdiv
idxm = tvm.indexmod idxm = tvm.indexmod
...@@ -214,8 +214,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -214,8 +214,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
K = CO K = CO
C = CI C = CI
H = (IH + 2 * HPAD - 3) // HSTR + 1 H = (IH + pt + pb - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1 W = (IW + pl + pr - 3) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW P = N * nH * nW
...@@ -387,12 +387,13 @@ def conv2d_arm_cpu_winograd_nnpack( ...@@ -387,12 +387,13 @@ def conv2d_arm_cpu_winograd_nnpack(
assert len(kernel.shape) == 4 assert len(kernel.shape) == 4
CO, _, KH, KW = get_const_tuple(kernel.shape) CO, _, KH, KW = get_const_tuple(kernel.shape)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
assert layout == 'NCHW' assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1 assert KH == 3 and KW == 3 and pt == 1 and pb == 1 and pl == 1 and pr == 1 and HSTR == 1\
H = (IH + 2 * HPAD - 3) // HSTR + 1 and WSTR == 1
W = (IW + 2 * WPAD - 3) // WSTR + 1 H = (IH + pt + pb - 3) // HSTR + 1
W = (IW + pl + pr - 3) // WSTR + 1
cfg.define_knob('winograd_nnpack_algorithm', [convolution_algorithm]) cfg.define_knob('winograd_nnpack_algorithm', [convolution_algorithm])
...@@ -407,7 +408,7 @@ def conv2d_arm_cpu_winograd_nnpack( ...@@ -407,7 +408,7 @@ def conv2d_arm_cpu_winograd_nnpack(
output = tvm.contrib.nnpack.convolution_inference_without_weight_transform( output = tvm.contrib.nnpack.convolution_inference_without_weight_transform(
data, transformed_kernel, data, transformed_kernel,
bias=None, bias=None,
padding=[HPAD, HPAD, WPAD, WPAD], padding=[pt, pb, pl, pr],
stride=[HSTR, WSTR], stride=[HSTR, WSTR],
algorithm=cfg['winograd_nnpack_algorithm'].val) algorithm=cfg['winograd_nnpack_algorithm'].val)
...@@ -467,13 +468,14 @@ def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides, ...@@ -467,13 +468,14 @@ def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides,
assert len(transformed_kernel.shape) == 4 assert len(transformed_kernel.shape) == 4
CO, _, _, _ = get_const_tuple(transformed_kernel.shape) CO, _, _, _ = get_const_tuple(transformed_kernel.shape)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, (3, 3))
KH, KW = 3, 3 KH, KW = 3, 3
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
assert layout == 'NCHW' assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HPAD == 1 and WPAD == 1 and HSTR == 1 and WSTR == 1 assert KH == 3 and KW == 3 and pt == 1 and pb == 1 and pl == 1 and pr == 1 and HSTR == 1\
H = (IH + 2 * HPAD - 3) // HSTR + 1 and WSTR == 1
W = (IW + 2 * WPAD - 3) // WSTR + 1 H = (IH + pt + pb - 3) // HSTR + 1
W = (IW + pl + pr - 3) // WSTR + 1
assert N == 1 assert N == 1
with tvm.tag_scope("winograd_nnpack_conv2d_output"): with tvm.tag_scope("winograd_nnpack_conv2d_output"):
...@@ -481,7 +483,7 @@ def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides, ...@@ -481,7 +483,7 @@ def conv2d_winograd_nnpack_ww(cfg, data, transformed_kernel, bias, strides,
data=data, data=data,
transformed_kernel=transformed_kernel, transformed_kernel=transformed_kernel,
bias=bias, bias=bias,
padding=[HPAD, HPAD, WPAD, WPAD], padding=[pt, pb, pl, pr],
stride=[HSTR, WSTR], stride=[HSTR, WSTR],
algorithm=cfg['winograd_nnpack_algorithm'].val) algorithm=cfg['winograd_nnpack_algorithm'].val)
......
...@@ -276,11 +276,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -276,11 +276,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
H_CAT, W_CAT, CO, CI = get_const_tuple(kernel.shape) H_CAT, W_CAT, CO, CI = get_const_tuple(kernel.shape)
KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1 KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
assert layout == 'NCHW' assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") data_pad = pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
r = KW r = KW
m = tile_size m = tile_size
...@@ -289,8 +289,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -289,8 +289,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
K = CO K = CO
C = CI C = CI
H = (IH + 2 * HPAD - 3) // HSTR + 1 H = (IH + pt + pb - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1 W = (IW + pl + pr - 3) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW P = N * nH * nW
......
...@@ -21,6 +21,7 @@ from tvm import autotvm ...@@ -21,6 +21,7 @@ from tvm import autotvm
from tvm.contrib import cudnn from tvm.contrib import cudnn
from .. import nn, generic from .. import nn, generic
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple, traverse_inline from ..util import get_const_tuple, traverse_inline
from .conv2d_direct import schedule_direct_cuda from .conv2d_direct import schedule_direct_cuda
...@@ -49,8 +50,10 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou ...@@ -49,8 +50,10 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
strides : int or a list/tuple of two ints strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints padding : int or a list/tuple of 2 or 4 ints
padding size, or [pad_height, pad_width] padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
dilation: int or a list/tuple of two ints dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width] dilation size, or [dilation_height, dilation_width]
...@@ -81,11 +84,13 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou ...@@ -81,11 +84,13 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
# handle dilation # handle dilation
stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding
dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
OH = (H + 2 * pad_h - KH) // stride_h + 1 if isinstance(padding, (list, tuple)) and len(padding) > 2:
OW = (W + 2 * pad_w - KW) // stride_w + 1 raise ValueError("Cudnn doesn't support asymmetric padding.")
pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
OH = (H + pt + pb - KH) // stride_h + 1
OW = (W + pl + pr - KW) // stride_w + 1
cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\ cfg.add_flop(2 * N * OH * OW * CO * CI * ((KH - 1) * dilation_h + 1) *\
((KW - 1) * dilation_w + 1)) ((KW - 1) * dilation_w + 1))
...@@ -98,7 +103,7 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou ...@@ -98,7 +103,7 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
return cudnn.conv_forward(data, return cudnn.conv_forward(data,
kernel, kernel,
[pad_h, pad_w], [pt, pl], # cudnn padding pt, pl on both sides of input
[stride_h, stride_w], [stride_h, stride_w],
[dilation_h, dilation_w], [dilation_h, dilation_w],
conv_mode=1, conv_mode=1,
......
...@@ -64,15 +64,15 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty ...@@ -64,15 +64,15 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dty
KH = KW = alpha + 1 - tile_size KH = KW = alpha + 1 - tile_size
assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1
HPAD, WPAD, _, _ = nn.get_pad_tuple(padding, kernel) pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW))
data_pad = nn.pad(data, (0, 0, HPAD, WPAD), (0, 0, HPAD, WPAD), name="data_pad") data_pad = nn.pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
r = KW r = KW
m = tile_size m = tile_size
A, B, G = winograd_transform_matrices(m, r, out_dtype) A, B, G = winograd_transform_matrices(m, r, out_dtype)
H = (H + 2 * HPAD - KH) // HSTR + 1 H = (H + pt + pb - KH) // HSTR + 1
W = (W + 2 * WPAD - KW) // WSTR + 1 W = (W + pl + pr - KW) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW P = N * nH * nW
......
...@@ -83,10 +83,10 @@ def _create_schedule_template(cfg, data, kernel, strides, padding, dilation, lay ...@@ -83,10 +83,10 @@ def _create_schedule_template(cfg, data, kernel, strides, padding, dilation, lay
else: else:
raise ValueError("Not support this layout {} with " raise ValueError("Not support this layout {} with "
"schedule template.".format(layout)) "schedule template.".format(layout))
ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) pt, pl, pb, pr = get_pad_tuple(padding, kernel)
sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
oh = (h - kh + 2 * ph) // sh + 1 oh = (h - kh + pt + pb) // sh + 1
ow = (w - kw + 2 * pw) // sw + 1 ow = (w - kw + pl + pr) // sw + 1
ic_bn_upper = 32 ic_bn_upper = 32
oc_bn_upper = 64 oc_bn_upper = 64
oc_bn_lower = min(oc, 8) oc_bn_lower = min(oc, 8)
......
...@@ -226,19 +226,19 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -226,19 +226,19 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
CO *= VC CO *= VC
KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1 KH, KW = H_CAT - tile_size + 1, W_CAT - tile_size + 1
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
assert layout == 'NCHW' assert layout == 'NCHW'
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") data_pad = pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad")
r = KW r = KW
m = tile_size m = tile_size
alpha = m + r - 1 alpha = m + r - 1
A, B, G = winograd_transform_matrices(m, r, out_dtype) A, B, G = winograd_transform_matrices(m, r, out_dtype)
H = (IH + 2 * HPAD - 3) // HSTR + 1 H = (IH + pt + pb - 3) // HSTR + 1
W = (IW + 2 * WPAD - 3) // WSTR + 1 W = (IW + pl + pr - 3) // WSTR + 1
nH, nW = (H + m-1) // m, (W + m-1) // m nH, nW = (H + m-1) // m, (W + m-1) // m
P = N * nH * nW P = N * nH * nW
......
...@@ -23,7 +23,7 @@ import tvm ...@@ -23,7 +23,7 @@ import tvm
from .pad import pad from .pad import pad
from .util import get_pad_tuple from .util import get_pad_tuple
from ..util import simplify, get_const_tuple from ..util import simplify, get_const_tuple, get_const_int
from .winograd_util import winograd_transform_matrices from .winograd_util import winograd_transform_matrices
# workload description of conv2d # workload description of conv2d
...@@ -46,8 +46,10 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N ...@@ -46,8 +46,10 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
strides : int or a list/tuple of two ints strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints padding : int or a list/tuple of 2 or 4 ints
padding size, or [pad_height, pad_width] padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
dilation: int or a list/tuple of two ints dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width] dilation size, or [dilation_height, dilation_width]
...@@ -153,7 +155,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'): ...@@ -153,7 +155,7 @@ def _get_workload(data, kernel, stride, padding, out_dtype, data_layout='NCHW'):
else: else:
KH, KW, CIG, CO = [x.value for x in kernel.shape] KH, KW, CIG, CO = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel) HPAD, WPAD, _, _ = get_pad_tuple(padding, (get_const_int(KH), get_const_int(KW)))
GRPS = CI // CIG GRPS = CI // CIG
if isinstance(stride, (tuple, list)): if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride HSTR, WSTR = stride
...@@ -179,8 +181,10 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): ...@@ -179,8 +181,10 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width] Stride size, or [stride_height, stride_width]
padding : int or str padding : int or a list/tuple of 2 or 4 ints
Padding size, or ['VALID', 'SAME'] padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
dilation: int or a list/tuple of two ints dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width] dilation size, or [dilation_height, dilation_width]
...@@ -221,7 +225,6 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None): ...@@ -221,7 +225,6 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
rc = tvm.reduce_axis((0, in_channel), name='rc') rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry') ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx') rx = tvm.reduce_axis((0, kernel_w), name='rx')
return tvm.compute( return tvm.compute(
(batch, out_channel, out_height, out_width), (batch, out_channel, out_height, out_width),
lambda nn, ff, yy, xx: tvm.sum( lambda nn, ff, yy, xx: tvm.sum(
...@@ -245,8 +248,10 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None): ...@@ -245,8 +248,10 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width] Stride size, or [stride_height, stride_width]
padding : int or str padding : int or a list/tuple of 2 or 4 ints
Padding size, or ['VALID', 'SAME'] padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
dilation: int or a list/tuple of two ints dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width] dilation size, or [dilation_height, dilation_width]
...@@ -311,8 +316,10 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'): ...@@ -311,8 +316,10 @@ def conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width] Stride size, or [stride_height, stride_width]
padding : int or str padding : int or a list/tuple of 2 or 4 ints
Padding size, or ['VALID', 'SAME'] padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
dilation: int or a list/tuple of two ints dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width] dilation size, or [dilation_height, dilation_width]
...@@ -378,8 +385,10 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou ...@@ -378,8 +385,10 @@ def conv2d_NCHWc(data, kernel, stride, padding, dilation, layout, out_layout, ou
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints padding : int or a list/tuple of 2 or 4 ints
padding size, or [pad_height, pad_width] padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
dilation: int or a list/tuple of two ints dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width] dilation size, or [dilation_height, dilation_width]
...@@ -425,8 +434,10 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l ...@@ -425,8 +434,10 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints padding : int or a list/tuple of 2 or 4 ints
padding size, or [pad_height, pad_width] padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
dilation: int or a list/tuple of two ints dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width] dilation size, or [dilation_height, dilation_width]
...@@ -448,7 +459,6 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l ...@@ -448,7 +459,6 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l
# layout and out_layout are not used here, # layout and out_layout are not used here,
# we keep them for debug convenience when dumping autotvm workload # we keep them for debug convenience when dumping autotvm workload
HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \ dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
else (dilation, dilation) else (dilation, dilation)
...@@ -464,15 +474,22 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l ...@@ -464,15 +474,22 @@ def conv2d_NCHWc_compute(data, kernel, strides, padding, dilation, layout, out_l
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
HPAD = pad_top + pad_down
WPAD = pad_left + pad_right
# output shape # output shape
out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1 out_height = (ih + HPAD - dilated_kernel_h) // HSTR + 1
out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1 out_width = (iw + WPAD - dilated_kernel_w) // WSTR + 1
oshape = (n, oc_chunk, out_height, out_width, oc_bn) oshape = (n, oc_chunk, out_height, out_width, oc_bn)
pad_before = (0, 0, pad_top, pad_left, 0)
pad_after = (0, 0, pad_down, pad_right, 0)
# DOPAD # DOPAD
DOPAD = (HPAD != 0 or WPAD != 0) DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD: if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") data_pad = pad(data, pad_before, pad_after, name="data_pad")
else: else:
data_pad = data data_pad = data
...@@ -517,8 +534,10 @@ def conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, layout, out_layo ...@@ -517,8 +534,10 @@ def conv2d_NCHWc_int8(data, kernel, strides, padding, dilation, layout, out_layo
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints padding : int or a list/tuple of 2 or 4 ints
padding size, or [pad_height, pad_width] padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
dilation: int or a list/tuple of two ints dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width] dilation size, or [dilation_height, dilation_width]
...@@ -565,8 +584,10 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, ...@@ -565,8 +584,10 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout,
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints padding : int or a list/tuple of 2 or 4 ints
padding size, or [pad_height, pad_width] padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
dilation: int or a list/tuple of two ints dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width] dilation size, or [dilation_height, dilation_width]
...@@ -588,7 +609,6 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, ...@@ -588,7 +609,6 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout,
# layout and out_layout are not used here, # layout and out_layout are not used here,
# we keep them for debug convenience when dumping autotvm workload # we keep them for debug convenience when dumping autotvm workload
HPAD, WPAD = padding if isinstance(padding, (tuple, list)) else (padding, padding)
HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides) HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \ dilation_h, dilation_w = dilation if isinstance(dilation, (tuple, list)) \
else (dilation, dilation) else (dilation, dilation)
...@@ -603,15 +623,23 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout, ...@@ -603,15 +623,23 @@ def conv2d_NCHWc_int8_compute(data, kernel, strides, padding, dilation, layout,
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
HPAD = pad_top + pad_down
WPAD = pad_left + pad_right
# output shape # output shape
out_height = (ih + 2 * HPAD - dilated_kernel_h) // HSTR + 1 out_height = (ih + HPAD - dilated_kernel_h) // HSTR + 1
out_width = (iw + 2 * WPAD - dilated_kernel_w) // WSTR + 1 out_width = (iw + WPAD - dilated_kernel_w) // WSTR + 1
oshape = (n, oc_chunk, out_height, out_width, oc_bn) oshape = (n, oc_chunk, out_height, out_width, oc_bn)
pad_before = (0, 0, pad_top, pad_left, 0)
pad_after = (0, 0, pad_down, pad_right, 0)
# DOPAD # DOPAD
DOPAD = (HPAD != 0 or WPAD != 0) DOPAD = (HPAD != 0 or WPAD != 0)
if DOPAD: if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD, 0), name="data_pad") data_pad = pad(data, pad_before, pad_after, name="data_pad")
else: else:
data_pad = data data_pad = data
...@@ -780,8 +808,10 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp ...@@ -780,8 +808,10 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width] Stride size, or [stride_height, stride_width]
padding : int or str padding : int or a list/tuple of 2 or 4 ints
Padding size, or ['VALID', 'SAME'] padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
dilation : int or a list/tuple of two ints dilation : int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width] dilation size, or [dilation_height, dilation_width]
......
...@@ -23,6 +23,7 @@ from tvm.contrib import miopen ...@@ -23,6 +23,7 @@ from tvm.contrib import miopen
from .. import nn, generic from .. import nn, generic
from ..util import get_const_tuple from ..util import get_const_tuple
from ..cuda.conv2d import conv2d_cuda, schedule_conv2d_nchw_cuda from ..cuda.conv2d import conv2d_cuda, schedule_conv2d_nchw_cuda
from ..nn.util import get_pad_tuple
@autotvm.register_topi_compute(nn.conv2d, 'rocm', ['direct', 'winograd']) @autotvm.register_topi_compute(nn.conv2d, 'rocm', ['direct', 'winograd'])
def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', out_dtype='float32'): def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', out_dtype='float32'):
...@@ -42,8 +43,10 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou ...@@ -42,8 +43,10 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
strides : int or a list/tuple of two ints strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width] stride size, or [stride_height, stride_width]
padding : int or a list/tuple of two ints padding : int or a list/tuple of 2 or 4 ints
padding size, or [pad_height, pad_width] padding size, or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 4 ints
layout : str layout : str
layout of data layout of data
...@@ -62,7 +65,8 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou ...@@ -62,7 +65,8 @@ def conv2d_rocm(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
# handle dilation # handle dilation
stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides stride_h, stride_w = (strides, strides) if isinstance(strides, int) else strides
pad_h, pad_w = (padding, padding) if isinstance(padding, int) else padding pt, pl, pb, pr = get_pad_tuple(padding, (KH, KW))
pad_h, pad_w = pt + pb, pl + pr
dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation dilation_h, dilation_w = (dilation, dilation) if isinstance(dilation, int) else dilation
OH = (H + 2 * pad_h - KH) // stride_h + 1 OH = (H + 2 * pad_h - KH) // stride_h + 1
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
"""Convolution in python""" """Convolution in python"""
import numpy as np import numpy as np
import scipy.signal import scipy.signal
from topi.nn.util import get_pad_tuple
def conv2d_hwcn_python(a_np, w_np, stride, padding): def conv2d_hwcn_python(a_np, w_np, stride, padding):
...@@ -34,8 +35,10 @@ def conv2d_hwcn_python(a_np, w_np, stride, padding): ...@@ -34,8 +35,10 @@ def conv2d_hwcn_python(a_np, w_np, stride, padding):
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width] Stride size, or [stride_height, stride_width]
padding : int or str padding : int or str or a list/tuple of 2 or 4 ints
Padding size, or ['VALID', 'SAME'] Padding size, or ['VALID', 'SAME'], or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 2 ints
Returns Returns
------- -------
...@@ -48,18 +51,10 @@ def conv2d_hwcn_python(a_np, w_np, stride, padding): ...@@ -48,18 +51,10 @@ def conv2d_hwcn_python(a_np, w_np, stride, padding):
stride_h = stride_w = stride stride_h = stride_w = stride
else: else:
stride_h, stride_w = stride stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding * 2 pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w))
elif padding == 'VALID': pad_h = pad_top + pad_bottom
pad_h = 0 pad_w = pad_left + pad_right
pad_w = 0
else: # 'SAME'
pad_h = kernel_h - 1
pad_w = kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_bottom = pad_h - pad_top
pad_left = int(np.ceil(float(pad_w) / 2))
pad_right = pad_w - pad_left
# compute the output shape # compute the output shape
out_channel = num_filter out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1 out_height = (in_height - kernel_h + pad_h) // stride_h + 1
...@@ -72,9 +67,9 @@ def conv2d_hwcn_python(a_np, w_np, stride, padding): ...@@ -72,9 +67,9 @@ def conv2d_hwcn_python(a_np, w_np, stride, padding):
for n in range(batch): for n in range(batch):
for f in range(out_channel): for f in range(out_channel):
for c in range(in_channel): for c in range(in_channel):
if pad_h > 0: if pad_h > 0 or pad_w > 0:
apad = np.zeros((in_height + pad_h, in_width + pad_w)) apad = np.zeros((in_height + pad_h, in_width + pad_w))
apad[pad_top:-pad_bottom, pad_left:-pad_right] = at[n, c] apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = at[n, c]
else: else:
apad = at[n, c] apad = at[n, c]
out = scipy.signal.convolve2d( out = scipy.signal.convolve2d(
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
"""Convolution in python""" """Convolution in python"""
import numpy as np import numpy as np
import scipy.signal import scipy.signal
from topi.nn.util import get_pad_tuple
def _conv2d_nchw_python(a_np, w_np, stride, padding): def _conv2d_nchw_python(a_np, w_np, stride, padding):
...@@ -34,8 +35,10 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding): ...@@ -34,8 +35,10 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding):
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width] Stride size, or [stride_height, stride_width]
padding : int or str or a list/tuple of two ints padding : int or str or a list/tuple of 2 or 4 ints
Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width] Padding size, or ['VALID', 'SAME'], or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 2 ints
Returns Returns
------- -------
...@@ -48,17 +51,9 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding): ...@@ -48,17 +51,9 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding):
stride_h = stride_w = stride stride_h = stride_w = stride
else: else:
stride_h, stride_w = stride stride_h, stride_w = stride
if isinstance(padding, int): pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w))
pad_h = pad_w = padding * 2 pad_h = pad_top + pad_bottom
elif isinstance(padding, (list, tuple)): pad_w = pad_left + pad_right
pad_h, pad_w = padding[0] * 2, padding[1] * 2
else:
pad_h = 0 if padding == 'VALID' else kernel_h - 1
pad_w = 0 if padding == 'VALID' else kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_bottom = pad_h - pad_top
pad_left = int(np.ceil(float(pad_w) / 2))
pad_right = pad_w - pad_left
# compute the output shape # compute the output shape
out_channel = num_filter out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1 out_height = (in_height - kernel_h + pad_h) // stride_h + 1
...@@ -70,12 +65,7 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding): ...@@ -70,12 +65,7 @@ def _conv2d_nchw_python(a_np, w_np, stride, padding):
for c in range(in_channel): for c in range(in_channel):
if pad_h > 0 or pad_w > 0: if pad_h > 0 or pad_w > 0:
apad = np.zeros((in_height + pad_h, in_width + pad_w)) apad = np.zeros((in_height + pad_h, in_width + pad_w))
if pad_h == 0: apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = a_np[n, c]
apad[:, pad_left:-pad_right] = a_np[n, c]
elif pad_w == 0:
apad[pad_top:-pad_bottom, :] = a_np[n, c]
else:
apad[pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c]
else: else:
apad = a_np[n, c] apad = a_np[n, c]
out = scipy.signal.convolve2d( out = scipy.signal.convolve2d(
...@@ -98,8 +88,10 @@ def conv2d_nchw_python(a_np, w_np, stride, padding, groups=1): ...@@ -98,8 +88,10 @@ def conv2d_nchw_python(a_np, w_np, stride, padding, groups=1):
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width] Stride size, or [stride_height, stride_width]
padding : int or str or a list/tuple of two ints padding : int or str or a list/tuple of 2 or 4 ints
Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width] Padding size, or ['VALID', 'SAME'], or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 2 ints
groups : int groups : int
Number of groups Number of groups
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
"""Convolution in python""" """Convolution in python"""
import numpy as np import numpy as np
import scipy.signal import scipy.signal
from topi.nn.util import get_pad_tuple
def conv2d_nhwc_python(a_np, w_np, stride, padding): def conv2d_nhwc_python(a_np, w_np, stride, padding):
...@@ -34,8 +35,10 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): ...@@ -34,8 +35,10 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding):
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width] Stride size, or [stride_height, stride_width]
padding : int or str padding : int or str or a list/tuple of 2 or 4 ints
Padding size, or ['VALID', 'SAME'] Padding size, or ['VALID', 'SAME'], or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 2 ints
Returns Returns
------- -------
...@@ -48,18 +51,11 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): ...@@ -48,18 +51,11 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding):
stride_h = stride_w = stride stride_h = stride_w = stride
else: else:
stride_h, stride_w = stride stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding * 2 pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel_h, kernel_w))
elif padding == 'VALID': pad_h = pad_top + pad_bottom
pad_h = 0 pad_w = pad_left + pad_right
pad_w = 0
else: # 'SAME'
pad_h = kernel_h - 1
pad_w = kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_bottom = pad_h - pad_top
pad_left = int(np.ceil(float(pad_w) / 2))
pad_right = pad_w - pad_left
# compute the output shape # compute the output shape
out_channel = num_filter out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1 out_height = (in_height - kernel_h + pad_h) // stride_h + 1
...@@ -72,9 +68,9 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding): ...@@ -72,9 +68,9 @@ def conv2d_nhwc_python(a_np, w_np, stride, padding):
for n in range(batch): for n in range(batch):
for f in range(out_channel): for f in range(out_channel):
for c in range(in_channel): for c in range(in_channel):
if pad_h > 0: if pad_h > 0 or pad_w > 0:
apad = np.zeros((in_height + pad_h, in_width + pad_w)) apad = np.zeros((in_height + pad_h, in_width + pad_w))
apad[pad_top:-pad_bottom, pad_left:-pad_right] = at[n, c] apad[pad_top:pad_top + in_height, pad_left:pad_left + in_width] = at[n, c]
else: else:
apad = at[n, c] apad = at[n, c]
out = scipy.signal.convolve2d( out = scipy.signal.convolve2d(
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
"""Deformable convolution in python""" """Deformable convolution in python"""
import itertools import itertools
import numpy as np import numpy as np
from topi.nn.util import get_pad_tuple
def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilation, def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilation,
deformable_groups, groups): deformable_groups, groups):
...@@ -39,8 +39,10 @@ def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilati ...@@ -39,8 +39,10 @@ def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilati
stride : int or a list/tuple of two ints stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width] Stride size, or [stride_height, stride_width]
padding : int or str or a list/tuple of two ints padding : int or str or a list/tuple of 2 or 4 ints
Padding size, or ['VALID', 'SAME'], or [pad_height, pad_width] Padding size, or ['VALID', 'SAME'], or
[pad_height, pad_width] for 2 ints, or
[pad_top, pad_left, pad_bottom, pad_right] for 2 ints
dilation : int or a list/tuple of two ints dilation : int or a list/tuple of two ints
Dilation size, or [dilate_height, dilate_width] Dilation size, or [dilate_height, dilate_width]
...@@ -67,15 +69,9 @@ def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilati ...@@ -67,15 +69,9 @@ def deformable_conv2d_nchw_python(a_np, offset_np, w_np, stride, padding, dilati
stride_h = stride_w = stride stride_h = stride_w = stride
else: else:
stride_h, stride_w = stride stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding * 2 pad_top, pad_left, _, _ = get_pad_tuple(padding, (kernel_h, kernel_w))
elif isinstance(padding, (list, tuple)):
pad_h, pad_w = padding[0] * 2, padding[1] * 2
else:
pad_h = 0 if padding == 'VALID' else kernel_h - 1
pad_w = 0 if padding == 'VALID' else kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_left = int(np.ceil(float(pad_w) / 2))
if isinstance(dilation, int): if isinstance(dilation, int):
dilation_h = dilation_w = dilation dilation_h = dilation_w = dilation
else: else:
......
...@@ -30,6 +30,7 @@ from ..nn.conv2d import conv2d, conv2d_NCHWc, \ ...@@ -30,6 +30,7 @@ from ..nn.conv2d import conv2d, conv2d_NCHWc, \
conv2d_infer_layout, _get_workload as _get_conv2d_workload conv2d_infer_layout, _get_workload as _get_conv2d_workload
from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload from ..nn.depthwise_conv2d import _get_workload as _get_depthwise_conv2d_workload
from ..nn.pad import pad from ..nn.pad import pad
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple from ..util import get_const_tuple
from . import conv2d_avx_1x1, conv2d_avx_common from . import conv2d_avx_1x1, conv2d_avx_common
...@@ -84,10 +85,10 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): ...@@ -84,10 +85,10 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
"schedule template.".format(layout)) "schedule template.".format(layout))
is_kernel_1x1 = kh == 1 and kw == 1 is_kernel_1x1 = kh == 1 and kw == 1
ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw))
sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
oh = (h - kh + 2 * ph) // sh + 1 oh = (h - kh + pt + pb) // sh + 1
ow = (w - kw + 2 * pw) // sw + 1 ow = (w - kw + pl + pr) // sw + 1
# Create schedule config # Create schedule config
cfg.define_split("tile_ic", ic, num_outputs=2) cfg.define_split("tile_ic", ic, num_outputs=2)
...@@ -102,7 +103,6 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout): ...@@ -102,7 +103,6 @@ def _create_tuning_space(cfg, data, kernel, strides, padding, dilation, layout):
@autotvm.register_topi_compute(conv2d, 'cpu', ['direct']) @autotvm.register_topi_compute(conv2d, 'cpu', ['direct'])
def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype): def _declaration_conv(cfg, data, kernel, strides, padding, dilation, layout, out_dtype):
out_dtype = data.dtype if out_dtype is None else out_dtype out_dtype = data.dtype if out_dtype is None else out_dtype
padding = padding if isinstance(padding, (tuple, list)) else (padding, padding)
strides = strides if isinstance(strides, (tuple, list)) else (strides, strides) strides = strides if isinstance(strides, (tuple, list)) else (strides, strides)
dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) dilation = dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation)
...@@ -141,24 +141,27 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout ...@@ -141,24 +141,27 @@ def _declaration_conv_impl(cfg, data, kernel, strides, padding, dilation, layout
else: else:
dilation_h, dilation_w = dilation dilation_h, dilation_w = dilation
HPAD, WPAD = padding
HSTR, WSTR = strides HSTR, WSTR = strides
batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape) batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape)
num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape) num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape)
pad_height = in_height + 2 * HPAD pad_top, pad_left, pad_down, pad_right = get_pad_tuple(padding, (kernel_height, kernel_width))
pad_width = in_width + 2 * WPAD pad_h = pad_top + pad_down
pad_w = pad_left + pad_right
pad_height = in_height + pad_h
pad_width = in_width + pad_w
dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
out_height = (in_height + 2 * HPAD - dilated_kernel_h) // HSTR + 1 out_height = (in_height + pad_h - dilated_kernel_h) // HSTR + 1
out_width = (in_width + 2 * WPAD - dilated_kernel_w) // WSTR + 1 out_width = (in_width + pad_w - dilated_kernel_w) // WSTR + 1
# pack data # pack data
DOPAD = (HPAD != 0 or WPAD != 0) DOPAD = (pad_h != 0 or pad_w != 0)
if DOPAD: if DOPAD:
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") data_pad = pad(data, (0, 0, pad_top, pad_left), (0, 0, pad_down, pad_right), \
name="data_pad")
else: else:
data_pad = data data_pad = data
...@@ -353,8 +356,9 @@ def _conv2d_infer_layout(workload, cfg): ...@@ -353,8 +356,9 @@ def _conv2d_infer_layout(workload, cfg):
out_channel, _, k_height, k_width = kernel[:-1] out_channel, _, k_height, k_width = kernel[:-1]
idxdiv = tvm.indexdiv idxdiv = tvm.indexdiv
out_height = idxdiv(in_height + 2 * padding[0] - k_height, strides[0]) + 1 pt, pl, pb, pr = get_pad_tuple(padding, (k_height, k_width))
out_width = idxdiv(in_width + 2 * padding[1] - k_width, strides[1]) + 1 out_height = idxdiv(in_height + pt + pb - k_height, strides[0]) + 1
out_width = idxdiv(in_width + pl + pr - k_width, strides[1]) + 1
tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] tile_ic, tile_oc = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic) in_shape = (batch_size, idxdiv(in_channel, tile_ic), in_height, in_width, tile_ic)
in_layout = "NCHW%dc" % tile_ic in_layout = "NCHW%dc" % tile_ic
......
...@@ -28,6 +28,7 @@ from ..util import get_const_tuple, get_shape ...@@ -28,6 +28,7 @@ from ..util import get_const_tuple, get_shape
from ..nn import conv2d_legalize from ..nn import conv2d_legalize
from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_NCHWc_int8, conv2d_alter_layout from ..nn.conv2d import conv2d, conv2d_NCHWc, conv2d_NCHWc_int8, conv2d_alter_layout
from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw from ..nn.depthwise_conv2d import depthwise_conv2d_NCHWc, depthwise_conv2d_nchw
from ..nn.util import get_pad_tuple
logger = logging.getLogger('topi') logger = logging.getLogger('topi')
...@@ -221,12 +222,14 @@ def _conv2d_legalize(attrs, inputs, arg_types): ...@@ -221,12 +222,14 @@ def _conv2d_legalize(attrs, inputs, arg_types):
if data_tensor.dtype == 'int8' and kernel_tensor.dtype == 'int8': if data_tensor.dtype == 'int8' and kernel_tensor.dtype == 'int8':
is_int8_inputs = True is_int8_inputs = True
padding = attrs.get_int_tuple("padding") padding = attrs.get_int_tuple("padding")
kh, kw = attrs.get_int_tuple("kernel_size")
pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw))
if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO': if attrs['data_layout'] == 'NHWC' and attrs['kernel_layout'] == 'HWIO':
adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(0, 1, 2)) adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(0, 1, 2))
pad_width = ((0, 0), (padding[0], padding[0]), (padding[1], padding[1]), (0, 0)) pad_width = ((0, 0), (pt, pb), (pl, pr), (0, 0))
elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW': elif attrs['data_layout'] == 'NCHW' and attrs['kernel_layout'] == 'OIHW':
pad_width = ((0, 0), (0, 0), (padding[0], padding[0]), (padding[1], padding[1])) pad_width = ((0, 0), (0, 0), (pt, pb), (pl, pr))
adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(1, 2, 3)) adjust_shift = relay.sum(relay.cast(kernel, dtype='int32'), axis=(1, 2, 3))
adjust_shift = relay.expand_dims(adjust_shift, axis=1, num_newaxis=2) adjust_shift = relay.expand_dims(adjust_shift, axis=1, num_newaxis=2)
else: else:
......
...@@ -25,6 +25,7 @@ from tvm.autotvm.task.topi_integration import deserialize_args ...@@ -25,6 +25,7 @@ from tvm.autotvm.task.topi_integration import deserialize_args
from ..nn.conv2d import _get_workload as _get_conv2d_workload from ..nn.conv2d import _get_workload as _get_conv2d_workload
from .. import generic, tag from .. import generic, tag
from ..generic import conv2d as conv2d_generic from ..generic import conv2d as conv2d_generic
from ..nn.util import get_pad_tuple
from ..util import get_const_tuple from ..util import get_const_tuple
from ..nn.conv2d import conv2d_NCHWc_int8 from ..nn.conv2d import conv2d_NCHWc_int8
from .. import nn from .. import nn
...@@ -92,10 +93,10 @@ def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, lay ...@@ -92,10 +93,10 @@ def _create_tuning_space_int8(cfg, data, kernel, strides, padding, dilation, lay
"schedule template.".format(layout)) "schedule template.".format(layout))
is_kernel_1x1 = kh == 1 and kw == 1 is_kernel_1x1 = kh == 1 and kw == 1
ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) pt, pl, pb, pr = get_pad_tuple(padding, kernel)
sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
oh = (h - kh + 2 * ph) // sh + 1 oh = (h - kh + pt + pb) // sh + 1
ow = (w - kw + 2 * pw) // sw + 1 ow = (w - kw + pl + pr) // sw + 1
# Create schedule config # Create schedule config
cfg.define_split('tile_ic', ic, num_outputs=2, filter=lambda y: y.size[-1] % 4 == 0) cfg.define_split('tile_ic', ic, num_outputs=2, filter=lambda y: y.size[-1] % 4 == 0)
......
...@@ -204,10 +204,10 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs): ...@@ -204,10 +204,10 @@ def _topi_nn_depthwise_conv2d_NCHWc(*args, **kwargs):
batch, in_channel, height, width = get_const_tuple(data.shape) batch, in_channel, height, width = get_const_tuple(data.shape)
filter_channel, channel_multiplier, kh, kw = get_const_tuple(kernel.shape) filter_channel, channel_multiplier, kh, kw = get_const_tuple(kernel.shape)
ph, pw = padding if isinstance(padding, (tuple, list)) else (padding, padding) pt, pl, pb, pr = get_pad_tuple(padding, kernel)
sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides) sh, sw = strides if isinstance(strides, (tuple, list)) else (strides, strides)
out_height = (height - kh + 2 * ph) // sh + 1 out_height = (height - kh + pt + pb) // sh + 1
out_width = (width - kw + 2 * pw) // sw + 1 out_width = (width - kw + pl + pr) // sw + 1
out_channel = filter_channel * channel_multiplier out_channel = filter_channel * channel_multiplier
# get config here # get config here
......
...@@ -22,6 +22,7 @@ from tvm import autotvm ...@@ -22,6 +22,7 @@ from tvm import autotvm
import topi import topi
import topi.testing import topi.testing
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from topi.nn.util import get_pad_tuple
from topi.util import get_const_tuple from topi.util import get_const_tuple
from common import get_all_backend from common import get_all_backend
...@@ -49,10 +50,11 @@ def _transform_bias(bias, bn): ...@@ -49,10 +50,11 @@ def _transform_bias(bias, bn):
def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride, def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
padding, dilation=1, add_bias=False, add_relu=False, dtype="float32"): padding, dilation=1, add_bias=False, add_relu=False, dtype="float32"):
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
(batch, in_channel, in_size, num_filter, kernel, stride, padding)) padding_sum = pad_top + pad_left + pad_bottom + pad_right
in_height = in_width = in_size in_height = in_width = in_size
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding_sum))
# for testing functionality, # for testing functionality,
# we choose arbitrary block size that can divide the channel, # we choose arbitrary block size that can divide the channel,
...@@ -96,7 +98,7 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride, ...@@ -96,7 +98,7 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), (padding, padding), C = topi.nn.conv2d_NCHWc(A, W, (stride, stride), padding,
(dilation, dilation), (dilation, dilation),
layout='NCHW%dc'%ic_block, layout='NCHW%dc'%ic_block,
out_layout="NCHW%dc"%oc_block, out_layout="NCHW%dc"%oc_block,
...@@ -114,12 +116,12 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride, ...@@ -114,12 +116,12 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
if add_bias: if add_bias:
func = tvm.build(s, [A, W, bias, C], device, func = tvm.build(s, [A, W, bias, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, b, c) func(a, w, b, c)
else: else:
func = tvm.build(s, [A, W, C], device, func = tvm.build(s, [A, W, C], device,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, c) func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3)
...@@ -217,5 +219,22 @@ def test_conv2d_NCHWc(): ...@@ -217,5 +219,22 @@ def test_conv2d_NCHWc():
verify_conv2d_NCHWc(1, 512, 5, 126, 3, 1, 1) verify_conv2d_NCHWc(1, 512, 5, 126, 3, 1, 1)
verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1) verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1)
# Asymmetric padding
verify_conv2d_NCHWc(1, 3, 224, 64, 7, 2, (0, 0, 1, 1))
verify_conv2d_NCHWc(1, 64, 56, 128, 3, 1, (3, 3, 2, 2))
verify_conv2d_NCHWc(1, 64, 56, 64, 1, 1, (1, 2, 2, 1))
verify_conv2d_NCHWc(1, 64, 288, 192, 1, 1, (1, 2))
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, (3, 1))
verify_conv2d_NCHWc(1, 128, 56, 384, 3, 1, (0, 2))
verify_conv2d_NCHWc(1, 64, 56, 64, 1, 1, "VALID")
verify_conv2d_NCHWc(1, 388, 56, 64, 3, 1, "VALID")
verify_conv2d_NCHWc(1, 512, 19, 64, 1, 1, "SAME")
verify_conv2d_NCHWc(1, 64, 2048, 32, 2, 1, "SAME")
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
verify_conv2d_NCHWc(1, 64, 56, 64, 5, 2, (1, 3), add_bias=True)
verify_conv2d_NCHWc(1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True)
verify_conv2d_NCHWc(1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True)
if __name__ == "__main__": if __name__ == "__main__":
test_conv2d_NCHWc() test_conv2d_NCHWc()
...@@ -23,6 +23,7 @@ from tvm.autotvm.task.space import FallbackConfigEntity ...@@ -23,6 +23,7 @@ from tvm.autotvm.task.space import FallbackConfigEntity
import topi import topi
import topi.testing import topi.testing
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from topi.nn.util import get_pad_tuple
from topi.util import get_const_tuple from topi.util import get_const_tuple
from common import get_all_backend, Int8Fallback from common import get_all_backend, Int8Fallback
...@@ -31,7 +32,9 @@ oc_block_factor = 4 ...@@ -31,7 +32,9 @@ oc_block_factor = 4
def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
in_height = in_width = in_size in_height = in_width = in_size
...@@ -79,7 +82,7 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str ...@@ -79,7 +82,7 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
C = topi.nn.conv2d(A, W, (stride, stride), (padding, padding), (dilation, dilation), C = topi.nn.conv2d(A, W, (stride, stride), padding, (dilation, dilation),
layout='NCHW', out_dtype=dtype) layout='NCHW', out_dtype=dtype)
if add_bias: if add_bias:
C = topi.add(C, bias) C = topi.add(C, bias)
...@@ -92,11 +95,11 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str ...@@ -92,11 +95,11 @@ def verify_conv2d_NCHWc_int8(batch, in_channel, in_size, num_filter, kernel, str
b = tvm.nd.array(b_np, ctx) b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias: if add_bias:
tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, b, c) func(a, w, b, c)
else: else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, c) func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
...@@ -184,5 +187,22 @@ def test_conv2d_nchw(): ...@@ -184,5 +187,22 @@ def test_conv2d_nchw():
verify_conv2d_NCHWc_int8(8, 32, 149, 32, 3, 1, 0) verify_conv2d_NCHWc_int8(8, 32, 149, 32, 3, 1, 0)
verify_conv2d_NCHWc_int8(32, 32, 149, 32, 3, 1, 0) verify_conv2d_NCHWc_int8(32, 32, 149, 32, 3, 1, 0)
# Asymmetric padding
verify_conv2d_NCHWc_int8(1, 32, 224, 64, 7, 2, (0, 0, 1, 1))
verify_conv2d_NCHWc_int8(1, 64, 56, 128, 3, 1, (3, 3, 2, 2))
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 1, 1, (1, 2, 2, 1))
verify_conv2d_NCHWc_int8(1, 64, 288, 192, 1, 1, (1, 2))
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, (3, 1))
verify_conv2d_NCHWc_int8(1, 128, 56, 384, 3, 1, (0, 2))
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 1, 1, "VALID")
verify_conv2d_NCHWc_int8(1, 388, 56, 64, 3, 1, "VALID")
verify_conv2d_NCHWc_int8(1, 512, 19, 64, 1, 1, "SAME")
verify_conv2d_NCHWc_int8(1, 64, 2048, 32, 2, 1, "SAME")
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 5, 2, (1, 3), add_bias=True)
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True)
verify_conv2d_NCHWc_int8(1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True)
if __name__ == "__main__": if __name__ == "__main__":
test_conv2d_nchw() test_conv2d_nchw()
...@@ -22,12 +22,17 @@ from tvm import autotvm ...@@ -22,12 +22,17 @@ from tvm import autotvm
import topi import topi
import topi.testing import topi.testing
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from topi.nn.util import get_pad_tuple
from topi.util import get_const_tuple from topi.util import get_const_tuple
from common import get_all_backend from common import get_all_backend
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False,\
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) use_cudnn=False):
pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
in_height = in_width = in_size in_height = in_width = in_size
...@@ -62,7 +67,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -62,7 +67,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
return return
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
C = topi.nn.conv2d(A, W, (stride, stride), (padding, padding), C = topi.nn.conv2d(A, W, (stride, stride), padding,
(dilation, dilation), layout='NCHW', out_dtype=dtype) (dilation, dilation), layout='NCHW', out_dtype=dtype)
if add_bias: if add_bias:
C = topi.add(C, bias) C = topi.add(C, bias)
...@@ -75,10 +80,10 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -75,10 +80,10 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
b = tvm.nd.array(b_np, ctx) b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias: if add_bias:
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, b, c) func(a, w, b, c)
else: else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, c) func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
...@@ -86,6 +91,9 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -86,6 +91,9 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
with autotvm.tophub.context(device): # load tophub pre-tuned parameters with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device) check_device(device)
if use_cudnn:
check_device("cuda -model=unknown -libs=cudnn")
def test_conv2d_nchw(): def test_conv2d_nchw():
# ResNet18 workloads # ResNet18 workloads
...@@ -176,6 +184,25 @@ def test_conv2d_nchw(): ...@@ -176,6 +184,25 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 512, 5, 126, 3, 1, 1) verify_conv2d_nchw(1, 512, 5, 126, 3, 1, 1)
verify_conv2d_nchw(1, 256, 3, 126, 3, 1, 1) verify_conv2d_nchw(1, 256, 3, 126, 3, 1, 1)
# Asymmetric padding
verify_conv2d_nchw(1, 3, 224, 64, 7, 2, (0, 0, 1, 1))
verify_conv2d_nchw(1, 64, 56, 128, 3, 1, (3, 3, 2, 2))
verify_conv2d_nchw(1, 64, 56, 64, 1, 1, (1, 2, 2, 1))
verify_conv2d_nchw(1, 64, 288, 192, 1, 1, (1, 2))
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, (3, 1))
verify_conv2d_nchw(1, 128, 56, 384, 3, 1, (0, 2))
verify_conv2d_nchw(1, 64, 384, 64, 3, 1, (1, 2), use_cudnn=True)
verify_conv2d_nchw(1, 64, 56, 64, 1, 1, "VALID")
verify_conv2d_nchw(1, 388, 56, 64, 3, 1, "VALID")
verify_conv2d_nchw(1, 64, 1280, 48, 3, 1, "VALID", use_cudnn=True)
verify_conv2d_nchw(1, 512, 19, 64, 1, 1, "SAME")
verify_conv2d_nchw(1, 64, 2048, 32, 2, 1, "SAME")
verify_conv2d_nchw(1, 64, 8, 64, 3, 1, "SAME", use_cudnn=True)
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, (1, 2, 2, 1), add_relu=True)
verify_conv2d_nchw(1, 64, 56, 64, 5, 2, (1, 3), add_bias=True)
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, "VALID", add_bias=True, add_relu=True)
verify_conv2d_nchw(1, 64, 56, 64, 24, 1, "SAME", add_bias=True, add_relu=True)
if __name__ == "__main__": if __name__ == "__main__":
test_conv2d_nchw() test_conv2d_nchw()
...@@ -71,8 +71,13 @@ def test_conv2d_nhwc(): ...@@ -71,8 +71,13 @@ def test_conv2d_nhwc():
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID") verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "VALID") verify_conv2d_nhwc(4, 128, 16, 128, 5, 2, "VALID")
verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "VALID") verify_conv2d_nhwc(4, 128, 16, 256, 5, 2, "VALID")
verify_conv2d_nhwc(1, 128, 16, 256, 3, 2, (0, 0, 1, 1))
verify_conv2d_nhwc(1, 128, 16, 256, 3, 2, (1, 1, 2, 2))
verify_conv2d_nhwc(1, 128, 16, 128, 5, 2, (3, 3, 2, 2))
verify_conv2d_nhwc(1, 128, 16, 256, 3, 2, (0, 1, 2, 3))
# dilation = 2 # dilation = 2
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME", dilation=2) verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, "SAME", dilation=2)
verify_conv2d_nhwc(1, 256, 32, 256, 3, 1, (1, 1, 2, 2), dilation=2)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -23,12 +23,15 @@ from tvm.autotvm.task.space import FallbackConfigEntity ...@@ -23,12 +23,15 @@ from tvm.autotvm.task.space import FallbackConfigEntity
import topi import topi
import topi.testing import topi.testing
from tvm.contrib.pickle_memoize import memoize from tvm.contrib.pickle_memoize import memoize
from topi.nn.util import get_pad_tuple
from topi.util import get_const_tuple from topi.util import get_const_tuple
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False, def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False,
devices=['cuda', 'llvm -device=arm_cpu', 'opencl -device=mali']): devices=['cuda', 'llvm -device=arm_cpu', 'opencl -device=mali']):
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
padding_sum = pad_top + pad_left + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
in_height = in_width = in_size in_height = in_width = in_size
...@@ -76,14 +79,13 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p ...@@ -76,14 +79,13 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
b = tvm.nd.array(b_np, ctx) b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias: if add_bias:
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, b, c) func(a, w, b, c)
else: else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, c) func(a, w, c)
rtol = 1e-3 rtol = 1e-3
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol) tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=rtol)
...@@ -133,5 +135,20 @@ def test_conv2d_nchw(): ...@@ -133,5 +135,20 @@ def test_conv2d_nchw():
verify_conv2d_nchw(3, 3, 3, 3, 3, 1, 1) verify_conv2d_nchw(3, 3, 3, 3, 3, 1, 1)
verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1) verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1)
# Asymmetric padding
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, (1, 1, 1, 1))
verify_conv2d_nchw(1, 128, 28, 128, 3, 1, (1, 1, 1, 1))
verify_conv2d_nchw(1, 256, 14, 256, 3, 1, (1, 1))
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, "SAME")
verify_conv2d_nchw(2, 13, 71, 59, 3, 1, (1, 1, 1, 1))
verify_conv2d_nchw(2, 64, 56, 64, 3, 1, (1, 1, 1, 1), add_bias=True)
verify_conv2d_nchw(2, 64, 56, 64, 3, 1, (1, 1), add_relu=True)
verify_conv2d_nchw(2, 64, 56, 64, 3, 1, "SAME", add_relu=True, add_bias=True)
verify_conv2d_nchw(1, 128, 17, 192, 7, 1, (3, 1), devices=['cuda'])
verify_conv2d_nchw(1, 128, 17, 128, 7, 1, (3, 3, 2, 2), devices=['cuda'])
verify_conv2d_nchw(1, 160, 17, 160, 7, 1, "SAME", devices=['cuda'])
verify_conv2d_nchw(1, 48, 35, 64, 5, 1, "VALID", devices=['cuda'])
if __name__ == "__main__": if __name__ == "__main__":
test_conv2d_nchw() test_conv2d_nchw()
...@@ -90,7 +90,6 @@ def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, ...@@ -90,7 +90,6 @@ def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride,
with autotvm.tophub.context(device): # load tophub pre-tuned parameters with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device) check_device(device)
def test_conv3d_ncdhw(): def test_conv3d_ncdhw():
#3DCNN workloads #3DCNN workloads
verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, 0) verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, 0)
...@@ -122,6 +121,5 @@ def test_conv3d_ncdhw(): ...@@ -122,6 +121,5 @@ def test_conv3d_ncdhw():
verify_conv3d_ncdhw(1, 32, 32, 1, 3, 1, "VALID") verify_conv3d_ncdhw(1, 32, 32, 1, 3, 1, "VALID")
verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, "VALID") verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, "VALID")
if __name__ == "__main__": if __name__ == "__main__":
test_conv3d_ncdhw() test_conv3d_ncdhw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment