Commit 020b6398 by Siva Committed by Tianqi Chen

[RELAY][OP] conv2d_transpose (#1862)

parent 3ff86009
......@@ -48,6 +48,7 @@ This level enables typical convnet models.
:nosignatures:
tvm.relay.nn.conv2d
tvm.relay.nn.conv2d_transpose
tvm.relay.nn.max_pool2d
tvm.relay.nn.avg_pool2d
tvm.relay.nn.global_max_pool2d
......@@ -129,6 +130,7 @@ Level 1 Definitions
Level 2 Definitions
-------------------
.. autofunction:: tvm.relay.nn.conv2d
.. autofunction:: tvm.relay.nn.conv2d_transpose
.. autofunction:: tvm.relay.nn.max_pool2d
.. autofunction:: tvm.relay.nn.avg_pool2d
.. autofunction:: tvm.relay.nn.global_max_pool2d
......
......@@ -77,6 +77,57 @@ struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
}
};
/*! \brief Attributes used in transposed convolution operator */
struct Conv2DTransposeAttrs : public tvm::AttrsNode<Conv2DTransposeAttrs> {
IndexExpr channels;
Array<IndexExpr> kernel_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> output_padding;
Array<IndexExpr> dilation;
int groups;
std::string data_layout;
std::string weight_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") {
TVM_ATTR_FIELD(channels)
.set_default(NullValue<IndexExpr>())
.describe("The dimensionality of the output space"
"i.e. the number of output channels in the convolution.");
TVM_ATTR_FIELD(kernel_size)
.describe("The dimensions of the convolution window.")
.set_default(NullValue<Array<IndexExpr> >());
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("The strides of the convolution.");
TVM_ATTR_FIELD(output_padding).set_default(Array<IndexExpr>({0, 0}))
.describe("Zero-padding added to one side of the output.");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"on both sides for padding number of points");
TVM_ATTR_FIELD(dilation).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use for dilated convolution.");
TVM_ATTR_FIELD(groups).set_default(1)
.describe("Controls the connections between inputs and outputs."
"At groups=1, all inputs are convolved to all outputs."
"At groups=2, the operation becomes equivalent to having two convolution"
"layers side by side, each seeing half the input channels, and producing"
"half the output channels, and both subsequently concatenated.");
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
.describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(weight_layout).set_default("OIHW")
.describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc."
"'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_dtype)
.set_default(Int(0))
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
/*! \brief Attributes for max pool operator */
struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
Array<IndexExpr> pool_size;
......
......@@ -88,6 +88,62 @@ def conv2d(data,
weight_layout, out_layout, out_dtype)
def conv2d_transpose(data,
weight,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW",
weight_layout="OIHW",
output_padding=(0, 0),
out_dtype=""):
"""Two dimensional trnasposed convolution operator.
Parameters
----------
data : relay.Expr
The input data to the operator.
weight : relay.Expr
The weight expressions.
strides : Tuple[int], optional
The strides of convoltution.
padding : Tuple[int], optional
The padding of convolution on both sides of inputs.
dilation : Tuple[int], optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
data_layout : str, optional
Layout of the input.
weight_layout : str, optional
Layout of the weight.
output_padding : Tuple[int], optional
Additional zero-padding to be added to one side of the output.
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.conv2d_transpose(data, weight, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
weight_layout, output_padding, out_dtype)
def softmax(data, axis):
r"""Computes softmax.
......@@ -103,8 +159,12 @@ def softmax(data, axis):
axis: int
The axis to sum over when computing softmax
"""
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.softmax(data, axis)
......@@ -125,8 +185,12 @@ def log_softmax(data, axis):
axis: int
The axis to sum over when computing softmax
"""
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.log_softmax(data, axis)
......
......@@ -154,5 +154,153 @@ with the layer input to produce a tensor of outputs.
.set_support_level(2)
.add_type_rel("Conv2D", Conv2DRel);
// Conv2DTranspose
TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs);
bool Conv2DTransposeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
static const Layout kOIHW("OIHW");
const Conv2DTransposeAttrs* param = attrs.as<Conv2DTransposeAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->weight_layout);
CHECK(in_layout.convertible(kNCHW))
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
CHECK(kernel_layout.convertible(kOIHW))
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
const auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW);
// infer weight if the kernel_size and channels are defined
if (param->kernel_size.defined() && param->channels.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
CHECK_EQ(param->dilation.size(), 2);
std::vector<IndexExpr> wshape({dshape_nchw[1],
param->channels / param->groups,
param->kernel_size[0],
param->kernel_size[1]});
wshape = ConvertLayout(wshape, kOIHW, kernel_layout);
dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
channels = param->channels;
// assign result to reporter
reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype));
} else {
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW);
if (param->kernel_size.defined()) {
CHECK_EQ(param->kernel_size.size(), 2);
// check the size
CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) &&
reporter->AssertEQ(param->kernel_size[1], wshape[3]))
<< "Conv2D: shape of weight is inconsistent with kernel_size, "
<< " kernel_size=" << param->kernel_size
<< " wshape=" << Array<IndexExpr>(wshape);
}
if (param->channels.defined()) {
CHECK(reporter->AssertEQ(param->channels, wshape[1]))
<< "Conv2D: shape of weight is inconsistent with channels, "
<< " channels=" << param->channels
<< " wshape=" << Array<IndexExpr>(wshape);
}
CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[0]));
channels = wshape[1];
dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0];
dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1];
}
// dilation
std::vector<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
oshape[2] = (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
2 * param->padding[0] + param->output_padding[0]);
oshape[3] = (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
2 * param->padding[1] + param->output_padding[1]);
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = ConvertLayout(oshape, kNCHW, in_layout);
reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype));
return true;
}
Expr MakeConv2DTranspose(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string weight_layout,
Array<IndexExpr> output_padding,
DataType out_dtype) {
auto attrs = make_node<Conv2DTransposeAttrs>();
attrs->channels = channels;
attrs->kernel_size = kernel_size;
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->output_padding = std::move(output_padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->data_layout = std::move(data_layout);
attrs->weight_layout = std::move(weight_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv2d_transpose");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.conv2d_transpose")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeConv2DTranspose, args, rv);
});
RELAY_REGISTER_OP("nn.conv2d_transpose")
.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
of a normal convolution, i.e., from something that has the shape of the
output of some convolution to something that has the shape of its input
while maintaining a connectivity pattern that is compatible with
said convolution.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (in_channels, channels, kernel_size[0], kernel_size[1])
- **bias**: (channels,)
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
out_height and out_width are calculated as::
out_height = (height-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0]
out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1]
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Conv2DTranspose", Conv2DTransposeRel);
} // namespace relay
} // namespace tvm
......@@ -57,6 +57,42 @@ def test_conv2d_infer_type():
assert ftype.arg_types[1] == relay.ty.TensorType(
(4, 8, 3, 3, 4, 4), "int8")
def test_conv2d_transpose_infer_type():
# symbolic in batch dimension
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 10, 10, 12
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
w = ib.param("w", relay.ty.IncompleteType())
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d_transpose(x.var, w.var,
kernel_size=(3, 3),
padding=(1, 1),
channels=15))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(n, 15, 10, 12), "float32")
assert ftype.arg_types[1] == relay.ty.TensorType(
(10, 15, 3, 3), "float32")
# infer by shape of w, mixed precision
ib = relay.ir_builder.IRBuilder()
n, c, h, w = tvm.var("n"), 10, 10, 12
x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32"))
w = ib.param("w", relay.ty.TensorType((12, 11, 5, 5), "float32"))
with ib.function(x, w) as func:
ib.ret(relay.nn.conv2d_transpose(x.var, w.var,
output_padding=(1, 1),
channels=11,
data_layout="NHWC"))
ib.ret(func)
func = relay.ir_pass.infer_type(ib.env, func.to_func())
ftype = func.checked_type
assert ftype.ret_type == relay.ty.TensorType(
(n, 15, 15, 11), "float32")
def test_upsampling_infer_type():
ib = relay.ir_builder.IRBuilder()
n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
......@@ -166,3 +202,4 @@ if __name__ == "__main__":
test_pool2d_infer_type()
test_upsampling_infer_type()
test_flatten_infer_type()
test_conv2d_transpose_infer_type()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment