Commit 35099e6a by Josh Fromm Committed by masahi

[Relay/Topi][Op] Conv1D (#4639)

* added conv1d operators to topi.

* Started to add python testing.

* Added python conv1d implementation for testing.

* Wrote test but need to add cuda schedule :(

* Cuda schedules working for both conv1d layouts.

* All topi tests passing.

* Formatting topi.

* Removed pad_method option as its probably overkill.

* Added relay op definition of conv1d.

* End2end conv1d working with onnx.

* Lint fixes.

* Formatting fixes.

* Rebase fix.

* Switched to array based attributes for consistency across convs.

* Improved onnx parsing and testing for convolutions.

* lint fix

* Tiny tweak.

* Bug fix

* Rebase fix.

* Add group ignore to onnx conv1d frontend.

* Unified MakeConv and fixed documentation.

* improved autopadding

* Addressed feedback and simplified onnx frontend.

* Format fix.

* Basic X86 NCW schedule working.

* Added nwc schedule.

* fixed name

* Added more tests and basic x86 schedules.

* Format fix.

* Added non power of two shape tests.
parent d8f06020
......@@ -49,6 +49,54 @@ struct BiasAddAttrs : public tvm::AttrsNode<BiasAddAttrs> {
};
/*! \brief Attributes used in 1D convolution operators */
struct Conv1DAttrs : public tvm::AttrsNode<Conv1DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilation;
int groups;
IndexExpr channels;
Array<IndexExpr> kernel_size;
std::string data_layout;
std::string kernel_layout;
std::string out_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv1DAttrs, "relay.attrs.Conv1DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, }))
.describe("Specifies the stride of the convolution.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, }))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Currently unused but may be added in the future.");
TVM_ATTR_FIELD(channels)
.describe("The number of output channels in the convolution."
" If it is not set, inferred by shape of the weight.")
.set_default(NullValue<IndexExpr>());
TVM_ATTR_FIELD(kernel_size)
.describe("Specifies the dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(data_layout).set_default("NCW")
.describe("Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
"'N', 'C', 'W' stands for batch, channel, and width"
"dimensions respectively. Convolution is applied on the 'W'"
"dimension.");
TVM_ATTR_FIELD(kernel_layout).set_default("OIW")
.describe("Dimension ordering of weight. Can be 'OIW', or 'WIO', etc."
"'O', 'I', 'W' stands for num_filter, input_channel, and width"
"dimensions respectively.");
// use 0 bits to indicate none.
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
/*! \brief Attributes used in convolution operators */
struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
Array<IndexExpr> strides;
......
......@@ -267,22 +267,25 @@ class Conv(OnnxOpConverter):
@classmethod
def _impl_v1(cls, inputs, attr, params):
# infer pads for auto_pad
# Use shape of input to determine convolution type.
input_shape = infer_shape(inputs[0])
if 'auto_pad' in attr:
attr['auto_pad'] = attr['auto_pad'].decode('utf-8')
if attr['auto_pad'] in ('SAME_UPPER', 'SAME_LOWER'):
input_shape = infer_shape(inputs[0])
in_h, in_w = input_shape[2], input_shape[3]
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']
dilation_h, dilation_w = attr['dilations']
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_v = get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = get_pad_pair(in_w, dilated_kernel_w, stride_w)
attr['pads'] = (pad_v[0], pad_h[0], pad_v[1], pad_h[1])
pad_tuple = []
for axis in range(len(input_shape) - 2):
axis_shape = input_shape[2 + axis]
stride = attr['strides'][axis]
kernel = attr['kernel_shape'][axis]
dilation = attr['dilations'][axis]
dilated_kernel = (kernel - 1) * dilation + 1
pad = get_pad_pair(axis_shape, dilated_kernel, stride)
pad_tuple.append(pad)
pad_tuple = tuple([val for pair in zip(*pad_tuple) for val in pair])
attr['pads'] = pad_tuple
elif attr['auto_pad'] == 'VALID':
attr['pads'] = (0, 0)
attr['pads'] = tuple([0 for i in range(len(input_shape) - 2)])
elif attr['auto_pad'] == 'NOTSET':
pass
else:
......@@ -294,10 +297,12 @@ class Conv(OnnxOpConverter):
op_name=dimension_picker('conv'),
transforms={
'kernel_shape': 'kernel_size',
'dilations': ('dilation', (0, 0)),
'pads': ('padding', (0, 0), revert_caffe2_pad),
'group': ('groups', 1)},
'dilations': ('dilation', 1),
'pads': ('padding', 0),
'group': ('groups', 1)
},
custom_check=dimension_constraint())(inputs[:2], attr, params)
use_bias = len(inputs) == 3
if use_bias:
out = _op.nn.bias_add(out, inputs[2])
......@@ -713,8 +718,8 @@ class Upsample(OnnxOpConverter):
else:
raise tvm.error.OpAttributeInvalid(
'Value {} in attribute "mode" of operator Upsample is not valid.'.format(mode))
attr = {'scale_h':scales[-2], 'scale_w':scales[-1], 'method':method,
'layout':'NCHW', 'align_corners':True}
attr = {'scale_h': scales[-2], 'scale_w': scales[-1], 'method': method,
'layout': 'NCHW', 'align_corners': True}
return AttrCvt('upsampling')(inputs, attr)
......@@ -848,7 +853,7 @@ class Gather(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
axis = attr.get('axis', 0)
return AttrCvt('take',
extras={'axis':axis})(inputs, {})
extras={'axis': axis})(inputs, {})
class Greater(OnnxOpConverter):
......@@ -880,7 +885,7 @@ class LRN(OnnxOpConverter):
beta = attr.get('beta', 0.75)
bias = attr.get('bias', 1.0)
nsize = attr.get('size')
attr = {'size':nsize, 'axis':axis, 'alpha':alpha, 'beta':beta, 'bias':bias}
attr = {'size': nsize, 'axis': axis, 'alpha': alpha, 'beta': beta, 'bias': bias}
return AttrCvt('lrn')(inputs, attr)
class Maximum(OnnxOpConverter):
......@@ -926,7 +931,7 @@ class HardSigmoid(OnnxOpConverter):
alpha = attr.get('alpha', 0.2)
beta = attr.get('beta', 0.5)
transformX = (inputs[0] * _expr.const(alpha)) + _expr.const(beta)
attr = {'a_min':0, 'a_max':1}
attr = {'a_min': 0, 'a_max': 1}
return AttrCvt('clip')([transformX], attr)
class Reduce(OnnxOpConverter):
......@@ -940,7 +945,7 @@ class Reduce(OnnxOpConverter):
else:
axis_len = len(infer_shape(inputs[0]))
axis = list(range(axis_len))
attr = {'axis':axis, 'keepdims':attr.get('keepdims', True)}
attr = {'axis': axis, 'keepdims': attr.get('keepdims', True)}
return AttrCvt(cls.name)(inputs, attr)
class ReduceMax(Reduce):
......@@ -975,7 +980,7 @@ class ArgMax(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
axis = attr.get('axis', 0)
keepdims = attr.get('keepdims', True)
attr = {'axis':axis, 'keepdims':keepdims}
attr = {'axis': axis, 'keepdims': keepdims}
return AttrCvt('argmax')(inputs, attr)
class ArgMin(OnnxOpConverter):
......@@ -985,7 +990,7 @@ class ArgMin(OnnxOpConverter):
def _impl_v1(cls, inputs, attr, params):
axis = attr.get('axis', 0)
keepdims = attr.get('keepdims', True)
attr = {'axis':axis, 'keepdims':keepdims}
attr = {'axis': axis, 'keepdims': keepdims}
return AttrCvt('argmin')(inputs, attr)
class Softmax(OnnxOpConverter):
......
......@@ -131,6 +131,42 @@ def schedule_sparse_transpose(attrs, outputs, target):
reg.register_pattern("nn.sparse_transpose", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
# Conv1D
@reg.register_compute("nn.conv1d")
def compute_conv1d(attrs, inputs, out_type, target):
"""Compute definition of conv1d"""
strides = get_const_tuple(attrs.strides)
padding = get_const_tuple(attrs.padding)
dilation = get_const_tuple(attrs.dilation)
layout = attrs.data_layout
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
assert layout in ["NCW", "NWC"]
if dilation[0] < 1:
raise ValueError("dilation should be a positive value")
return [topi.nn.conv1d(inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)]
@reg.register_schedule("nn.conv1d")
def schedule_conv1d(attrs, outs, target):
"""Schedule definition of conv1d"""
layout = attrs.data_layout
with target:
if layout == "NCW":
return topi.generic.schedule_conv1d_ncw(outs)
elif layout == "NCW":
return topi.generic.schedule_conv1d_nwc(outs)
raise ValueError("No compatible schedule")
reg.register_pattern("nn.conv1d", OpPattern.OUT_ELEMWISE_FUSABLE)
# conv2d
def _find_conv2d_op(op):
"""Find the op with conv2d in its tag by traversing."""
......
......@@ -21,6 +21,99 @@ from ...expr import TupleWrapper
from . import _make
def conv1d(data,
weight,
strides=1,
padding=0,
dilation=1,
groups=1,
channels=None,
kernel_size=None,
data_layout="NCW",
kernel_layout="OIW",
out_layout="",
out_dtype=""):
r"""1D convolution.
This operator takes the weight as the convolution kernel
and convolves it with data to produce an output.
In the default case, where the data_layout is `NCW`
and kernel_layout is `OIW`, conv1d takes in
a data Tensor with shape `(batch_size, in_channels, width)`,
and a weight Tensor with shape `(channels, in_channels, kernel_size)`
to produce an output Tensor with the following rule:
.. math::
\mbox{out}[b, c, w] = \sum_{dw, k}
\mbox{data}[b, k, \mbox{strides}[0] * w + dw] *
\mbox{weight}[c, k, dw]
Padding and dilation are applied to data and weight respectively before the computation.
This operator accepts data layout specification.
Semantically, the operator will convert the layout to the canonical layout
(`NCW` for data and `OIW` for weight), perform the computation,
then convert to the out_layout.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
weight : tvm.relay.Expr
The weight expressions.
strides : Optional[int, Tuple[int]]
The strides of convolution.
padding : Optional[int, Tuple[int]]
The padding of convolution on both sides of the input before convolution.
dilation : Optional[int, Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
groups : Optional[int]
Currently unused for 1D convolution.
channels : Optional[int]
Number of output channels of this convolution.
kernel_size : Optional[int, Tuple[int]]
The spatial dimension of the convolution kernel.
data_layout : Optional[str]
Layout of the input.
kernel_layout : Optional[str]
Layout of the weight.
out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : Optional[str]
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
if isinstance(kernel_size, int):
kernel_size = (kernel_size, )
if isinstance(strides, int):
strides = (strides, )
if isinstance(dilation, int):
dilation = (dilation, )
if isinstance(padding, int):
padding = (padding, padding)
return _make.conv1d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
def conv2d(data,
weight,
strides=(1, 1),
......@@ -66,13 +159,13 @@ def conv2d(data,
weight : tvm.relay.Expr
The weight expressions.
strides : Optional[Tuple[int]]
strides : Optional[int, Tuple[int]]
The strides of convolution.
padding : Optional[Tuple[int]]
padding : Optional[int, Tuple[int]]
The padding of convolution on both sides of inputs before convolution.
dilation : Optional[Tuple[int]]
dilation : Optional[int, Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
groups : Optional[int]
......@@ -81,7 +174,7 @@ def conv2d(data,
channels : Optional[int]
Number of output channels of this convolution.
kernel_size : Optional[Tuple[int]]
kernel_size : Optional[int, Tuple[int]]
The spatial of the convolution kernel.
data_layout : Optional[str]
......@@ -101,6 +194,15 @@ def conv2d(data,
result : tvm.relay.Expr
The computed result.
"""
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(strides, int):
strides = (strides, strides)
if isinstance(dilation, int):
dilation = (dilation, dilation)
if isinstance(padding, int):
padding = (padding, padding)
return _make.conv2d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
......@@ -154,10 +256,10 @@ def conv3d(data,
strides : Optional[Tuple[int]]
The strides of convolution.
padding : Optional[Tuple[int]]
padding : Optional[int, Tuple[int]]
The padding of convolution on both sides of inputs before convolution.
dilation : Optional[Tuple[int]]
dilation : Optional[int, Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
groups : Optional[int]
......@@ -166,7 +268,7 @@ def conv3d(data,
channels : Optional[int]
Number of output channels of this convolution.
kernel_size : Optional[Tuple[int]]
kernel_size : Optional[int, Tuple[int]]
The spatial of the convolution kernel.
data_layout : Optional[str]
......@@ -186,6 +288,15 @@ def conv3d(data,
result : tvm.relay.Expr
The computed result.
"""
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(strides, int):
strides = (strides, strides, strides)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
if isinstance(padding, int):
padding = (padding, padding, padding)
return _make.conv3d(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)
......
......@@ -19,6 +19,12 @@
from ...attrs import Attrs
from ..base import register_relay_attr_node
@register_relay_attr_node
class Conv1DAttrs(Attrs):
"""Attributes for nn.conv1d"""
@register_relay_attr_node
class Conv2DAttrs(Attrs):
"""Attributes for nn.conv2d"""
......
......@@ -34,8 +34,6 @@
namespace tvm {
namespace relay {
// relay.nn.conv2d
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
template<typename T>
Array<Array<Layout> > ConvInferCorrectLayout(
......@@ -52,21 +50,22 @@ Array<Array<Layout> > ConvInferCorrectLayout(
params->data_layout : params->out_layout}};
}
// Positional relay function to create conv2d operator
// used by frontend FFI.
Expr MakeConv2D(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_object<Conv2DAttrs>();
template <typename T>
Expr MakeConv(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype,
std::string op_name) {
auto attrs = make_object<T>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
......@@ -77,13 +76,77 @@ Expr MakeConv2D(Expr data,
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv2d");
static const Op& op = Op::Get(op_name);
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
// relay.nn.conv1d
TVM_REGISTER_NODE_TYPE(Conv1DAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv1d")
.set_body_typed([](Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
return MakeConv<Conv1DAttrs>(
data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype, "nn.conv1d");
});
RELAY_REGISTER_OP("nn.conv1d")
.describe(R"code(1D convolution layer (e.g. spatial convolution over sequences).
This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.
- **data**: This depends on the `layout` parameter. Input is 3D array of shape
(batch_size, in_channels, width) if `layout` is `NCW`.
- **weight**: (channels, in_channels, kernel_size)
- **out**: This depends on the `layout` parameter. Output is 3D array of shape
(batch_size, channels, out_width) if `layout` is `NCW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type<Conv1DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv1D", Conv1DRel<Conv1DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv1DAttrs>);
// relay.nn.conv2d
TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv2d")
.set_body_typed(MakeConv2D);
.set_body_typed([](Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
return MakeConv<Conv2DAttrs>(
data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype, "nn.conv2d");
});
RELAY_REGISTER_OP("nn.conv2d")
......@@ -110,38 +173,24 @@ with the layer input to produce a tensor of outputs.
// relay.nn.conv3d
TVM_REGISTER_NODE_TYPE(Conv3DAttrs);
// Positional relay function to create conv3d operator
// used by frontend FFI.
Expr MakeConv3D(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_object<Conv3DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = std::move(channels);
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv3d");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.conv3d")
.set_body_typed(MakeConv3D);
.set_body_typed([](Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
return MakeConv<Conv3DAttrs>(
data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype, "nn.conv3d");
});
RELAY_REGISTER_OP("nn.conv3d")
......
......@@ -34,6 +34,94 @@ namespace tvm {
namespace relay {
template <typename AttrType>
bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCW("NCW");
static const Layout kOIW("OIW");
const AttrType* param = attrs.as<AttrType>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCW."
<< " But got " << in_layout;
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIW."
<< " But got " << kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCW."
<< " But got " << out_layout;
Array<IndexExpr> dshape_ncw = trans_in_layout.ForwardShape(data->shape);
IndexExpr channels, dilated_ksize;
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
Array<IndexExpr> wshape;
wshape = {{param->channels, dshape_ncw[1], param->kernel_size[0]}};
wshape = trans_kernel_layout.BackwardShape(wshape);
channels = param->channels;
dilated_ksize = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
DataType weight_dtype = data->dtype;
if (weight != nullptr) {
weight_dtype = weight->dtype;
}
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
if (param->kernel_size.defined()) {
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) )
<< "Conv1D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size << " wshape=" << wshape;
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[0]))
<< "Conv1D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels << " wshape=" << wshape;
}
CHECK(reporter->AssertEQ(dshape_ncw[1], wshape[1]));
channels = wshape[0];
dilated_ksize = 1 + (wshape[2] - 1) * param->dilation[0];
}
// dilation
Array<IndexExpr> oshape({dshape_ncw[0], channels, 0});
if (!dshape_ncw[2].as<ir::AnyNode>()) {
oshape.Set(2, indexdiv(dshape_ncw[2] + param->padding[0] + param->padding[1] - dilated_ksize,
param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_ncw[2]);
}
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true;
}
template <typename AttrType>
bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
......
......@@ -1732,16 +1732,27 @@ def test_or():
verify_or(indata=[x, y], dtype=bool)
def verify_conv(x_shape, w_shape, y_shape, p):
node = helper.make_node('Conv',
inputs=['x', 'W'],
outputs=['y'],
kernel_shape=[3, 3],
# Default values for other attributes:
# strides=[1, 1],
# dilations=[1, 1],
# groups=1
pads=p,)
def verify_conv(x_shape, w_shape, y_shape, padding, kernel_shape, strides, dilations, auto_pad="NOTSET"):
if padding is None:
node = helper.make_node('Conv',
inputs=['x', 'W'],
outputs=['y'],
kernel_shape=kernel_shape,
# Default values for other attributes:
strides=strides,
dilations=dilations,
# groups=1
auto_pad=auto_pad)
else:
node = helper.make_node('Conv',
inputs=['x', 'W'],
outputs=['y'],
kernel_shape=kernel_shape,
# Default values for other attributes:
strides=strides,
dilations=dilations,
# groups=1
pads=padding)
graph = helper.make_graph([node],
'conv_test',
......@@ -1761,18 +1772,35 @@ def verify_conv(x_shape, w_shape, y_shape, p):
def test_conv():
# Convolution with padding
# (1, 1, 5, 5) input tensor
# (1, 1, 3, 3) tensor for convolution weights
# (1, 1, 5, 5) output tensor
# [1, 1, 1, 1] list for pads
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), [1, 1, 1, 1])
# Conv2D
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), [1, 1, 1, 1], [3, 3], [1, 1], [1, 1])
# Conv1D
verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 5), [1, 1], [3], [1], [1])
# Convolution without padding
# (1, 1, 5, 5) input tensor
# (1, 1, 3, 3) tensor for convolution weights
# (1, 1, 3, 3) output tensor
# [0, 0, 0, 0] list for pads
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), [0, 0, 0, 0])
# Conv2D
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3), [0, 0, 0, 0], [3, 3], [1, 1], [1, 1])
# Conv1D
verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 3), [0, 0], [3], [1], [1])
# Convolution with autopadding
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5),
None, [3, 3], [1, 1], [1, 1],
auto_pad="SAME_UPPER")
# Conv1D
verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 5), None, [3], [1], [1], auto_pad="SAME_UPPER")
# Convolution with non uniform stride
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 3, 3),
None, [3, 3], [2, 2], [1, 1],
auto_pad="SAME_UPPER")
# Conv1D
verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 3), None, [3], [2], [1], auto_pad="SAME_UPPER")
# Convolution with dilation
verify_conv((1, 1, 5, 5), (1, 1, 3, 3), (1, 1, 5, 5), [2, 2, 2, 2], [3, 3], [1, 1], [2, 2])
# Conv1D
verify_conv((1, 1, 5), (1, 1, 3), (1, 1, 5), [2, 2], [3], [1], [2])
def verify_convtranspose(x_shape, w_shape, y_shape, p):
......@@ -1838,15 +1866,15 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_p
raise ValueError("Pool method {} is not supported.".format(mode))
if pads is None:
pool_node = helper.make_node(node_type,
inputs=["x"],
pool_node = helper.make_node(node_type,
inputs=["x"],
outputs=["y"],
kernel_shape=kernel_shape,
auto_pad=auto_pad,
strides=strides)
else:
pool_node = helper.make_node(node_type,
inputs=["x"],
pool_node = helper.make_node(node_type,
inputs=["x"],
outputs=["y"],
kernel_shape=kernel_shape,
pads=pads,
......@@ -1867,6 +1895,7 @@ def verify_pooling(x_shape, kernel_shape, strides, pads, out_shape, mode, auto_p
model, [x_np], target, ctx, out_shape)
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-5, atol=1e-5)
def test_pooling():
for mode in ['max', 'average']:
# Pool1D
......
......@@ -31,6 +31,101 @@ def run_infer_type(expr):
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def test_conv1d_infer_type():
# symbolic in batch dimension
n, c, w = tvm.var("n"), 10, 224
x = relay.var("x", relay.ty.TensorType((n, c, w), "float32"))
w = relay.var("w")
y = relay.nn.conv1d(x, w,
kernel_size=3,
padding=(1, 1),
channels=2)
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, 2, 224), "float32")
assert yy.args[1].checked_type == relay.TensorType(
(2, 10, 3), "float32")
# infer by shape of w, mixed precision
n, c, w = tvm.var("n"), 10, 224
x = relay.var("x", relay.TensorType((n, c, w), "int8"))
w = relay.var("w", relay.TensorType((2, 10, 3), "int8"))
y = relay.nn.conv1d(x, w, out_dtype="int32")
assert "out_dtype=\"int32\"" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, 2, 222), "int32")
# infer shape in case of different dtypes for input and weight.
n, c, w = tvm.var("n"), 10, 224
x = relay.var("x", relay.TensorType((n, c, w), "uint8"))
w = relay.var("w", relay.TensorType((2, 10, 3), "int8"))
y = relay.nn.conv1d(x, w, out_dtype="int32")
assert "out_dtype=\"int32\"" in y.astext()
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, 2, 222), "int32")
# Infer with NWC
n, c, w = 4, 32, 224
x = relay.var("x", relay.TensorType((n, w, c), "int8"))
wt = relay.var("w")
y = relay.nn.conv1d(x, wt,
kernel_size=3,
padding=(1, 1),
channels=16,
data_layout="NWC",
out_dtype="int32")
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, w, 16), "int32")
def test_conv1d_run():
def run_test_conv1d(dtype, out_dtype, scale, dshape, kshape,
padding=(1, 1),
fref=None,
dilation=1,
except_targets=None,
**attrs):
if except_targets is None:
except_targets = []
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", dtype=dtype)
y = relay.nn.conv1d(x, w,
padding=padding,
dilation=dilation,
**attrs)
func = relay.Function([x, w], y)
data = np.random.uniform(-scale, scale, size=dshape).astype(dtype)
kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype)
ref_res = topi.testing.conv1d_ncw_python(
data.astype(out_dtype), kernel.astype(out_dtype), 1, padding, dilation)
for target, ctx in ctx_list():
if target in except_targets:
continue
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(data, kernel)
tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
# normal conv1d
dshape = (1, 3, 224)
kshape = (10, 3, 3)
run_test_conv1d("float32", "float32", 1, dshape, kshape,
padding=(1, 1), channels=10, kernel_size=3)
# mixed precision
run_test_conv1d("int8", "int32", 1, dshape, kshape,
padding=(1, 1), channels=10, kernel_size=3)
# dilated conv2d
dshape = (1, 3, 18)
kshape = (10, 3, 3)
run_test_conv1d("float32", "float32", 1, dshape, kshape,
padding=(1, 1), channels=10, kernel_size=3, dilation=3)
def test_conv2d_infer_type():
# symbolic in batch dimension
n, c, h, w = tvm.var("n"), 10, 224, 224
......@@ -1114,6 +1209,7 @@ if __name__ == "__main__":
test_avg_pool2d_no_count_pad()
test_lrn()
test_l2_normalize()
test_conv1d_infer_type()
test_conv2d_infer_type()
test_conv3d_infer_type()
test_bitpack_infer_type()
......@@ -1126,6 +1222,7 @@ if __name__ == "__main__":
test_conv2d_transpose_nchw_run()
test_conv2d_transpose_nhwc_run()
test_conv1d_transpose_ncw_run()
test_conv1d_run()
test_conv2d_run()
test_conv2d_winograd()
test_conv3d_run()
......
......@@ -19,8 +19,8 @@
"""CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs
from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, deformable_conv2d, \
group_conv2d_nchw, dense, conv1d_transpose_ncw
from . import conv1d, conv2d, depthwise_conv2d, conv2d_transpose_nchw, \
deformable_conv2d, group_conv2d_nchw, dense, conv1d_transpose_ncw
from . import conv3d
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc
......
......@@ -35,6 +35,42 @@ def _default_schedule(outs, auto_inline):
@tvm.target.generic_func
def schedule_conv1d_ncw(outs):
"""Schedule for conv1d_ncw
Parameters
----------
outs: Array of Tensor
The computation graph description of conv1d_ncw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv1d_nwc(outs):
"""Schedule for conv1d_nwc
Parameters
----------
outs: Array of Tensor
The computation graph description of conv1d_nwc
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv2d_hwcn(outs):
"""Schedule for conv2d_hwcn
......
......@@ -19,6 +19,7 @@
"""Neural network operators"""
from __future__ import absolute_import as _abs
from .conv1d import *
from .conv2d import *
from .conv3d import *
from .deformable_conv2d import *
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument
"""1D convolution operators."""
from __future__ import absolute_import as _abs
import tvm
from .pad import pad
from ..util import simplify
from .util import get_pad_tuple1d
@tvm.target.generic_func
def conv1d(data,
kernel,
strides=1,
padding='VALID',
dilation=1,
layout='NCW',
out_dtype=None):
""" 1D convolution forward operator.
Parameters
----------
data : tvm.Tensor
3-D input shape [batch, in_channel, in_width] for layout == 'NCW'
and [batch, in_width, in_channel] for layout == 'NWC'
kernel : tvm.Tensor
3-D kernel with shape [num_filter, in_channel, filter_size] for layout == 'NCW'
and [filter_size, in_channel, num_filter] for layout == 'NWC'
strides : int or tuple
The spatial stride along width
padding : int or str
Padding size, or ['VALID', 'SAME']
dilation : int or tuple
Dilation rate if convolution should be dilated.
layout : str
How input data is laid out, must be one of ['NCW', 'NWC']
out_dtype : str
The output data type. If None then output is same type as input.
"""
if out_dtype is None:
out_dtype = data.dtype
if isinstance(strides, (tuple, list)):
strides = strides[0]
if isinstance(dilation, (tuple, list)):
dilation = dilation[0]
if layout == 'NCW':
return conv1d_ncw(data, kernel, strides, padding, dilation, out_dtype)
if layout == 'NWC':
return conv1d_nwc(data, kernel, strides, padding, dilation, out_dtype)
raise ValueError("This layout is not yet supported: {}".format(layout))
def conv1d_ncw(data,
kernel,
strides=1,
padding='VALID',
dilation=1,
out_dtype=None):
""" 1D convolution forward operator for NCW layout.
Parameters
----------
data : tvm.Tensor
3-D with shape [batch, in_channel, in_width]
kernel : tvm.Tensor
3-D with shape [num_filter, in_channel, filter_size]
strides : int or tuple
The spatial stride along width
padding : int, tuple, or str
Padding size can be an integer for equal padding,
a tuple of (left, right) or a string in ['VALID', 'SAME'].
dilation : int or tuple
Dilation rate if convolution should be dilated.
out_dtype : str
The output data type. If None then output is same type as input.
"""
batch, in_channels, data_width = data.shape
out_channels, _, kernel_size = kernel.shape
# Compute the output shape
dilated_kernel_size = (kernel_size - 1) * dilation + 1
pad_left, pad_right = get_pad_tuple1d(padding, (dilated_kernel_size, ))
out_channels = simplify(out_channels)
out_width = simplify(
(data_width - dilated_kernel_size + pad_left + pad_right) // strides + 1)
# Apply padding
pad_before = [0, 0, pad_left]
pad_after = [0, 0, pad_right]
temp = pad(data, pad_before, pad_after, name='pad_temp')
# Compute graph
rc = tvm.reduce_axis((0, in_channels), name='rc')
rw = tvm.reduce_axis((0, kernel_size), name='rw')
return tvm.compute(
(batch, out_channels, out_width),
lambda b, c, w: tvm.sum(
temp[b, rc, w * strides + rw * dilation].astype(out_dtype)
* kernel[c, rc, rw].astype(out_dtype),
axis=[rc, rw]),
tag="conv1d_ncw")
def conv1d_nwc(data,
kernel,
strides=1,
padding='VALID',
dilation=1,
out_dtype=None):
""" 1D convolution forward operator for NWC layout.
Parameters
----------
data : tvm.Tensor
3-D with shape [batch, in_width, in_channel]
kernel : tvm.Tensor
3-D with shape [filter_size, in_channel, num_filter]
strides : int or tuple
The spatial stride along width
padding : int, tuple, or str
Padding size can be an integer for equal padding,
a tuple of (left, right) or a string in ['VALID', 'SAME'].
dilation : int or tuple
Dilation rate if convolution should be dilated.
out_dtype : str
The output data type. If None then output is same type as input.
"""
batch, data_width, in_channels = data.shape
kernel_size, _, out_channels = kernel.shape
# Compute the output shape
dilated_kernel_size = (kernel_size - 1) * dilation + 1
pad_left, pad_right = get_pad_tuple1d(padding, (dilated_kernel_size, ))
out_channels = simplify(out_channels)
out_width = simplify(
(data_width - dilated_kernel_size + pad_left + pad_right) // strides + 1)
# Apply padding
pad_before = [0, pad_left, 0]
pad_after = [0, pad_right, 0]
temp = pad(data, pad_before, pad_after, name='pad_temp')
# Compute graph
rc = tvm.reduce_axis((0, in_channels), name='rc')
rw = tvm.reduce_axis((0, kernel_size), name='rw')
return tvm.compute(
(batch, out_width, out_channels),
lambda b, w, c: tvm.sum(
temp[b, w * strides + rw * dilation, rc].astype(out_dtype)
* kernel[rw, rc, c].astype(out_dtype),
axis=[rc, rw]),
tag="conv1d_nwc")
......@@ -21,6 +21,7 @@ Used to verify the correctness of operators in TOPI .
"""
from __future__ import absolute_import as _abs
from .conv1d_ncw_python import conv1d_ncw_python
from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
from .conv2d_nhwc_python import conv2d_nhwc_python
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-variable, invalid-name
"""1D convolution in python"""
import numpy as np
from topi.nn.util import get_pad_tuple1d
def dilate_np(x, dilation):
""" 1D dilation using numpy
Parameters
----------
x : numpy.ndarray
Array to dilate with shape [batch, in_channel, in_width]
dilation : int
dilation rate of output
Returns
-------
out : numpy.ndarray
Dilated output with shape [batch, in_channel, (in_width - 1) * dilation + 1]
"""
irange = range(len(x) - 1)
for d in range(dilation - 1):
indices = [(d + 1)*(i + 1) for i in irange]
x = np.insert(x, indices, 0)
return x
def conv1d_ncw_python(a_np, w_np, stride, padding, dilation):
"""1D convolution operator in NCW layout
Parameters
----------
a_np : numpy.ndarray
3-D with shape [batch, in_channel, in_width]
w_np : numpy.ndarray
3-D with shape [num_filter, in_channel, filter_width]
stride : int
Stride size
padding : int, tuple, or str
Single int for padding size or tuple of (left, right) padding
or a string in ['VALID', 'SAME']
dilation : int
Dilation rate of the kernel
Returns
-------
b_np : numpy.ndarray
3-D with shape [batch, out_channel, out_width]
"""
batch, in_c, in_w = a_np.shape
out_c, _, filter_w = w_np.shape
if isinstance(stride, (tuple, list)):
stride = stride[0]
if isinstance(dilation, (tuple, list)):
dilation = dilation[0]
dilated_filter_w = (filter_w - 1) * dilation + 1
pad_left, pad_right = get_pad_tuple1d(padding, (dilated_filter_w,))
out_w = ((in_w - dilated_filter_w + pad_left + pad_right) // stride) + 1
padded_a_np = np.zeros((batch, in_c, in_w + pad_left + pad_right))
padded_a_np[:, :, pad_left:(in_w + pad_left)] = a_np
b_np = np.zeros((batch, out_c, out_w))
for n in range(batch):
for f in range(out_c):
for c in range(in_c):
out = np.convolve(
padded_a_np[n, c], np.flip(dilate_np(w_np[f, c], dilation)), mode='valid')
b_np[n, f] += out[::stride]
return b_np
......@@ -19,6 +19,7 @@
"""x86 specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .conv1d import schedule_conv1d_nwc
from .conv2d import schedule_conv2d, schedule_conv2d_nhwc
from .binarize_pack import schedule_binarize_pack
from .binary_dense import schedule_binary_dense
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,unused-variable,unused-argument,invalid-name
"""Conv1D schedule on for Intel CPU"""
from __future__ import absolute_import as _abs
import tvm
from .. import generic, tag
@generic.schedule_conv1d_ncw.register(["cpu"])
def schedule_conv1d_ncw(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
output_op = outs[0].op
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else: # inject custom schedule
if len(op.axis) == 3: # schedule bias + bn + relu
n, c, w = op.axis
fused = s[op].fuse(n, c)
s[op].parallel(fused)
s[op].vectorize(w)
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv1d_ncw' in op.tag:
conv = op.output(0)
kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n_pad, c_pad, w_pad = data_pad.op.axis
pad_fused = s[data_pad].fuse(n_pad, c_pad)
s[data_pad].parallel(pad_fused)
C = conv
n, c, w = C.op.axis
rc, rw = C.op.reduce_axis
n_out, c_out, w_out = output_op.axis
s[C].vectorize(w)
if op != output_op: # fuse bias + bn + relu into conv
s[C].compute_at(s[output_op], w_out)
else:
fused = s[C].fuse(n, c)
s[C].parallel(fused)
scheduled_ops.append(op)
traverse(output_op)
return s
@generic.schedule_conv1d_nwc.register(["cpu"])
def schedule_conv1d_nwc(outs):
"""Create schedule for tensors"""
s = tvm.create_schedule([x.op for x in outs])
output_op = outs[0].op
scheduled_ops = []
def traverse(op):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else: # inject custom schedule
if len(op.axis) == 3: # schedule bias + bn + relu
n, w, c = op.axis
fused = s[op].fuse(n, w)
s[op].parallel(fused)
s[op].vectorize(c)
for tensor in op.input_tensors:
if isinstance(tensor.op, tvm.tensor.ComputeOp) and tensor.op not in scheduled_ops:
traverse(tensor.op)
if 'conv1d_nwc' in op.tag:
conv = op.output(0)
kernel = op.input_tensors[1]
if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
s[kernel].compute_inline()
data = op.input_tensors[0]
data_pad = None
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
data_pad = data
data = data_pad.op.input_tensors[0]
n_pad, w_pad, c_pad = data_pad.op.axis
pad_fused = s[data_pad].fuse(n_pad, w_pad)
s[data_pad].parallel(pad_fused)
C = conv
n, w, c = C.op.axis
rc, rw = C.op.reduce_axis
n_out, w_out, c_out = output_op.axis
s[C].vectorize(c)
if op != output_op: # fuse bias + bn + relu into conv
s[C].compute_at(s[output_op], c_out)
else:
fused = s[C].fuse(n, w)
s[C].parallel(fused)
scheduled_ops.append(op)
traverse(output_op)
return s
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for transposed convolution."""
import numpy as np
import itertools
import tvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from common import get_all_backend
def verify_conv1d(batch,
in_channels,
in_width,
filters,
kernel_size=3,
stride=1,
dilation=1,
padding='VALID',
layout='NCW'):
if layout == 'NCW':
in_shape = [batch, in_channels, in_width]
kernel_shape = [filters, in_channels, kernel_size]
else:
in_shape = [batch, in_width, in_channels]
kernel_shape = [kernel_size, in_channels, filters]
dtype = 'float32'
A = tvm.placeholder(in_shape, name='A', dtype=dtype)
W = tvm.placeholder(kernel_shape, name='W', dtype=dtype)
def get_ref_data(layout):
a_np = np.random.uniform(size=in_shape).astype(dtype)
w_np = np.random.uniform(size=kernel_shape).astype(dtype)
if layout == 'NWC':
np_in = np.transpose(a_np, [0, 2, 1])
np_w = np.transpose(w_np, [2, 1, 0])
else:
np_in = a_np
np_w = w_np
b_np = topi.testing.conv1d_ncw_python(np_in, np_w, stride, padding, dilation)
if layout == 'NWC':
b_np = np.transpose(b_np, [0, 2, 1])
return a_np, w_np, b_np
a_np, w_np, b_np = get_ref_data(layout)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
with tvm.target.create(device):
B = topi.nn.conv1d(A, W, stride, padding, dilation, layout, 'float32')
if layout == 'NCW':
s = topi.generic.schedule_conv1d_ncw([B])
else:
s = topi.generic.schedule_conv1d_nwc([B])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
func = tvm.build(s, [A, W, B], device)
func(a, w, b)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in get_all_backend():
check_device(device)
def test_conv1d():
for layout in ["NCW", "NWC"]:
# Most basic test case
verify_conv1d(1, 1, 8, 1, 3, 1, 1, 'VALID', layout)
# With padding
verify_conv1d(1, 1, 8, 1, 3, 1, 1, 'SAME', layout)
# Realistic dimensions
verify_conv1d(1, 16, 32, 16, 3, 1, 1, 'SAME', layout)
# With stride
verify_conv1d(1, 16, 32, 16, 3, 2, 1, 'SAME', layout)
# With dilation
verify_conv1d(1, 16, 32, 16, 3, 1, 2, 'SAME', layout)
# Large batch size
verify_conv1d(8, 16, 32, 16, 3, 1, 1, 'SAME', layout)
# Other kernel sizes
verify_conv1d(1, 16, 32, 16, 3, 1, 1, 'SAME', layout)
verify_conv1d(1, 16, 32, 16, 2, 1, 1, 'SAME', layout)
verify_conv1d(1, 16, 32, 16, 1, 1, 1, 'SAME', layout)
# Non-power-of-two shape
verify_conv1d(1, 17, 12, 21, 3, 1, 1, 'SAME', layout)
verify_conv1d(1, 5, 27, 18, 3, 1, 1, 'VALID', layout)
if __name__ == "__main__":
test_conv1d()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment