Commit 7e32f373 by optima2005 Committed by masahi

implement conv3d op (#4400)

* implement conv3d op

* add back missed conv2d_output_shape by mistake

* fix typo and docs, add topi test

* rebase to master and merge 2d/3d unification

* use cudnn.conv_forward
parent 279a8eba
......@@ -48,6 +48,7 @@ struct BiasAddAttrs : public tvm::AttrsNode<BiasAddAttrs> {
}
};
/*! \brief Attributes used in convolution operators */
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides;
......@@ -193,6 +194,61 @@ struct Conv2DWinogradNNPACKWeightTransformAttrs
}
};
/*! \brief Attributes used in convolution operators */
struct Conv3DAttrs : public tvm::AttrsNode<Conv3DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv3DAttrs, "relay.attrs.Conv3DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the strides of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCDHW")
.describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Convolution is applied on the 'D', 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW")
.describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc."
"'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height,"
"and width dimensions respectively.");
TVM_ATTR_FIELD(out_layout).set_default("")
.describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc."
"'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width"
"dimensions respectively. Default to be same as input layout.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
/*! \brief Attributes used in softmax operators */
struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
int axis;
......
......@@ -142,7 +142,6 @@ def _find_conv2d_op(op):
return op_
return None
@reg.register_compute("nn.conv2d")
def compute_conv2d(attrs, inputs, out_type, target):
"""Compute definition of conv2d"""
......@@ -278,6 +277,48 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target):
return [out]
@reg.register_compute("nn.conv3d")
def compute_conv3d(attrs, inputs, out_type, target):
"""Compute definition of conv3d"""
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
layout = attrs.data_layout
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
assert layout in ["NCDHW"]
(dilation_d, dilation_h, dilation_w) = dilation
if dilation_d < 1 or dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
out = topi.nn.conv3d(
inputs[0], inputs[1], strides, padding,
dilation, layout, out_dtype)
else:
raise ValueError("not support arbitrary group number for now")
return [out]
@reg.register_schedule("nn.conv3d")
def schedule_conv3d(attrs, outs, target):
"""Schedule definition of conv3d"""
groups = attrs.groups
layout = attrs.data_layout
with target:
if groups == 1 and layout == "NCDHW":
return topi.generic.schedule_conv3d_ncdhw(outs)
raise ValueError("No compatible schedule")
reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_schedule("nn.conv2d_transpose")
def schedule_conv2d_transpose(attrs, outs, target):
"""Schedule definition of conv2d_transpose"""
......
......@@ -106,6 +106,91 @@ def conv2d(data,
kernel_layout, out_layout, out_dtype)
def conv3d(data,
weight,
strides=(1, 1, 1),
padding=(0, 0, 0),
dilation=(1, 1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCDHW",
kernel_layout="OIDHW",
out_layout="",
out_dtype=""):
r"""3D convolution.
This operator takes the weight as the convolution kernel
and convolves it with data to produce an output.
In the default case, where the data_layout is `NCDHW`
and kernel_layout is `OIDHW`, conv3d takes in
a data Tensor with shape `(batch_size, in_channels, depth, height, width)`,
and a weight Tensor with shape `(channels, in_channels, kernel_size[0], kernel_size[1],
kernel_size[2])` to produce an output Tensor with the following rule:
.. math::
\mbox{out}[b, c, z, y, x] = \sum_{dz, dy, dx, k}
\mbox{data}[b, k, \mbox{strides}[0] * z + dz, \mbox{strides}[1] * y + dy,
\mbox{strides}[2] * x + dx] * \mbox{weight}[c, k, dz, dy, dx]
Padding and dilation are applied to data and weight respectively before the computation.
This operator accepts data layout specification.
Semantically, the operator will convert the layout to the canonical layout
(`NCDHW` for data and `OIDHW` for weight), perform the computation,
then convert to the out_layout.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
strides : Optional[Tuple[int]]
The strides of convolution.
padding : Optional[Tuple[int]]
The padding of convolution on both sides of inputs before convolution.
dilation : Optional[Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
groups : Optional[int]
Number of groups for grouped convolution.
channels : Optional[int]
Number of output channels of this convolution.
kernel_size : Optional[Tuple[int]]
The spatial of the convolution kernel.
data_layout : Optional[str]
Layout of the input.
kernel_layout : Optional[str]
Layout of the weight.
out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : Optional[str]
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.conv3d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def conv2d_transpose(data,
weight,
strides=(1, 1),
......
......@@ -106,6 +106,64 @@ with the layer input to produce a tensor of outputs.
.add_type_rel("Conv2D", Conv2DRel<Conv2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", Conv2DInferCorrectLayout<Conv2DAttrs>);
// relay.nn.conv3d
TVM_REGISTER_NODE_TYPE(Conv3DAttrs);
// Positional relay function to create conv3d operator
// used by frontend FFI.
Expr MakeConv3D(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<Conv3DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = std::move(channels);
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv3d");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.conv3d")
.set_body_typed(MakeConv3D);
RELAY_REGISTER_OP("nn.conv3d")
.describe(R"code(3D convolution layer (e.g. convolution over 3D image data,
like Magnetic Resonance Imaging (MRI) data in medicine).
This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.
- **data**: This depends on the `layout` parameter. Input is 5D array of shape
(batch_size, in_channels, depth, height, width) if `layout` is `NCDHW`.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2])
- **out**: This depends on the `layout` parameter. Output is 5D array of shape
(batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv3DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv3D", Conv3DRel<Conv3DAttrs>);
// relay.nn.conv2d_transpose
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
......
......@@ -138,6 +138,123 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}
template <typename AttrType>
bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCDHW("NCDHW");
static const Layout kOIDHW("OIDHW");
const AttrType* param = attrs.as<AttrType>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCDHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCDHW."
<< " But got " << in_layout;
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIDHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIDHW."
<< " But got " << kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCDHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCDHW."
<< " But got " << out_layout;
Array<IndexExpr> dshape_ncdhw = trans_in_layout.ForwardShape(data->shape);
IndexExpr channels, dilated_ksize_z, dilated_ksize_y, dilated_ksize_x;
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 3);
CHECK_EQ(param->dilation.size(), 3);
Array<IndexExpr> wshape;
if (tvm::ir::Equal(param->channels, param->groups) && !tvm::ir::Equal(param->channels, 1)) {
// infer weight's shape for depthwise convolution
wshape = {{dshape_ncdhw[1], indexdiv(param->groups, dshape_ncdhw[1]), param->kernel_size[0],
param->kernel_size[1], param->kernel_size[2]}};
} else {
wshape = {{param->channels, indexdiv(dshape_ncdhw[1], param->groups), param->kernel_size[0],
param->kernel_size[1], param->kernel_size[2]}};
}
/*wshape = trans_kernel_layout.BackwardShape(wshape); */
channels = param->channels;
dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2];
DataType weight_dtype = data->dtype;
if (weight != nullptr) {
weight_dtype = weight->dtype;
}
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 3);
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]) &&
reporter->AssertEQ(param->kernel_size[2], wshape[4]))
<< "Conv3D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size << " wshape=" << wshape;
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[0]))
<< "Conv3D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1]));
channels = wshape[0];
dilated_ksize_z = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_y = 1 + (wshape[3] - 1) * param->dilation[1];
dilated_ksize_x = 1 + (wshape[4] - 1) * param->dilation[2];
}
// dilation
Array<IndexExpr> oshape({dshape_ncdhw[0], channels, 0, 0, 0});
if (!dshape_ncdhw[2].as<ir::Any>()) {
oshape.Set(2, indexdiv(dshape_ncdhw[2] + param->padding[0] * 2 - dilated_ksize_z,
param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_ncdhw[2]);
}
if (!dshape_ncdhw[3].as<ir::Any>()) {
oshape.Set(3, indexdiv(dshape_ncdhw[3] + param->padding[1] * 2 - dilated_ksize_y,
param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_ncdhw[3]);
}
if (!dshape_ncdhw[4].as<ir::Any>()) {
oshape.Set(4, indexdiv(dshape_ncdhw[4] + param->padding[2] * 2 - dilated_ksize_x,
param->strides[2]) + 1);
} else {
oshape.Set(4, dshape_ncdhw[4]);
}
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true;
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_NN_CONVOLUTION_H_
......@@ -294,6 +294,51 @@ def test_conv2d_winograd():
padding=(2, 2), channels=192, kernel_size=(7, 7))
def test_conv3d_run():
def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape,
padding=(1, 1, 1),
fref=None,
groups=1,
dilation=(1, 1, 1),
except_targets=None,
**attrs):
if except_targets is None:
except_targets = []
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", dtype=dtype)
y = relay.nn.conv3d(x, w,
padding=padding,
dilation=dilation,
groups=groups,
**attrs)
func = relay.Function([x, w], y)
data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
dkernel = topi.testing.dilate_python(kernel, (1, 1) + dilation)
if fref is None:
ref_res = topi.testing.conv3d_ncdhw_python(
data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding,
groups=groups)
else:
ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype))
for target, ctx in ctx_list():
if target in except_targets:
continue
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data, kernel)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
# normal conv3d
dshape = (1, 3, 5, 224, 224)
kshape = (10, 3, 3, 3, 3)
run_test_conv3d("float32", "float32", 1, dshape, kshape,
padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3))
def test_conv2d_transpose_infer_type():
# symbolic in batch dimension
n, c, h, w = tvm.var("n"), 10, 10, 12
......@@ -850,6 +895,7 @@ if __name__ == "__main__":
test_conv2d_transpose_nhwc_run()
test_conv2d_run()
test_conv2d_winograd()
test_conv3d_run()
test_bitserial_conv2d_infer_type()
test_batch_flatten()
test_upsampling()
......
......@@ -21,6 +21,7 @@ from __future__ import absolute_import as _abs
from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, deformable_conv2d, \
group_conv2d_nchw, dense
from . import conv3d
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Compute definition for conv3d with cuda backend"""
import tvm
from tvm import autotvm
from tvm.contrib import cudnn
from .. import nn, generic
from ..util import get_const_tuple, traverse_inline
from .conv3d_direct import schedule_direct_3d_cuda
@autotvm.register_topi_compute(nn.conv3d, ['cuda', 'gpu'], ['direct'])
def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', out_dtype='float32'):
"""Conv3D operator for cuda backend.
Parameters
----------
cfg: ConfigEntity
The config for this template
data : tvm.Tensor
5-D with shape [batch, in_channel, in_depth, in_height, in_width]
kernel : tvm.Tensor
5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width]
strides : int or a list/tuple of three ints
stride size, or [stride_depth, stride_height, stride_width]
padding : int or a list/tuple of three ints
padding size, or [pad_depth, pad_height, pad_width]
dilation: int or a list/tuple of three ints
dilation size, or [dilation_depth, dilation_height, dilation_width]
layout : str
layout of data
out_dtype: str
The output type. This is used for mixed precision.
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_channel, out_depth, out_height, out_width]
"""
target = tvm.target.current_target()
if "cudnn" in target.libs:
if layout == 'NCDHW':
tensor_format = 0 # CUDNN_TENSOR_NCHW
N, _, D, H, W = get_const_tuple(data.shape)
elif layout == 'NDHWC':
tensor_format = 1 # CUDNN_TENSOR_NHWC
N, D, H, W, _ = get_const_tuple(data.shape)
else:
raise ValueError("Unsupported layout %s in cudnn" % layout)
CO, CI, KD, KH, KW = get_const_tuple(kernel.shape)
# handle dilation
stride_d, stride_h, stride_w = (strides, strides, strides) if isinstance(strides, int) \
else strides
pad_d, pad_h, pad_w = (padding, padding, padding) if isinstance(padding, int) else padding
dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation) if \
isinstance(dilation, int) else dilation
OD = (D + 2 * pad_d - KD) // stride_d + 1
OH = (H + 2 * pad_h - KH) // stride_h + 1
OW = (W + 2 * pad_w - KW) // stride_w + 1
cfg.add_flop(2 * N * OD * OH * OW * CO * CI * ((DH - 1) * dilation_d + 1) *\
((KH - 1) * dilation_h + 1) * ((KW - 1) * dilation_w + 1))
return cudnn.conv_forward(data,
kernel,
[pad_d, pad_h, pad_w],
[stride_d, stride_h, stride_w],
[dilation_d, dilation_h, dilation_w],
conv_mode=1,
tensor_format=tensor_format,
algo=-1, # let CUDNN choose the best algo
conv_dtype=dtype)
if layout == 'NCDHW':
return nn.conv3d_ncdhw(data, kernel, strides, padding, dilation, out_dtype)
raise ValueError("not support this layout {} yet".format(layout))
@autotvm.register_topi_schedule(generic.schedule_conv3d_ncdhw, ["cuda", "gpu"],
["direct"])
def schedule_conv3d_ncdhw_cuda(cfg, outs):
"""TOPI schedule callback of conv3d for cuda gpu
Parameters
----------
cfg: ConfigEntity
The config for this template
outs: Array of Tensor
The computation graph description of conv2d
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d.
"""
target = tvm.target.current_target()
if 'cudnn' in target.libs:
return generic.schedule_extern(outs)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if op.tag == 'conv3d_ncdhw':
schedule_direct_3d_cuda(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""The templates for cuda conv3d operators"""
import tvm
from tvm import autotvm
from ..util import get_const_tuple
def schedule_direct_3d_cuda(cfg, s, conv):
"""schedule optimized for batch size = 1"""
##### space definition begin #####
n, f, d, y, x = s[conv].op.axis
rc, rd, ry, rx = s[conv].op.reduce_axis
cfg.define_split("tile_f", f, num_outputs=4)
cfg.define_split("tile_d", d, num_outputs=4)
cfg.define_split("tile_y", y, num_outputs=4)
cfg.define_split("tile_x", x, num_outputs=4)
cfg.define_split("tile_rc", rc, num_outputs=2)
cfg.define_split("tile_rd", ry, num_outputs=2)
cfg.define_split("tile_ry", ry, num_outputs=2)
cfg.define_split("tile_rx", rx, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
target = tvm.target.current_target()
if target.target_name in ['nvptx', 'rocm']:
cfg.define_knob("unroll_explicit", [1])
else:
cfg.define_knob("unroll_explicit", [0, 1])
# fallback support
if cfg.is_fallback:
ref_log = autotvm.tophub.load_reference_log(
target.target_name, target.model, 'conv3d', 'direct')
cfg.fallback_with_reference_log(ref_log)
##### space definition end #####
pad_data, kernel = s[conv].op.input_tensors
s[pad_data].compute_inline()
if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag:
s[kernel].compute_inline()
if conv.op in s.outputs:
output = conv
OL = s.cache_write(conv, 'local')
else:
output = s.outputs[0].output(0)
s[conv].set_scope('local')
OL = conv
# create cache stage
AA = s.cache_read(pad_data, 'shared', [OL])
WW = s.cache_read(kernel, 'shared', [OL])
# tile and bind spatial axes
n, f, d, y, x = s[output].op.axis
kernel_scope, n = s[output].split(n, nparts=1)
bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
bd, vd, td, di = cfg["tile_d"].apply(s, output, d)
by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
bf = s[output].fuse(n, bf)
s[output].reorder(bf, bd, by, bx, vf, vd, vy, vx, tf, td, ty, tx, fi, di, yi, xi)
s[output].bind(bf, tvm.thread_axis("blockIdx.z"))
s[output].bind(s[output].fuse(bd, by), tvm.thread_axis("blockIdx.y"))
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(vf, tvm.thread_axis("vthread"))
s[output].bind(vd, tvm.thread_axis("vthread"))
s[output].bind(vy, tvm.thread_axis("vthread"))
s[output].bind(vx, tvm.thread_axis("vthread"))
s[output].bind(s[output].fuse(td, tf), tvm.thread_axis("threadIdx.z"))
s[output].bind(ty, tvm.thread_axis("threadIdx.y"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
s[OL].compute_at(s[output], tx)
# tile reduction axes
n, f, d, y, x = s[OL].op.axis
rc, rd, ry, rx = s[OL].op.reduce_axis
rco, rci = cfg['tile_rc'].apply(s, OL, rc)
rdo, rdi = cfg['tile_rd'].apply(s, OL, rd)
ryo, ryi = cfg['tile_ry'].apply(s, OL, ry)
rxo, rxi = cfg['tile_rx'].apply(s, OL, rx)
s[OL].reorder(rco, rdo, ryo, rxo, rci, rdi, ryi, rxi, n, f, d, y, x)
s[AA].compute_at(s[OL], rxo)
s[WW].compute_at(s[OL], rxo)
# cooperative fetching
for load in [AA, WW]:
n, f, d, y, x = s[load].op.axis
fused = s[load].fuse(n, f, d, y, x)
tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
td, fused = s[load].split(fused, nparts=cfg["tile_d"].size[2])
ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
s[load].bind(tz, tvm.thread_axis("threadIdx.z"))
s[load].bind(s[load].fuse(td, ty), tvm.thread_axis("threadIdx.y"))
s[load].bind(tx, tvm.thread_axis("threadIdx.x"))
# unroll
s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val)
s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val)
N, CO, OD, OH, OW = get_const_tuple(output.shape)
_, KD, KH, KW, CI = get_const_tuple(kernel.shape)
cfg.add_flop(2 * N * OD * OH * OW * CO * CI * KD * KH * KW)
......@@ -226,6 +226,24 @@ def schedule_conv2d_winograd_nnpack_without_weight_transform(outs):
@tvm.target.generic_func
def schedule_conv3d_ncdhw(outs):
"""Schedule for conv3d_ncdhw
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv2d_transpose_nchw(outs):
"""Schedule for conv2d_transpose_nchw
......
......@@ -20,6 +20,7 @@
from __future__ import absolute_import as _abs
from .conv2d import *
from .conv3d import *
from .deformable_conv2d import *
from .depthwise_conv2d import *
from .elemwise import *
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin
"""Conv3D operators"""
from __future__ import absolute_import as _abs
import tvm
from .pad import pad
from .util import get_pad_tuple3d
from ..util import simplify
@tvm.target.generic_func
def conv3d(input, filter, strides, padding, dilation, layout='NCDHW', out_dtype=None):
"""Conv3D operator.
Parameters
----------
input : tvm.Tensor
5-D with shape [batch, in_depth, in_channel, in_height, in_width]
filter : tvm.Tensor
5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width]
strides : int or a list/tuple of three ints
stride size, or [stride_depth, stride_height, stride_width]
padding : int or a list/tuple of three ints
padding size, or [pad_depth, pad_height, pad_width]
dilation: int or a list/tuple of three ints
dilation size, or [dilation_depth, dilation_height, dilation_width]
layout : str
layout of data
Returns
-------
output : tvm.Tensor
5-D with shape [batch, out_depth, out_channel, out_height, out_width]
"""
# search platform specific declaration first
# default declaration
if layout == 'NCDHW':
return conv3d_ncdhw(input, filter, strides, padding, dilation, out_dtype)
raise ValueError("not support this layout {} yet".format(layout))
def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None):
"""Convolution operator in NCDHW layout.
Parameters
----------
Input : tvm.Tensor
5-D with shape [batch, in_channel, in_depth, in_height, in_width]
Filter : tvm.Tensor
5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width]
stride : int or a list/tuple of three ints
Stride size, or [strid_depth, stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation: int or a list/tuple of three ints
dilation size, or [dilation_depth, dilation_height, dilation_width]
Returns
-------
Output : tvm.Tensor
5-D with shape [batch, out_channel, out_depth, out_height, out_width]
"""
if out_dtype is None:
out_dtype = Input.dtype
assert isinstance(stride, int) or len(stride) == 3
assert isinstance(dilation, int) or len(dilation) == 3
if isinstance(stride, int):
stride_d = stride_h = stride_w = stride
else:
stride_d, stride_h, stride_w = stride
if isinstance(dilation, int):
dilation_d = dilation_h = dilation_w = dilation
else:
dilation_d, dilation_h, dilation_w = dilation
batch, in_channel, in_depth, in_height, in_width = Input.shape
num_filter, channel, kernel_d, kernel_h, kernel_w = Filter.shape
# compute the output shape
dilated_kernel_d = (kernel_d - 1) * dilation_d + 1
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d(
padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w))
out_channel = num_filter
out_depth = simplify((in_depth - dilated_kernel_d + pad_front + pad_back) // stride_d + 1)
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
# compute graph
pad_before = [0, 0, pad_front, pad_top, pad_left]
pad_after = [0, 0, pad_back, pad_down, pad_right]
temp = pad(Input, pad_before, pad_after, name="pad_temp")
rc = tvm.reduce_axis((0, in_channel), name='rc')
rz = tvm.reduce_axis((0, kernel_d), name='rz')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
return tvm.compute(
(batch, out_channel, out_depth, out_height, out_width),
lambda nn, ff, zz, yy, xx: tvm.sum(
temp[nn, rc, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w].astype(out_dtype) *
Filter[ff, rc, rz, ry, rx].astype(out_dtype),
axis=[rc, rz, ry, rx]), tag="conv3d_ncdhw")
......@@ -118,3 +118,57 @@ def get_pad_tuple(padding, kernel):
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left
def get_pad_tuple3d(padding, kernel):
"""Common code to get the pad option
Parameters
----------
padding : int or str
Padding size, or ['VALID', 'SAME']
kernel : tuple of int
Conv kernel size
Returns
-------
pad_front : int
Padding size on front.
pad_top : int
Padding size on top
pad_left : int
Padding size on left
pad_back : int
Padding size on back.
pad_down : int
Padding size on down.
pad_right : int
Padding size on right.
"""
# compute the padding size
if isinstance(padding, (tuple, list)):
pad_h = padding[0] * 2
pad_w = padding[1] * 2
pad_d = padding[2] * 2
elif isinstance(padding, int):
pad_d = pad_w = pad_h = padding * 2
elif padding == "VALID":
pad_h = 0
pad_w = 0
pad_d = 0
elif padding == "SAME":
pad_h = kernel[0] - 1
pad_w = kernel[1] - 1
pad_d = kernel[2] - 1
else:
raise ValueError("Unknown padding option %s" % padding)
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
pad_front = (pad_d + 1) // 2
return pad_front, pad_top, pad_left, pad_d - pad_front, pad_h - pad_top, pad_w - pad_left
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example code to do convolution."""
import numpy as np
import tvm
from tvm import autotvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from common import get_all_backend
def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False):
print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
in_depth = in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel, kernel), name='W')
bias = tvm.placeholder((num_filter, 1, 1, 1), name='bias')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv3d_ncdhw.verify_conv3d_ncdhw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation, dilation))
c_np = topi.testing.conv3d_ncdhw_python(a_np, dw_np, stride, padding)
if add_bias:
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.nn.conv3d(A, W, (stride, stride, stride), (padding, padding, padding),
(dilation, dilation, dilation), layout='NCDHW', out_dtype=dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.generic.schedule_conv3d_ncdhw([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4)
for device in get_all_backend():
with autotvm.tophub.context(device): # load tophub pre-tuned parameters
check_device(device)
def test_conv3d_ncdhw():
#3DCNN workloads
verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, 0)
verify_conv3d_ncdhw(1, 32, 32, 1, 1, 1, 0)
verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, 1)
verify_conv3d_ncdhw(1, 32, 32, 1, 1, 1, 1)
# bias, relu
verify_conv3d_ncdhw(1, 64, 56, 3, 1, 1, 1, add_relu=True)
verify_conv3d_ncdhw(1, 64, 56, 3, 1, 1, 1, add_bias=True)
verify_conv3d_ncdhw(1, 64, 56, 3, 1, 1, 1, add_bias=True, add_relu=True)
# dilation = 2
verify_conv3d_ncdhw(1, 64, 56, 3, 3, 1, 1, dilation=2)
# batch size
verify_conv3d_ncdhw(4, 64, 56, 5, 3, 1, 1)
# weird workloads
verify_conv3d_ncdhw(2, 2, 2, 2, 2, 2, 2)
verify_conv3d_ncdhw(3, 3, 3, 3, 3, 3, 3)
if __name__ == "__main__":
test_conv3d_ncdhw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment