Commit 1ef1605a by optima2005 Committed by masahi

[FRONTEND][TF] Add conv3d (#4604)

* [FRONTEND][TF] Add conv3d

* fix high rtol
parent f096c06f
......@@ -212,7 +212,11 @@ struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"three int : back, bottom, right will use same padding as front, top, left"
"six int : padding width in the order of (front, top, left, back, bottom,"
"right)");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
......
......@@ -66,16 +66,18 @@ def _dimension_picker(prefix, surfix=''):
kernel = attr['kernel_shape']
if len(kernel) == 2:
return prefix + '2d' + surfix
if len(kernel) == 3:
return prefix + '3d' + surfix
raise tvm.error.OpAttributeInvalid(
'Only 2D kernels are supported for operator {}'.format(prefix + '2d'))
'Only 2D or 3D kernels are supported for operator {}'.format(prefix + '2d or 3d'))
return _impl
def _dimension_constraint():
def _dim_check(attrs):
if len(attrs['kernel_shape']) == 2:
if len(attrs['kernel_shape']) in (2, 3):
return True
return False
return _dim_check, "Only 2d kernel supported."
return _dim_check, "Only 2d or 3d kernel supported."
def _get_param(params, input_node):
if isinstance(input_node, _expr.Constant):
......@@ -425,6 +427,130 @@ def _conv(opname):
return out
return _impl
def _conv3d(opname):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
flip_layout = False
inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]
# NCDHW Layout require weights transpose
if attr['data_format'] == 'NCDHW':
tmp_shape = attr['_input_shapes'][inputs[1]]
tmp_shape = [tmp_shape[ii] for ii in (4, 3, 0, 1, 2)]
inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2))
attr['_input_shapes'][inputs[1]] = tmp_shape
input_shape = attr['_input_shapes'][inputs_data]
weights_shape = attr['_input_shapes'][inputs[1]]
if attr['_target_layout'] == "NCDHW" and attr['data_format'] == "NDHWC":
input_shape = [input_shape[ii] for ii in (0, 4, 1, 2, 3)]
inputs_data = _op.transpose(inputs_data, axes=(0, 4, 1, 2, 3))
weights_shape = [weights_shape[ii] for ii in (4, 3, 0, 1, 2)]
inputs[1] = _op.transpose(inputs[1], axes=(4, 3, 0, 1, 2))
attr['data_format'] = "NCDHW"
attr['strides'] = [attr['strides'][ii] for ii in (0, 4, 1, 2, 3)]
flip_layout = True
if attr['data_format'] == 'NDHWC':
kernel_d, kernel_h, kernel_w, _, _ = weights_shape
attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w)
if opname == 'conv':
attr['channels'] = weights_shape[4]
elif opname == 'conv_transpose':
attr['channels'] = weights_shape[3]
if 'dilations' in attr:
attr['dilations'] =\
(attr['dilations'][1], attr['dilations'][2], attr['dilations'][3])
attr['strides'] = (attr['strides'][1], attr['strides'][2], attr['strides'][3])
elif attr['data_format'] == 'NCDHW':
_, _, kernel_d, kernel_h, kernel_w = weights_shape
attr['kernel_shape'] = (kernel_d, kernel_h, kernel_w)
if opname == 'conv':
attr['channels'] = weights_shape[0]
elif opname == 'conv_transpose':
attr['channels'] = weights_shape[1]
if 'dilations' in attr:
attr['dilations'] =\
(attr['dilations'][2], attr['dilations'][3], attr['dilations'][4])
attr['strides'] = (attr['strides'][2], attr['strides'][3], attr['strides'][4])
else:
msg = 'Value {} in attribute "data_format" of operator Conv is ' \
'not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
# Fix padding
attr['padding'] = attr['padding'].decode("utf-8")
if attr['padding'] == 'VALID':
attr['padding'] = [0, 0, 0]
elif attr['padding'] == 'SAME':
stride_d, stride_h, stride_w = attr['strides']
kernel_d, kernel_h, kernel_w = attr['kernel_shape']
pdata_shape = input_shape
if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
pdata_shape = attr['_output_shapes'][0]
if attr['data_format'] == 'NDHWC':
in_d = pdata_shape[1]
in_h = pdata_shape[2]
in_w = pdata_shape[3]
else:
in_d = pdata_shape[2]
in_h = pdata_shape[3]
in_w = pdata_shape[4]
dilation_d = attr['dilations'][0]
dilation_h = attr['dilations'][1]
dilation_w = attr['dilations'][2]
dilated_kernel_d = (kernel_d - 1) * dilation_d + 1
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_d = _get_pad_pair(in_d, dilated_kernel_d, stride_d)
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
attr['padding'] = [pad_d[0], pad_v[0], pad_h[0], pad_v[0], pad_v[1], pad_h[1]]
else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' \
'valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
if 'kernel_layout' not in attr:
attr['kernel_layout'] = 'DHWIO' if attr['data_format'] == 'NDHWC' else 'OIDHW'
use_bias = len(inputs) == (3 if opname != 'conv_transpose' else 4)
channel_axis = 1 if attr['data_format'] == "NCDHW" else 3
# Ignore the new attributes from TF2.0, for now.
out = AttrCvt(
op_name=_dimension_picker('conv', \
surfix="_transpose" if opname == 'conv_transpose' else ""),
ignores=['explicit_paddings'],
transforms={
'kernel_shape': 'kernel_size',
'data_format': 'data_layout',
'dilations': ('dilation', (0, 0)),
'group': ('groups', 1)},
custom_check=_dimension_constraint())([inputs_data, inputs[1]], attr)
if use_bias:
out = _op.nn.bias_add(out,
inputs[2] if opname != 'conv_transpose' else inputs[3],
axis=channel_axis)
if flip_layout:
out = _op.transpose(out, axes=(0, 2, 3, 4, 1))
return out
return _impl
def _decode_image():
def _impl(inputs, attr, params):
# Image decode wrapper: Expecting user to feed decoded input to next layer drop this layer.
......@@ -1442,6 +1568,7 @@ _convert_map = {
'Concat' : _concat(),
'ConcatV2' : _concatV2(),
'Conv2D' : _conv('conv'),
'Conv3D' : _conv3d('conv'),
'Conv2DBackpropInput' : _conv('conv_transpose'),
'CropAndResize' : _crop_and_resize(),
'DecodeJpeg' : _decode_image(),
......
......@@ -173,6 +173,7 @@ def compute_conv2d(attrs, inputs, out_type, target):
assert len(weight_shape) == 5
C, M, _, _, VC = weight_shape
return C * VC * M
if groups == 1:
out = topi.nn.conv2d(
inputs[0], inputs[1], strides, padding,
......@@ -330,7 +331,7 @@ def compute_conv3d(attrs, inputs, out_type, target):
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
assert layout in ["NCDHW"]
assert layout in ["NCDHW", "NDHWC"]
(dilation_d, dilation_h, dilation_w) = dilation
if dilation_d < 1 or dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
......@@ -353,6 +354,8 @@ def schedule_conv3d(attrs, outs, target):
with target:
if groups == 1 and layout == "NCDHW":
return topi.generic.schedule_conv3d_ncdhw(outs)
elif groups == 1 and layout == "NDHWC":
return topi.generic.schedule_conv3d_ndhwc(outs)
raise ValueError("No compatible schedule")
......
......@@ -38,7 +38,7 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
template<typename T>
Array<Array<Layout> > Conv2DInferCorrectLayout(
Array<Array<Layout> > ConvInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
......@@ -105,7 +105,7 @@ with the layer input to produce a tensor of outputs.
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
// relay.nn.conv3d
TVM_REGISTER_NODE_TYPE(Conv3DAttrs);
......@@ -163,7 +163,8 @@ with the layer input to produce a tensor of outputs.
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>);
.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv3DAttrs>);
// relay.nn.conv2d_transpose
......@@ -337,7 +338,7 @@ v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DTransposeAttrs>)
ConvInferCorrectLayout<Conv2DTransposeAttrs>)
.add_type_rel("Conv2DTranspose", Conv2DTransposeRel);
......@@ -635,7 +636,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
.set_support_level(10)
.add_type_rel("Conv2DWinograd", Conv2DWinogradRel<Conv2DWinogradAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DWinogradAttrs>);
ConvInferCorrectLayout<Conv2DWinogradAttrs>);
// relay.nn.contrib_conv2d_winograd_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradWeightTransformAttrs);
......@@ -744,7 +745,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKRel", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
// relay.nn.contrib_conv2d_winograd_nnpack_weight_transform
TVM_REGISTER_NODE_TYPE(Conv2DWinogradNNPACKWeightTransformAttrs);
......@@ -854,7 +855,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc_int8")
.set_support_level(10)
.add_type_rel("Conv2DNCHWcInt8", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);
ConvInferCorrectLayout<Conv2DAttrs>);
// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
......@@ -903,7 +904,7 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
.set_support_level(10)
.add_type_rel("Conv2DNCHWc", Conv2DWinogradRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);
ConvInferCorrectLayout<Conv2DAttrs>);
// Positional relay function to create depthwise conv2d NCHWc operator
......@@ -953,7 +954,7 @@ RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
.set_support_level(10)
.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);
ConvInferCorrectLayout<Conv2DAttrs>);
bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
......
......@@ -28,6 +28,8 @@
#include <string>
#include <utility>
#include "../op_common.h"
namespace tvm {
namespace relay {
......@@ -187,7 +189,7 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
param->kernel_size[1], param->kernel_size[2]}};
}
/*wshape = trans_kernel_layout.BackwardShape(wshape); */
wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels;
dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
......@@ -196,6 +198,7 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
if (weight != nullptr) {
weight_dtype = weight->dtype;
}
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
} else {
......@@ -225,22 +228,24 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// dilation
Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0});
IndexExpr pad_d, pad_h, pad_w;
GetPaddingDepthHeightWidth(param->padding, &pad_d, &pad_h, &pad_w);
if (!dshape_ncdhw[2].as<ir::Any>()) {
oshape.Set(2, indexdiv(dshape_ncdhw[2] + param->padding[0] * 2 - dilated_ksize_z,
oshape.Set(2, indexdiv(dshape_ncdhw[2] + pad_d - dilated_ksize_z,
param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_ncdhw[2]);
}
if (!dshape_ncdhw[3].as<ir::Any>()) {
oshape.Set(3, indexdiv(dshape_ncdhw[3] + param->padding[1] * 2 - dilated_ksize_y,
oshape.Set(3, indexdiv(dshape_ncdhw[3] + pad_h - dilated_ksize_y,
param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_ncdhw[3]);
}
if (!dshape_ncdhw[4].as<ir::Any>()) {
oshape.Set(4, indexdiv(dshape_ncdhw[4] + param->padding[2] * 2 - dilated_ksize_x,
oshape.Set(4, indexdiv(dshape_ncdhw[4] + pad_w - dilated_ksize_x,
param->strides[2]) + 1);
} else {
oshape.Set(4, dshape_ncdhw[4]);
......
......@@ -162,6 +162,45 @@ inline void GetPaddingWidth(const Array<IndexExpr>& padding, IndexExpr* pad_w) {
}
}
/*! \brief A utility function to get padding height and width from a 1, 2, 4 ints tuple. */
inline void GetPaddingHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_h,
IndexExpr* pad_w) {
if (padding.size() == 1) {
*pad_h = padding[0] * 2;
*pad_w = padding[0] * 2;
} else if (padding.size() == 2) {
*pad_h = padding[0] * 2;
*pad_w = padding[1] * 2;
} else if (padding.size() == 4) {
*pad_h = padding[0] + padding[2];
*pad_w = padding[1] + padding[3];
} else {
CHECK_EQ(padding.size(), 4) << " Padding size should be 1, 2 or 4, but got "
<< padding.size();
}
}
/*! \brief A utility function to get padding depth, height and width from a 1, 3, 6 ints tuple. */
inline void GetPaddingDepthHeightWidth(const Array<IndexExpr>& padding, IndexExpr* pad_d,
IndexExpr* pad_h, IndexExpr* pad_w) {
if (padding.size() == 1) {
*pad_d = padding[0] * 2;
*pad_h = padding[0] * 2;
*pad_w = padding[0] * 2;
} else if (padding.size() == 3) {
*pad_d = padding[0] * 2;
*pad_h = padding[1] * 2;
*pad_w = padding[2] * 2;
} else if (padding.size() == 6) {
*pad_d = padding[0] + padding[3];
*pad_h = padding[1] + padding[4];
*pad_w = padding[2] + padding[5];
} else {
CHECK_EQ(padding.size(), 6) << " Padding size should be 1, 3 or 6, but got "
<< padding.size();
}
}
} // namespace relay
} // namespace tvm
......
......@@ -94,13 +94,14 @@ def vmobj_to_list(o):
def run_tvm_graph(graph_def, input_data, input_node, num_output=1,
target='llvm', out_names=None, opt_level=3, mode='graph_runtime'):
target='llvm', out_names=None, opt_level=3, mode='graph_runtime',
cuda_layout="NCHW"):
""" Generic function to compile on relay and execute on tvm """
input_data = convert_to_list(input_data)
input_node = convert_to_list(input_node)
layout = None
if target == "cuda":
layout = "NCHW"
layout = cuda_layout
target_host = None
shape_dict = {e: i.shape for e, i in zip(input_node, input_data)}
mod, params = relay.frontend.from_tensorflow(graph_def,
......@@ -160,7 +161,8 @@ def run_tf_graph(sess, input_data, input_node, output_node):
def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
no_gpu=False, opt_level=3, mode='graph_runtime'):
no_gpu=False, opt_level=3, mode='graph_runtime',
cuda_layout="NCHW"):
"""Generic function to generate and compare tensorflow and TVM output"""
def name_without_num(name):
return name.split(':')[0] if ":" in name else name
......@@ -191,7 +193,8 @@ def compare_tf_with_tvm(in_data, in_name, out_name, init_global_variables=False,
tvm_output = run_tvm_graph(final_graph_def, in_data, in_node,
target=device, out_names=out_name,
num_output=len(out_name), opt_level=opt_level, mode=mode)
num_output=len(out_name), opt_level=opt_level, mode=mode,
cuda_layout=cuda_layout)
# since the names from tensorflow and relay runs are not exactly same,
# first len(tf_output) will be compared
for i in range(len(tf_output)):
......@@ -470,6 +473,57 @@ def test_forward_convolution():
#######################################################################
# Convolution3D
# -----------
def _test_convolution3d(opname, tensor_in_sizes, filter_in_sizes,
dilations, strides, padding, data_format,
deconv_output_shape=[]):
""" One iteration of 3D convolution with given shapes and attributes """
total_size_1 = np.prod(tensor_in_sizes)
total_size_2 = np.prod(filter_in_sizes)
# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(
filter_array, shape=filter_in_sizes, dtype='float32')
if data_format == 'NDHWC':
strides = [1] + strides + [1]
dilations = [1] + dilations + [1]
else:
strides = [1, 1] + strides
dilations = [1, 1] + dilations
if opname == 'conv':
nn_ops.conv3d(in_data,
in_filter,
strides=strides,
dilations=dilations,
padding=padding,
data_format=data_format)
compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
'Placeholder:0', 'Conv3D:0', cuda_layout="NCDHW")
def test_forward_convolution3d():
if is_gpu_available():
_test_convolution3d('conv', [4, 176, 8, 8, 8], [1, 1, 1, 176, 32], [1, 1, 1], [1, 1, 1], 'SAME', 'NCDHW')
_test_convolution3d('conv', [4, 19, 17, 17, 17], [3, 3, 3, 19, 19], [1, 1, 1], [2, 2, 2], 'VALID', 'NCDHW')
_test_convolution3d('conv', [4, 124, 17, 17, 17], [1, 1, 1, 124, 19], [1, 1, 1], [1, 1, 1], 'SAME', 'NCDHW')
_test_convolution3d('conv', [4, 12, 17, 17, 17], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], 'VALID', 'NCDHW')
_test_convolution3d('conv', [4, 8, 8, 8, 176], [1, 1, 1, 176, 32], [1, 1, 1], [1, 1, 1], 'SAME', 'NDHWC')
_test_convolution3d('conv', [4, 17, 17, 17, 19], [3, 3, 3, 19, 19], [1, 1, 1], [2, 2, 2], 'VALID', 'NDHWC')
_test_convolution3d('conv', [4, 17, 17, 17, 124], [1, 1, 1, 124, 19], [1, 1, 1], [1, 1, 1], 'SAME', 'NDHWC')
_test_convolution3d('conv', [4, 17, 17, 17, 12], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], 'VALID', 'NDHWC')
#######################################################################
# BiasAdd
# -----------
......
......@@ -294,6 +294,56 @@ def test_conv2d_winograd():
padding=(2, 2), channels=192, kernel_size=(7, 7))
def test_conv3d_infer_type():
# symbolic in batch dimension
n, c, d, h, w = tvm.var("n"), 10, 224, 224, 224
x = relay.var("x", relay.ty.TensorType((n, c, d, h, w), "float32"))
w = relay.var("w")
y = relay.nn.conv3d(x, w,
kernel_size=(3, 3, 3),
padding=(1, 1, 1),
channels=2)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, 2, 224, 224, 224), "float32")
assert yy.args[1].checked_type == relay.TensorType(
(2, 10, 3, 3, 3), "float32")
# infer by shape of w, mixed precision
n, c, d, h, w = tvm.var("n"), 10, 224, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "int8"))
w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8"))
y = relay.nn.conv3d(x, w, out_dtype="int32")
assert "out_dtype=\"int32\"" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, 2, 222, 222, 222), "int32")
# infer shape in case of different dtypes for input and weight.
n, c, d, h, w = tvm.var("n"), 10, 224, 224, 224
x = relay.var("x", relay.TensorType((n, c, d, h, w), "uint8"))
w = relay.var("w", relay.TensorType((2, 10, 3, 3, 3), "int8"))
y = relay.nn.conv3d(x, w, out_dtype="int32")
assert "out_dtype=\"int32\"" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, 2, 222, 222, 222), "int32")
# Infer with NDHWC
n, c, d, h, w = 4, 32, 224, 224, 224
x = relay.var("x", relay.TensorType((n, d, h, w, c), "int8"))
wt = relay.var("w")
y = relay.nn.conv3d(x, wt,
kernel_size=(3, 3, 3),
padding=(1, 1, 1),
channels=16,
data_layout="NDHWC",
out_dtype="int32")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, d, h, w, 16), "int32")
def test_conv3d_run():
def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape,
padding=(1, 1, 1),
......@@ -338,6 +388,50 @@ def test_conv3d_run():
run_test_conv3d("float32", "float32", 1, dshape, kshape,
padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3))
def test_conv3d_ndhwc_run():
def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape,
padding=(1, 1, 1),
fref=None,
groups=1,
dilation=(1, 1, 1),
except_targets=None,
**attrs):
if except_targets is None:
except_targets = []
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", dtype=dtype)
y = relay.nn.conv3d(x, w,
padding=padding,
dilation=dilation,
groups=groups,
data_layout="NDHWC", kernel_layout="DHWIO",
**attrs)
func = relay.Function([x, w], y)
data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
dkernel = topi.testing.dilate_python(kernel, (1, 1) + dilation)
if fref is None:
ref_res = topi.testing.conv3d_ndhwc_python(
data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding)
else:
ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype))
for target, ctx in ctx_list():
if target in except_targets:
continue
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data, kernel)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
# normal conv3d
dshape = (1, 5, 224, 224, 6)
kshape = (3, 3, 3, 6, 10)
run_test_conv3d("float32", "float32", 1, dshape, kshape,
padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3), except_targets=["cuda"])
def test_conv2d_transpose_infer_type():
# symbolic in batch dimension
......@@ -993,6 +1087,7 @@ if __name__ == "__main__":
test_lrn()
test_l2_normalize()
test_conv2d_infer_type()
test_conv3d_infer_type()
test_bitpack_infer_type()
test_upsampling_infer_type()
test_upsampling3d_infer_type()
......@@ -1006,6 +1101,7 @@ if __name__ == "__main__":
test_conv2d_run()
test_conv2d_winograd()
test_conv3d_run()
test_conv3d_ndhwc_run()
test_bitserial_conv2d_infer_type()
test_batch_flatten()
test_upsampling()
......
......@@ -21,6 +21,7 @@ from tvm import autotvm
from tvm.contrib import cudnn
from .. import nn, generic
from ..nn.util import get_pad_tuple3d
from ..util import get_const_tuple, traverse_inline
from .conv3d_direct import schedule_direct_3d_cuda
......@@ -44,8 +45,10 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o
strides : int or a list/tuple of three ints
stride size, or [stride_depth, stride_height, stride_width]
padding : int or a list/tuple of three ints
padding size, or [pad_depth, pad_height, pad_width]
padding : int or a list/tuple of 3 or 6 ints
padding size, or
[pad_depth, pad_height, pad_width] for 3 ints, or
[pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right] for 6 ints
dilation: int or a list/tuple of three ints
dilation size, or [dilation_depth, dilation_height, dilation_width]
......@@ -77,25 +80,27 @@ def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', o
# handle dilation
stride_d, stride_h, stride_w = (strides, strides, strides) if isinstance(strides, int) \
else strides
pad_d, pad_h, pad_w = (padding, padding, padding) if isinstance(padding, int) else padding
if isinstance(padding, (list, tuple)) and len(padding) > 3:
raise ValueError("Cudnn doesn't support asymmetric padding.")
pf, pt, pl, pk, pb, pr = get_pad_tuple3d(padding, (KD, KH, KW))
dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation) if \
isinstance(dilation, int) else dilation
OD = (D + 2 * pad_d - KD) // stride_d + 1
OH = (H + 2 * pad_h - KH) // stride_h + 1
OW = (W + 2 * pad_w - KW) // stride_w + 1
cfg.add_flop(2 * N * OD * OH * OW * CO * CI * ((DH - 1) * dilation_d + 1) *\
OD = (D + pf + pk - KD) // stride_d + 1
OH = (H + pt + pb - KH) // stride_h + 1
OW = (W + pl + pr - KW) // stride_w + 1
cfg.add_flop(2 * N * OD * OH * OW * CO * CI * ((KD - 1) * dilation_d + 1) *\
((KH - 1) * dilation_h + 1) * ((KW - 1) * dilation_w + 1))
return cudnn.conv_forward(data,
kernel,
[pad_d, pad_h, pad_w],
[pf, pt, pl], # cudnn padding pt, pl on both sides of input
[stride_d, stride_h, stride_w],
[dilation_d, dilation_h, dilation_w],
conv_mode=1,
tensor_format=tensor_format,
algo=-1, # let CUDNN choose the best algo
conv_dtype=dtype)
conv_dtype=data.dtype)
if layout == 'NCDHW':
return nn.conv3d_ncdhw(data, kernel, strides, padding, dilation, out_dtype)
......@@ -134,3 +139,37 @@ def schedule_conv3d_ncdhw_cuda(cfg, outs):
traverse_inline(s, outs[0].op, _callback)
return s
@autotvm.register_topi_schedule(generic.schedule_conv3d_ndhwc, ["cuda", "gpu"],
["direct"])
def schedule_conv3d_ndhwc_cuda(cfg, outs):
"""TOPI schedule callback of conv3d for cuda gpu
Parameters
----------
cfg: ConfigEntity
The config for this template
outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d.
"""
target = tvm.target.current_target()
if 'cudnn' in target.libs:
return generic.schedule_extern(outs)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if op.tag == 'conv3d_ndhwc':
schedule_direct_3d_cuda(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
......@@ -242,6 +242,22 @@ def schedule_conv3d_ncdhw(outs):
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv3d_ndhwc(outs):
"""Schedule for conv3d_ndhwc
Parameters
----------
outs: Array of Tensor
The computation graph description of conv3d_ndhwc
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv2d_transpose_nchw(outs):
......
......@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin
# pylint: disable=unused-argument, redefined-builtin, no-else-return
"""Conv3D operators"""
from __future__ import absolute_import as _abs
import tvm
......@@ -58,6 +58,8 @@ def conv3d(input, filter, strides, padding, dilation, layout='NCDHW', out_dtype=
# default declaration
if layout == 'NCDHW':
return conv3d_ncdhw(input, filter, strides, padding, dilation, out_dtype)
elif layout == 'NDHWC':
return conv3d_ndhwc(input, filter, strides, padding, dilation, out_dtype)
raise ValueError("not support this layout {} yet".format(layout))
......@@ -128,3 +130,71 @@ def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None):
xx * stride_w + rx * dilation_w].astype(out_dtype) *
Filter[ff, rc, rz, ry, rx].astype(out_dtype),
axis=[rc, rz, ry, rx]), tag="conv3d_ncdhw")
def conv3d_ndhwc(Input, Filter, stride, padding, dilation, out_dtype='float32'):
"""Convolution operator in NDHWC layout.
Parameters
----------
Input : tvm.Tensor
5-D with shape [batch, in_channel, in_depth, in_height, in_width]
Filter : tvm.Tensor
5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width]
stride : int or a list/tuple of three ints
Stride size, or [strid_depth, stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation: int or a list/tuple of three ints
dilation size, or [dilation_depth, dilation_height, dilation_width]
Returns
-------
Output : tvm.Tensor
5-D with shape [batch, out_channel, out_depth, out_height, out_width]
"""
assert isinstance(stride, int) or len(stride) == 3
assert isinstance(dilation, int) or len(dilation) == 3
if isinstance(stride, int):
stride_d = stride_h = stride_w = stride
else:
stride_d, stride_h, stride_w = stride
if isinstance(dilation, int):
dilation_d = dilation_h = dilation_w = dilation
else:
dilation_d, dilation_h, dilation_w = dilation
batch, in_depth, in_height, in_width, in_channel = Input.shape
kernel_d, kernel_h, kernel_w, channel, num_filter = Filter.shape
# compute the output shape
dilated_kernel_d = (kernel_d - 1) * dilation_d + 1
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w))
out_channel = num_filter
out_depth = simplify((in_depth - dilated_kernel_d + pad_front + pad_back) // stride_d + 1)
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
pad_before = [0, pad_front, pad_top, pad_left, 0]
pad_after = [0, pad_back, pad_down, pad_right, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
rc = tvm.reduce_axis((0, in_channel), name='rc')
rz = tvm.reduce_axis((0, kernel_d), name='rz')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
Output = tvm.compute(
(batch, out_depth, out_height, out_width, out_channel),
lambda nn, zz, yy, xx, ff: tvm.sum(
PaddedInput[nn, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
Filter[rz, ry, rx, rc, ff].astype(out_dtype), axis=[rz, ry, rx, rc]),
name="Conv3dOutput", tag="conv3d_ndhwc")
return Output
......@@ -158,9 +158,15 @@ def get_pad_tuple3d(padding, kernel):
"""
# compute the padding size
if isinstance(padding, (tuple, list)):
pad_h = padding[0] * 2
pad_w = padding[1] * 2
pad_d = padding[2] * 2
if len(padding) == 3:
pad_d = padding[0] * 2
pad_h = padding[1] * 2
pad_w = padding[2] * 2
elif len(padding) == 6:
return padding[0], padding[1], padding[2], padding[3], \
padding[4], padding[5]
else:
raise ValueError("Size of padding can only be 3 or 6")
elif isinstance(padding, int):
pad_d = pad_w = pad_h = padding * 2
elif padding == "VALID":
......
......@@ -25,6 +25,7 @@ from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
from .conv2d_nhwc_python import conv2d_nhwc_python
from .conv3d_ncdhw_python import conv3d_ncdhw_python
from .conv3d_ndhwc_python import conv3d_ndhwc_python
from .conv2d_transpose_python import conv2d_transpose_nchw_python, conv2d_transpose_nhwc_python
from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python
from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
......
......@@ -18,6 +18,7 @@
"""Convolution 3D in python"""
import numpy as np
import scipy.signal
from topi.nn.util import get_pad_tuple3d
def _conv3d_ncdhw_python(a_np, w_np, stride, padding):
......@@ -27,20 +28,13 @@ def _conv3d_ncdhw_python(a_np, w_np, stride, padding):
stride_d = stride_h = stride_w = stride
else:
stride_d, stride_h, stride_w = stride
if isinstance(padding, int):
pad_d = pad_h = pad_w = padding * 2
elif isinstance(padding, (list, tuple)):
pad_d, pad_h, pad_w = padding[0] * 2, padding[1] * 2, padding[2] * 2
else:
pad_d = 0 if padding == 'VALID' else kernel_d - 1
pad_h = 0 if padding == 'VALID' else kernel_h - 1
pad_w = 0 if padding == 'VALID' else kernel_w - 1
pad_front = int(np.ceil(float(pad_d) / 2))
pad_back = pad_d - pad_front
pad_top = int(np.ceil(float(pad_h) / 2))
pad_bottom = pad_h - pad_top
pad_left = int(np.ceil(float(pad_w) / 2))
pad_right = pad_w - pad_left
pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = \
get_pad_tuple3d(padding, (kernel_d, kernel_h, kernel_w))
pad_d = pad_front + pad_back
pad_h = pad_top + pad_bottom
pad_w = pad_left + pad_right
# compute the output shape
out_channel = num_filter
out_depth = (in_depth - kernel_d + pad_d) // stride_d + 1
......@@ -53,19 +47,8 @@ def _conv3d_ncdhw_python(a_np, w_np, stride, padding):
for c in range(in_channel):
if pad_d > 0 or pad_h > 0 or pad_w > 0:
apad = np.zeros((in_depth + pad_d, in_height + pad_h, in_width + pad_w))
if pad_d == 0 and pad_h == 0:
apad[:, :, pad_left:-pad_right] = a_np[n, c]
elif pad_d == 0 and pad_w == 0:
apad[:, pad_top:-pad_bottom, :] = a_np[n, c]
elif pad_d == 0 and pad_h != 0 and pad_w != 0:
apad[:, pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c]
elif pad_d != 0 and pad_h == 0:
apad[pad_front:-pad_back, :, pad_left:-pad_right] = a_np[n, c]
elif pad_d != 0 and pad_w == 0:
apad[pad_front:-pad_back, pad_top:-pad_bottom, :] = a_np[n, c]
elif pad_d != 0 and pad_h != 0 and pad_w != 0:
apad[pad_front:-pad_back, pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c]
apad[pad_front:pad_front + in_depth, pad_top:pad_top + in_height,\
pad_left:pad_left + in_width] = a_np[n, c]
else:
apad = a_np[n, c]
out = scipy.signal.convolve(
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Convolution 3D in python"""
import numpy as np
import scipy.signal
from topi.nn.util import get_pad_tuple3d
def conv3d_ndhwc_python(a_np, w_np, stride, padding):
"""Convolution 3D operator in NDHWC layout.
Parameters
----------
a_np : numpy.ndarray
5-D with shape [batch, in_channel, in_depth, in_height, in_width]
w_np : numpy.ndarray
5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width]
stride : int or a list/tuple of three ints
Stride size, or [stride_depth, stride_height, stride_width]
padding : int or str or a list/tuple of three ints
Padding size, or ['VALID', 'SAME'], or [pad_depth, pad_height, pad_width]
groups : int
Number of groups
Returns
-------
b_np : np.ndarray
5-D with shape [batch, out_channel, out_depth, out_height, out_width]
"""
batch, in_depth, in_height, in_width, in_channel = a_np.shape
kernel_d, kernel_h, kernel_w, _, num_filter = w_np.shape
if isinstance(stride, int):
stride_d = stride_h = stride_w = stride
else:
stride_d, stride_h, stride_w = stride
pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = \
get_pad_tuple3d(padding, (kernel_d, kernel_h, kernel_w))
pad_d = pad_front + pad_back
pad_h = pad_top + pad_bottom
pad_w = pad_left + pad_right
# compute the output shape
out_channel = num_filter
out_depth = (in_depth - kernel_d + pad_d) // stride_d + 1
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
# change the layout from NHWC to NCHW
at = a_np.transpose((0, 4, 1, 2, 3))
wt = w_np.transpose((4, 3, 0, 1, 2))
bt = np.zeros((batch, out_channel, out_depth, out_height, out_width))
# computation
for n in range(batch):
for f in range(out_channel):
for c in range(in_channel):
if pad_d > 0 or pad_h > 0 or pad_w > 0:
apad = np.zeros((in_depth + pad_d, in_height + pad_h, in_width + pad_w))
apad[pad_front:pad_front + in_depth, pad_top:pad_top + in_height,\
pad_left:pad_left + in_width] = at[n, c]
else:
apad = at[n, c]
out = scipy.signal.convolve(
apad, np.flip(wt[f, c]), mode='valid')
bt[n, f] += out[::stride_d, ::stride_h, ::stride_w]
return bt.transpose((0, 2, 3, 4, 1))
......@@ -22,12 +22,16 @@ from tvm import autotvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.nn.util import get_pad_tuple3d
from topi.util import get_const_tuple
from common import get_all_backend
def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
pad_front, pad_top, pad_left, pad_back, pad_bottom, pad_right = get_pad_tuple3d(padding, (kernel, kernel, kernel))
padding_sum = pad_front + pad_back + pad_top + pad_left + pad_bottom + pad_right
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride,
padding_sum, dilation))
in_depth = in_height = in_width = in_size
......@@ -62,7 +66,7 @@ def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride,
return
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.nn.conv3d(A, W, (stride, stride, stride), (padding, padding, padding),
C = topi.nn.conv3d(A, W, (stride, stride, stride), padding,
(dilation, dilation, dilation), layout='NCDHW', out_dtype=dtype)
if add_bias:
C = topi.add(C, bias)
......@@ -75,10 +79,10 @@ def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride,
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding_sum, dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
......@@ -109,6 +113,14 @@ def test_conv3d_ncdhw():
verify_conv3d_ncdhw(2, 2, 2, 2, 2, 2, 2)
verify_conv3d_ncdhw(3, 3, 3, 3, 3, 3, 3)
# Asymmetric padding
verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, (0, 0, 0, 1, 1, 1))
verify_conv3d_ncdhw(1, 32, 32, 1, 1, 1, (2, 1, 2, 1, 2, 1))
verify_conv3d_ncdhw(1, 64, 56, 3, 3, 1, (2, 2, 2, 1, 1, 1), dilation=2)
verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, (0, 1, 1))
verify_conv3d_ncdhw(1, 32, 32, 1, 1, 1, (2, 1, 0))
verify_conv3d_ncdhw(1, 32, 32, 1, 3, 1, "VALID")
verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, "VALID")
if __name__ == "__main__":
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example code to do convolution."""
import os
import numpy as np
import tvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
def verify_conv3d_ndhwc(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
in_depth = in_height = in_width = in_size
A = tvm.placeholder((batch, in_depth, in_height, in_width, in_channel), name='A')
W = tvm.placeholder((kernel, kernel, kernel, in_channel, num_filter), name='W')
B = topi.nn.conv3d_ndhwc(A, W, stride, padding, dilation)
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv3d_ndhwc.verify_ndhwc.v2")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, dilation, 1, 1))
b_np = topi.testing.conv3d_ndhwc_python(a_np, dw_np, stride, padding)
return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_conv3d_ndhwc([B])
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, W, B], device)
func(a, w, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm']:
check_device(device)
def test_conv3d_ndhwc():
verify_conv3d_ndhwc(1, 16, 32, 16, 3, 1, "SAME")
verify_conv3d_ndhwc(4, 32, 16, 32, 5, 2, "SAME")
verify_conv3d_ndhwc(4, 32, 16, 64, 5, 2, "SAME")
verify_conv3d_ndhwc(1, 64, 32, 64, 3, 1, "VALID")
verify_conv3d_ndhwc(1, 64, 32, 64, 3, 1, "VALID")
verify_conv3d_ndhwc(4, 32, 16, 32, 5, 2, "VALID")
verify_conv3d_ndhwc(4, 32, 16, 64, 5, 2, "VALID")
# dilation = 2
verify_conv3d_ndhwc(1, 64, 32, 64, 3, 1, "SAME", dilation=2)
if __name__ == "__main__":
test_conv3d_ndhwc()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment