Unverified Commit 646cfc63 by Mahesh Ambule Committed by GitHub

[Relay, TF Frontend] Dilation2D operator support (#5033)

* update docs for dilation 2d

* dilation2d compute

* dilation2d register

* dilation2d rel compute

* dilation2d strategy

* dilation2d attrs

* dilation2d generic schedule

* dilation2d tf frontend support

* dilation2d tf frontend test case

* dilation2d test cases

* pylint fixes

* add exception for cuda target

* Update docstring

* Update docstring

* change rates to dilations

* removed unused param

* merge master

* Update nn.py

* Update nn.py
parent 5eab475d
......@@ -57,6 +57,7 @@ List of operators
topi.nn.relu
topi.nn.leaky_relu
topi.nn.dilate
topi.nn.dilation2d
topi.nn.pool
topi.nn.global_pool
topi.nn.adaptive_pool
......@@ -197,6 +198,7 @@ topi.nn
.. autofunction:: topi.nn.upsampling
.. autofunction:: topi.nn.softmax
.. autofunction:: topi.nn.dense
.. autofunction:: topi.nn.dilation2d
.. autofunction:: topi.nn.batch_matmul
.. autofunction:: topi.nn.log_softmax
.. autofunction:: topi.nn.conv2d_nchw
......
......@@ -140,6 +140,7 @@ Supported Ops
- DecodeJpeg
- DepthwiseConv2dNative
- DepthToSpace
- Dilation2D
- Equal
- Elu
- Enter
......
......@@ -70,6 +70,7 @@ This level enables typical convnet models.
tvm.relay.nn.conv2d
tvm.relay.nn.conv2d_transpose
tvm.relay.nn.dense
tvm.relay.nn.dilation2d
tvm.relay.nn.max_pool2d
tvm.relay.nn.max_pool3d
tvm.relay.nn.avg_pool2d
......@@ -249,6 +250,7 @@ Level 2 Definitions
.. autofunction:: tvm.relay.nn.conv2d
.. autofunction:: tvm.relay.nn.conv2d_transpose
.. autofunction:: tvm.relay.nn.dense
.. autofunction:: tvm.relay.nn.dilation2d
.. autofunction:: tvm.relay.nn.max_pool2d
.. autofunction:: tvm.relay.nn.max_pool3d
.. autofunction:: tvm.relay.nn.avg_pool2d
......
......@@ -156,6 +156,42 @@ struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
};
/*! \brief Attributes used in dilation operators */
struct Dilation2DAttrs : public tvm::AttrsNode<Dilation2DAttrs> {
Array<IndexExpr> strides;
Array<IndexExpr> padding;
Array<IndexExpr> dilations;
std::string data_layout;
std::string kernel_layout;
DataType out_dtype;
TVM_DECLARE_ATTRS(Dilation2DAttrs, "relay.attrs.Dilation2DAttrs") {
TVM_ATTR_FIELD(strides).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the strides of the sliding window. [stride_height, stride_width].");
TVM_ATTR_FIELD(padding).set_default(Array<IndexExpr>({0, 0}))
.describe("If padding is non-zero, then the input is implicitly zero-padded"
"Padding support both symmetric and asymmetric as"
"one int : same padding used on all sides"
"two int : bottom, right will use same padding as top, left"
"four int : padding width in the order of (top, left, bottom, right)");
TVM_ATTR_FIELD(dilations).set_default(Array<IndexExpr>({1, 1}))
.describe("Specifies the dilation rate to use. [dilation_height, dilation_width]");
TVM_ATTR_FIELD(data_layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Convolution is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(kernel_layout).set_default("IHW")
.describe("Dimension ordering of weight. Can be 'IHW', 'HWI', etc."
"'I', 'H', 'W' stands for input_channel, height, and width"
"dimensions respectively.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type, set to explicit type under mixed precision setting");
}
};
/*! \brief Attributes used in winograd weight transformation operators */
struct Conv2DWinogradWeightTransformAttrs :
public tvm::AttrsNode<Conv2DWinogradWeightTransformAttrs> {
......
......@@ -410,6 +410,91 @@ def _conv(opname):
return out
return _impl
# Dilation2d
def _dilation2d():
def _impl(inputs, attr, params):
if 'data_format' not in attr:
attr['data_format'] = 'NHWC'
input_shape = attr['_input_shapes'][inputs[0]]
weights_shape = attr['_input_shapes'][inputs[1]]
if attr['_target_layout'] == "NCHW" and attr['data_format'] == "NHWC":
input_shape = [input_shape[ii] for ii in (0, 3, 1, 2)]
inputs[0] = _op.transpose(inputs[0], axes=(0, 3, 1, 2))
weights_shape = [weights_shape[ii] for ii in (2, 0, 1)]
inputs[1] = _op.transpose(inputs[1], axes=(2, 0, 1))
attr['data_format'] = "NCHW"
if attr['data_format'] in ['NHWC', 'NCHW']:
if 'rates' in attr:
attr['dilations'] = attr['rates']
if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
attr['strides'] = (attr['strides'][1], attr['strides'][2])
else:
msg = 'Value {} in attribute "data_format" of operator Dilation2D is ' \
'not valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
attr['padding'] = attr['padding'].decode("utf-8")
if attr['padding'] == 'VALID':
attr['padding'] = [0, 0]
elif attr['padding'] == 'SAME':
stride_h, stride_w = attr['strides']
if attr['data_format'] == 'NHWC':
kernel_h, kernel_w = weights_shape[0], weights_shape[1]
else:
kernel_h, kernel_w = weights_shape[1], weights_shape[2]
if attr['data_format'] == 'NHWC':
in_h = input_shape[1]
in_w = input_shape[2]
else:
in_h = input_shape[2]
in_w = input_shape[3]
dilation_h = attr['dilations'][0]
dilation_w = attr['dilations'][1]
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
if attr['data_format'] == 'NHWC':
inputs[0] = _op.nn.pad(data=inputs[0],
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs[0] = _op.nn.pad(data=inputs[0],
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))
attr['padding'] = [0, 0]
else:
msg = 'Value {} in attribute "padding" of operator Dilation2d is not ' \
'valid.'
raise tvm.error.OpAttributeInvalid(msg.format(attr['padding']))
attr['kernel_layout'] = 'HWI' if attr['data_format'] == 'NHWC' else 'IHW'
out = AttrCvt(
op_name='dilation2d',
ignores=['explicit_paddings', 'rates'],
transforms={
'data_format': 'data_layout',
})([inputs[0], inputs[1]], attr)
if attr['_target_layout'] == "NCHW":
out = _op.transpose(out, axes=(0, 2, 3, 1))
return out
return _impl
def _conv3d(opname):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
......@@ -1550,6 +1635,7 @@ _convert_map = {
'DecodeJpeg' : _decode_image(),
'DepthwiseConv2dNative' : _conv('depthwise'),
'DepthToSpace' : _depth_to_space(),
'Dilation2D' : _dilation2d(),
'Equal' : _broadcast('equal'),
'Elu' : _elu(),
'Erf' : AttrCvt('erf'),
......
......@@ -186,6 +186,9 @@ def legalize_conv2d_transpose(attrs, inputs, types):
reg.register_strategy("nn.conv3d", strategy.conv3d_strategy)
reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE)
# dilation2d
reg.register_strategy("nn.dilation2d", strategy.dilation2d_strategy)
reg.register_pattern("nn.dilation2d", OpPattern.OUT_ELEMWISE_FUSABLE)
# conv1d_transpose
reg.register_strategy("nn.conv1d_transpose", strategy.conv1d_transpose_strategy)
......
......@@ -2463,3 +2463,60 @@ def adaptive_avg_pool3d(data,
"""
output_size = [] or output_size
return _make.adaptive_avg_pool3d(data, output_size, layout)
def dilation2d(data,
weight,
strides=(1, 1),
padding=(0, 0),
dilations=(1, 1),
data_layout="NCHW",
kernel_layout="IHW",
out_dtype=""):
r"""Dilation 2D.
This operator takes the weight as the dilation kernel and dilates it with
data to produce an output. In the default case, where the data_layout is `NCHW`
and kernel_layout is `OIHW`, dilation2d takes in a data Tensor with shape
`(batch_size, in_channels, height, width)`, and a weight Tensor with shape
`(channels, kernel_height, kernel_width)` to produce an output Tensor
with the following rule:
.. math::
\mbox{out}[b, c, y, x] = \max_{dy, dx}
\mbox{data}[b, c, \mbox{strides}[0] * y + dy, \mbox{strides}[1] * x + dx] +
\mbox{weight}[c, dy, dx]
Padding and dilation are applied to data and weight respectively before the computation.
This operator accepts data layout specification. Semantically, the operator
will convert the layout to the canonical layout
(`NCHW` for data and `IHW` for weight) and perform the computation.
weight : tvm.relay.Expr
The weight expressions.
strides : Optional[Tuple[int]]
The strides of convolution.
padding : Optional[Tuple[int]]
The padding of convolution on both sides of inputs before convolution.
dilations : Optional[Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
data_layout : Optional[str]
Layout of the input.
kernel_layout : Optional[str]
Layout of the weight.
out_dtype : Optional[str]
Specifies the output data type.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.dilation2d(data, weight, strides, padding, dilations, data_layout,
kernel_layout, out_dtype)
......@@ -44,6 +44,9 @@ class Conv2DWinogradWeightTransformAttrs(Attrs):
class Conv2DWinogradNNPACKWeightTransformAttrs(Attrs):
"""Attributes for nn.contrib_conv2d_winograd_nnpack_weight_transform"""
@register_relay_attr_node
class Dilation2DAttrs(Attrs):
"""Attributes for nn.dilation2d"""
@register_relay_attr_node
class GlobalPool2DAttrs(Attrs):
......
......@@ -442,6 +442,57 @@ def conv1d_transpose_strategy(attrs, inputs, out_type, target):
name="conv1d_transpose_ncw.generic")
return strategy
# dilation2d
def wrap_compute_dilation2d(topi_compute, need_data_layout=False):
"""Wrap dilation2d topi compute"""
def _compute_dilation2d(attrs, inputs, out_type):
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilations = get_const_tuple(attrs.dilations)
data_layout = attrs.get_str("data_layout")
out_dtype = attrs.out_dtype
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
args = [inputs[0], inputs[1], strides, padding, dilations]
if need_data_layout:
args.append(data_layout)
args.append(out_dtype)
return [topi_compute(*args)]
return _compute_dilation2d
@override_native_generic_func("dilation2d_strategy")
def dilation2d_strategy(attrs, inputs, out_type, target):
"""dilation2d_strategy generic strategy"""
logger.warning("dilation2d_strategy is not optimized for this platform.")
strategy = _op.OpStrategy()
dilations = get_const_tuple(attrs.dilations)
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
assert layout in ["NCHW", "NHWC"]
(dilation_h, dilation_w) = dilations
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if layout == "NCHW":
assert kernel_layout == "IHW"
strategy.add_implementation(
wrap_compute_dilation2d(topi.nn.dilation2d_nchw),
wrap_topi_schedule(topi.generic.schedule_dilation2d_nchw),
name="dilation2d_nchw.generic")
elif layout == "NHWC":
assert kernel_layout == "HWI"
strategy.add_implementation(
wrap_compute_dilation2d(topi.nn.dilation2d_nhwc),
wrap_topi_schedule(topi.generic.schedule_dilation2d_nhwc),
name="dilation2d_nhwc.generic")
else:
raise RuntimeError("Unsupported dilation2d layout {}".format(layout))
return strategy
# dense
def wrap_compute_dense(topi_compute):
"""wrap dense topi compute"""
......
......@@ -1040,6 +1040,66 @@ Expr MakeDeformableConv2D(Expr data,
TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d")
.set_body_typed(MakeDeformableConv2D);
// relay.nn.dilation2d
TVM_REGISTER_NODE_TYPE(Dilation2DAttrs);
template<typename T>
Array<Array<Layout> > Dilation2DInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
const T* params = attrs.as<T>();
// We always make other operators to fit the layouts of convolution layers
// So this inference ignores all inputs
return Array<Array<Layout> >{{params->data_layout, params->kernel_layout},
{params->data_layout}};
}
// Positional relay function to create dilation2d operator
// used by frontend FFI.
Expr MakeDilation2D(Expr data,
Expr weight,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilations,
std::string data_layout,
std::string kernel_layout,
DataType out_dtype) {
auto attrs = make_object<Dilation2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilations = std::move(dilations);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.dilation2d");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.dilation2d")
.set_body_typed(MakeDilation2D);
RELAY_REGISTER_OP("nn.dilation2d")
.describe(R"code(Computes grayscale dilation of 4D input and 3D filter.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (in_channels, height, width)
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
)code" TVM_ADD_FILELINE)
.set_attrs_type<Dilation2DAttrs>()
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(2)
.add_type_rel("Dilation2D", Dilation2DRel<Dilation2DAttrs>)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Dilation2DInferCorrectLayout<Dilation2DAttrs>);
} // namespace relay
} // namespace tvm
......@@ -360,6 +360,77 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
return true;
}
template <typename AttrType>
bool Dilation2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight = types[1].as<TensorTypeNode>();
if (data == nullptr) return false;
static const Layout kNCHW("NCHW");
static const Layout kOIHW("IHW");
const AttrType* param = attrs.as<AttrType>();
CHECK(param != nullptr);
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Dilation2D only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Dilation2D only support kernel layouts that are convertible from OIHW."
<< " But got " << kernel_layout;
Layout out_layout(param->data_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Dilation2D only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
// use weight to infer the conv shape.
if (weight == nullptr) return false;
auto wshape = trans_kernel_layout.ForwardShape(weight->shape);
channels = wshape[0];
dilated_ksize_y = 1 + (wshape[1] - 1) * param->dilations[0];
dilated_ksize_x = 1 + (wshape[2] - 1) * param->dilations[1];
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
IndexExpr pad_h, pad_w;
GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
if (!dshape_nchw[2].as<tir::AnyNode>()) {
oshape.Set(2, indexdiv(dshape_nchw[2] + pad_h - dilated_ksize_y,
param->strides[0]) + 1);
} else {
oshape.Set(2, dshape_nchw[2]);
}
if (!dshape_nchw[3].as<tir::AnyNode>()) {
oshape.Set(3, indexdiv(dshape_nchw[3] + pad_w - dilated_ksize_x,
param->strides[1]) + 1);
} else {
oshape.Set(3, dshape_nchw[3]);
}
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
oshape = trans_out_layout.BackwardShape(oshape);
// assign output type
reporter->Assign(types[2], TensorType(oshape, out_dtype));
return true;
}
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_NN_CONVOLUTION_H_
......@@ -3037,7 +3037,51 @@ def test_forward_add_n():
_test_forward_add_n(in5)
#######################################################################
def _test_dilation2d(tensor_in_sizes, filter_in_sizes,
strides, dilations, padding):
""" One iteration of dilation2d with given shapes and attributes """
total_size_1 = np.prod(tensor_in_sizes)
total_size_2 = np.prod(filter_in_sizes)
# Initializes the input tensor with array containing incrementing
# numbers from 1.
data_array = [f * 1.0 for f in range(1, total_size_1 + 1)]
filter_array = [f * 1.0 for f in range(1, total_size_2 + 1)]
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=tensor_in_sizes, dtype='float32')
in_filter = constant_op.constant(
filter_array, shape=filter_in_sizes, dtype='float32')
nn_ops.dilation2d(in_data,
in_filter,
strides=strides,
rates=dilations,
padding=padding)
compare_tf_with_tvm(np.reshape(data_array, tensor_in_sizes).astype('float32'),
'Placeholder:0', 'Dilation2D:0', no_gpu=True)
def test_forward_dilation():
_test_dilation2d([1, 18, 18, 32], [4, 4, 32], [1, 1, 1, 1], [1, 2, 1, 1], "VALID")
_test_dilation2d([1, 15, 15, 32], [4, 4, 32], [1, 1, 1, 1], [1, 2, 1, 1], "SAME")
_test_dilation2d([1, 5, 5, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 1, 1], "VALID")
_test_dilation2d([1, 5, 5, 1], [3, 3, 1], [1, 1, 1, 1], [1, 2, 2, 1], "VALID")
_test_dilation2d([1, 5, 5, 3], [3, 3, 3], [1, 1, 1, 1], [1, 1, 1, 1], "SAME")
_test_dilation2d([1, 28, 28, 3], [5, 5, 3], [1, 2, 2, 1], [1, 1, 1, 1], "VALID")
_test_dilation2d([1, 224, 224, 10], [8, 8, 10], [1, 1, 1, 1], [1, 1, 1, 1], "VALID")
_test_dilation2d([1, 18, 18, 32], [4, 4, 32], [1, 1, 1, 1], [1, 2, 1, 1], "SAME")
_test_dilation2d([1, 15, 15, 32], [4, 4, 32], [1, 1, 1, 1], [1, 2, 1, 1], "VALID")
_test_dilation2d([1, 5, 5, 1], [7, 2, 1], [1, 3, 1, 1], [1, 1, 1, 1], "SAME")
_test_dilation2d([1, 5, 5, 1], [3, 4, 1], [1, 2, 1, 1], [1, 2, 2, 1], "SAME")
_test_dilation2d([1, 5, 5, 3], [3, 3, 3], [1, 1, 4, 1], [1, 1, 1, 1], "VALID")
_test_dilation2d([1, 28, 28, 3], [5, 6, 3], [1, 1, 2, 1], [1, 1, 1, 1], "SAME")
_test_dilation2d([1, 224, 224, 10], [8, 8, 10], [1, 3, 1, 1], [1, 1, 1, 1], "SAME")
_test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 2, 2, 1], "SAME")
_test_dilation2d([1, 3, 3, 1], [2, 2, 1], [1, 1, 1, 1], [1, 1, 2, 1], "VALID")
# #######################################################################
# Main
# ----
if __name__ == '__main__':
......@@ -3131,6 +3175,7 @@ if __name__ == '__main__':
test_forward_l2_normalize()
test_forward_space_to_batch_nd()
test_forward_batch_to_space_nd()
test_forward_dilation()
# End to End
test_forward_inception_v3()
......
......@@ -1219,6 +1219,113 @@ def test_depthwise_conv2d_int8():
graph, lib, params = relay.build(func, target, params=parameters)
def test_dilation2d_infer_type():
# symbolic in batch dimension
n, h, w, c = te.var("n"), 224, 224, 10
x = relay.var("x", relay.ty.TensorType((n, c, h, w), "float32"))
kc, kh, kw = 10, 8, 8
w = relay.var("w", relay.ty.TensorType((kc, kw, kh), "float32"))
y = relay.nn.dilation2d(x, w,
# kernel_size=(3, 3),
strides=[1, 1, 1, 1],
dilations=[1, 1, 1, 1],
padding=[0, 0, 0, 0])
yy = run_infer_type(y)
assert yy.checked_type == relay.TensorType(
(n, 10, 217, 217), "float32")
def test_dilation2d_run():
def run_test_dilation2d(indata, kernel, out,
dtype='float32',
strides=[1, 1],
padding=[0, 0],
dilations=[1, 1],
except_targets=['cuda'],
**attrs):
dshape = indata.shape
kshape = kernel.shape
if except_targets is None:
except_targets = []
x = relay.var("x", shape=dshape, dtype=dtype)
w = relay.var("w", shape=kshape, dtype=dtype)
y = relay.nn.dilation2d(x, w,
strides=strides,
dilations=dilations,
padding=padding,
**attrs)
func = relay.Function([x, w], y)
for target, ctx in ctx_list():
if target in except_targets:
continue
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(indata, kernel)
tvm.testing.assert_allclose(op_res.asnumpy(), out, rtol=1e-5, atol=1e-5)
def _convert_data(indata, kernel, out, layout=None):
indata = np.asarray(indata)
kernel = np.asarray(kernel)
out = np.asarray(out)
if layout == 'NCHW':
indata = indata.transpose([0, 3, 1, 2])
kernel = kernel.transpose([2, 0, 1])
out = out.transpose([0, 3, 1, 2])
return indata, kernel, out
image = [[[[.1], [.2]], [[.3], [.4]]]]
kernel = [[[.4], [.3]], [[.1], [.0]]]
out = [[[[.5]]]]
run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'))
run_test_dilation2d(*_convert_data(image, kernel, out), data_layout='NHWC', kernel_layout='HWI')
image = [[[[.1], [.2]], [[.3], [.4]]]]
kernel = [[[.4], [.3]], [[.1], [.0]]]
out = [[[[.5], [.6]], [[.7], [.8]]]]
run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[0, 0, 1, 1])
run_test_dilation2d(*_convert_data(image, kernel, out), padding=[0, 0, 1, 1],
data_layout='NHWC', kernel_layout='HWI')
image = [[[[.1, .2, .0], [.2, .3, .1]], [[.3, .4, .2], [.4, .5, .3]]]]
kernel = [[[.4, .5, .3], [.3, .4, .2]], [[.1, .2, .0], [.0, .1, -.1]]]
out = [[[[.5, .7, .3], [.6, .8, .4]], [[.7, .9, .5], [.8, 1., .6]]]]
run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[0, 0, 1, 1])
run_test_dilation2d(*_convert_data(image, kernel, out), padding=[0, 0, 1, 1],
data_layout='NHWC', kernel_layout='HWI')
image = [[[[.1], [.2]], [[.3], [.4]]], [[[.2], [.3]], [[.4], [.5]]]]
kernel = [[[.4], [.3]], [[.1], [.0]]]
out = [[[[.5], [.6]], [[.7], [.8]]], [[[.6], [.7]], [[.8], [.9]]]]
run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[0, 0, 1, 1])
run_test_dilation2d(*_convert_data(image, kernel, out), padding=[0, 0, 1, 1],
data_layout='NHWC', kernel_layout='HWI')
image = [[[[.1], [.2]], [[.3], [.4]]]]
kernel = [[[.4], [.3]]]
out = [[[[.5]], [[.7]]]]
run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'))
run_test_dilation2d(*_convert_data(image, kernel, out),
data_layout='NHWC', kernel_layout='HWI')
image = [[[[.1], [.2], [.3]], [[.4], [.5], [.6]], [[.7], [.8], [.9]]]]
kernel = [[[.4], [.3]], [[.1], [.2]]]
out = [[[[.7], [.8], [.6]], [[1.0], [1.1], [.9]], [[.8], [.9], [.9]]]]
run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), padding=[1, 1], dilations=[2, 2])
run_test_dilation2d(*_convert_data(image, kernel, out), padding=[1, 1], dilations=[2, 2],
data_layout='NHWC', kernel_layout='HWI')
image = [[[[.1], [.2], [.3], [.4]], [[.5], [.6], [.7], [.8]],
[[.9], [1.0], [1.1], [1.2]]]]
kernel = [[[.4], [.3]], [[.1], [.2]]]
out = [[[[.8], [1.0]], [[1.2], [1.4]]]]
run_test_dilation2d(*_convert_data(image, kernel, out, layout='NCHW'), strides=[1, 2])
run_test_dilation2d(*_convert_data(image, kernel, out), strides=[1, 2],
data_layout='NHWC', kernel_layout='HWI')
def test_bitserial_conv2d_infer_type():
# Basic shape test with ambiguous batch.
n, c, h, w = te.size_var("n"), 32, 224, 224
......@@ -1274,3 +1381,5 @@ if __name__ == "__main__":
test_upsampling3d()
test_conv2d_int8_intrinsics()
test_depthwise_conv2d_int8()
test_dilation2d_infer_type()
test_dilation2d_run()
......@@ -648,3 +648,33 @@ def schedule_batch_matmul(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)
def schedule_dilation2d_nchw(outs):
"""Schedule for dilation2d
Parameters
----------
outs : Array of Tensor
The computation graph description of dilation2d
in the format of an array of tensors.
Returns
-------
sch : Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
def schedule_dilation2d_nhwc(outs):
"""Schedule for dilation2d
Parameters
----------
outs : Array of Tensor
The computation graph description of dilation2d
in the format of an array of tensors.
Returns
-------
sch : Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
......@@ -24,6 +24,7 @@ from .conv2d import *
from .conv3d import *
from .deformable_conv2d import *
from .depthwise_conv2d import *
from .dilation2d import *
from .elemwise import *
from .dilate import *
from .flatten import *
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, too-many-locals
# pylint: disable=unused-argument, redefined-builtin
"""Dilation2D operators"""
from __future__ import absolute_import as _abs
from tvm import te
from topi.util import simplify
from .pad import pad
from .util import get_pad_tuple
def dilation2d_nchw(input, filter, stride, padding, dilations, out_dtype=None):
"""Dilation2D operator in NCHW layout.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
filter : tvm.Tensor
3-D with shape [ in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size
dilations: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
out_dtype : Optional[str]
Specifies the output data type.
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, in_channel, out_height, out_width]
"""
if out_dtype is None:
out_dtype = input.dtype
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(dilations, int) or len(dilations) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(dilations, int):
dilation_h = dilation_w = dilations
else:
dilation_h, dilation_w = dilations
batch, in_channel, in_height, in_width = input.shape
channel, kernel_h, kernel_w = filter.shape
assert in_channel.value == channel.value, \
"For Dilation2D input and filter channels should be same."
# compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
# compute graph
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
temp = pad(input, pad_before, pad_after, name="pad_temp")
ry = te.reduce_axis((0, kernel_h), name='ry')
rx = te.reduce_axis((0, kernel_w), name='rx')
return te.compute(
(batch, in_channel, out_height, out_width),
lambda nn, ff, yy, xx: te.max(
temp[nn, ff, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w].astype(out_dtype) +
filter[ff, ry, rx].astype(out_dtype),
axis=[ry, rx]), tag="dilation2d_nchw")
def dilation2d_nhwc(input, filter, stride, padding, dilations, out_dtype=None):
"""Dilation2D operator in NHWC layout.
Parameters
----------
input : tvm.Tensor
4-D with shape [batch, in_height, in_width, in_channel]
filter : tvm.Tensor
3-D with shape [filter_height, filter_width, in_channel]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int
Padding size
dilations: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
out_dtype : Optional[str]
Specifies the output data type.
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_height, out_width, in_channel]
"""
if out_dtype is None:
out_dtype = input.dtype
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(dilations, int) or len(dilations) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(dilations, int):
dilation_h = dilation_w = dilations
else:
dilation_h, dilation_w = dilations
batch, in_height, in_width, in_channel = input.shape
kernel_h, kernel_w, channel = filter.shape
assert in_channel.value == channel.value, \
"For Dilation2D input and filter channels should be same."
# compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w))
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
pad_before = [0, pad_top, pad_left, 0]
pad_after = [0, pad_down, pad_right, 0]
padded_input = pad(input, pad_before, pad_after, name="padded_input")
ry = te.reduce_axis((0, kernel_h), name='ry')
rx = te.reduce_axis((0, kernel_w), name='rx')
return te.compute(
(batch, out_height, out_width, in_channel),
lambda nn, yy, xx, ff: te.max(
padded_input[nn, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w, ff].astype(out_dtype) +
filter[ry, rx, ff].astype(out_dtype),
axis=[ry, rx]), tag="dilation2d_nhcw")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment