Commit 8bd9d4d5 by Josh Fromm Committed by Thierry Moreau

[Relay] SpaceToDepth and MirrorPad Operators (#3718)

* Added relay and topi mirror_pad operator.

* Added mirror_padding to tensorflow frontend.

* Added mirrorpad testing in tensorflow frontent.

* Added space_to_depth in tf frontend.

* Added tests for spacetodepth.

* spacetodepth bug fix.

* Lint fix

* Added mirror pad python attrs.

* Pad code formatting.

* Syntax improvement

* Hopefully last lint fix
parent d482512d
......@@ -411,6 +411,19 @@ struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
}
};
/*! \brief Attributes used for the MirrorPadding operator */
struct MirrorPadAttrs : public tvm::AttrsNode<MirrorPadAttrs> {
std::string mode;
Array<Array<IndexExpr> > pad_width;
TVM_DECLARE_ATTRS(MirrorPadAttrs, "relay.attrs.MirrorPadAttrs") {
TVM_ATTR_FIELD(mode).set_default("SYMMETRIC")
.describe("Specifies how mirroring should be performed.");
TVM_ATTR_FIELD(pad_width)
.describe("Number of values padded to the edges of each axis, "
"in the format of ((before_1, after_1), ..., (before_N, after_N))");
}
};
/*! \brief Attributes for leaky relu operator */
struct LeakyReluAttrs : public tvm::AttrsNode<LeakyReluAttrs> {
......
......@@ -569,6 +569,44 @@ def _depth_to_space():
return _impl
def _space_to_depth():
def _impl(inputs, attr, params):
# Need to handle data layouts differently.
input_shape = attr['_input_shapes'][inputs[0]]
block_size = int(attr['block_size'])
if attr['data_format'].decode("utf-8") == 'NHWC':
in_n, in_h, in_w, in_c = input_shape
new_h = int(in_h / block_size)
new_w = int(in_w / block_size)
# First expand input to larger dimension.
expanded = _op.reshape(
inputs[0], newshape=(in_n, new_h, block_size, new_w, block_size, in_c))
# Now reorder to expand spatial blocks.
transposed = _op.transpose(expanded, axes=(0, 1, 3, 2, 4, 5))
# Finally reshape to proper output.
new_c = in_c * block_size * block_size
newshape = (in_n, new_h, new_w, new_c)
else: # Handle NCHW layout
in_n, in_c, in_h, in_w = input_shape
new_h = int(in_h / block_size)
new_w = int(in_w / block_size)
expanded = _op.reshape(
inputs[0], newshape=(in_n, in_c, new_h, block_size, new_w, block_size))
transposed = _op.transpose(expanded, axes=(0, 3, 5, 1, 2, 4))
new_c = int(in_c * block_size * block_size)
newshape = (in_n, new_c, new_h, new_w)
return AttrCvt(
op_name="reshape",
extras={'newshape': newshape},
ignores=['data_format', 'block_size'])([transposed], attr)
return _impl
def _bias_add():
def _impl(inputs, attr, params):
# Must expand for proper broadcasting in NCHW.
......@@ -851,6 +889,19 @@ def _pad(name):
ignores=['Tpaddings'],)(new_inputs, attr)
return _impl
def _mirror_pad():
def _impl(inputs, attr, params):
padlist = _get_param(params, inputs[1])
paddings = tuple(tuple(l) for l in padlist)
attr['pad_width'] = paddings
mode = attr['mode'].decode('utf-8')
attr['mode'] = mode
new_inputs = [inputs[0]]
return AttrCvt(
op_name='mirror_pad',
ignores=['Tpaddings'],)(new_inputs, attr)
return _impl
def _transpose():
def _impl(inputs, attr, params):
# If perm is not specified, axes is left empty,
......@@ -1208,6 +1259,7 @@ _convert_map = {
'Mean' : _mean(),
'Min' : _reduce('min'),
'Minimum' : _elemwise('minimum'),
'MirrorPad' : _mirror_pad(),
'Mod' : _elemwise('mod'),
'Mul' : _elemwise('multiply'),
'Neg' : AttrCvt('negative'),
......@@ -1240,6 +1292,7 @@ _convert_map = {
'Softmax' : _softmax(),
'Softplus' : _softplus(),
'SpaceToBatchND' : _space_to_batch_nd(),
'SpaceToDepth' : _space_to_depth(),
'Split' : _split(False),
'SplitV' : _split(True),
'Sqrt' : AttrCvt('sqrt'),
......
......@@ -378,6 +378,16 @@ def schedule_upsampling(_, outs, target):
# pad
reg.register_schedule("nn.pad", schedule_broadcast)
# mirror_pad
reg.register_schedule("nn.mirror_pad", schedule_broadcast)
@reg.register_compute("nn.mirror_pad")
def compute_mirror_pad(attrs, inputs, out_dtype, target):
pad_before, pad_after = list(zip(*attrs.pad_width))
mode = attrs.mode
out = topi.nn.mirror_pad(inputs[0], pad_before=pad_before, pad_after=pad_after, mode=mode)
return [out]
# winograd related operators
@reg.register_compute("nn.contrib_conv2d_winograd_without_weight_transform")
def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_dtype, target):
......
......@@ -689,6 +689,32 @@ def pad(data,
return _make.pad(data, pad_width, pad_value)
def mirror_pad(data,
pad_width,
mode="SYMMETRIC"):
r"""MirrorPadding
This operator takes in a tensor and pads each axis by the specified
widths using mirroring of the border pixels.
Parameters
----------
data: tvm.relay.Expr
The input data to the operator
pad_width: tuple of <tuple of <int>>, required
Number of values padded to the edges of each axis, in the format
of ((before_1, after_1), ..., (before_N, after_N))
mode: string, optional, default='SYMMETRIC'
What type of mirroring to use, must be SYMMETRIC or REFLECT.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.mirror_pad(data, pad_width, mode)
def lrn(data, size=5, axis=1, bias=2, alpha=.00001, beta=0.75):
"""This operator takes data as input and does local response normalization.
......
......@@ -62,6 +62,9 @@ class UpSamplingAttrs(Attrs):
class PadAttrs(Attrs):
"""Attributes for nn.pad"""
@register_relay_attr_node
class MirrorPadAttrs(Attrs):
"""Attributes for nn.mirror_pad"""
@register_relay_attr_node
class LeakyReluAttrs(Attrs):
......
......@@ -129,5 +129,77 @@ RELAY_REGISTER_OP("nn.pad")
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>("FTVMCompute", PadCompute);
// relay.nn.mirror_pad
TVM_REGISTER_NODE_TYPE(MirrorPadAttrs);
bool MirrorPadRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) return false;
const MirrorPadAttrs* param = attrs.as<MirrorPadAttrs>();
CHECK(param != nullptr);
// check that pad widths match lengths
CHECK(data->shape.size() == param->pad_width.size())
<< "There should be as many pad width pairs as shape dimensions "
<< "but the shape has " << data->shape.size() << " dimensions "
<< "and there are " << param->pad_width.size() << " pad width pairs.";
// each pad width element should be a pair of positive integers
std::vector<IndexExpr> oshape;
for (size_t i = 0; i < param->pad_width.size(); i++) {
CHECK(param->pad_width[i].size() == 2)
<< "Each pad width element should be a pair but at index " << i
<< " there are " << param->pad_width[i].size() << " elements.";
auto width1 = as_const_int(param->pad_width[i][0]);
auto width2 = as_const_int(param->pad_width[i][1]);
CHECK(width1 != nullptr);
CHECK(width2 != nullptr);
CHECK(*width1 >= 0)
<< "Param width elements should be positive but first pad width at "
<< "index " << i << " is " << *width1 << ".";
CHECK(*width2 >= 0)
<< "Param width elements should be positive but first pad width at "
<< "index " << i << " is " << *width2 << ".";
auto padding = make_const(data->shape[i].type(), *width1 + *width2);
oshape.push_back(data->shape[i] + padding);
}
reporter->Assign(types[1], TensorTypeNode::make(Array<IndexExpr>(oshape),
data->dtype));
return true;
}
// Handler to create a call to the padding op used by front-end FFI
Expr MakeMirrorPad(Expr data, Array<Array<IndexExpr> > pad_width, std::string mode) {
auto attrs = make_node<MirrorPadAttrs>();
attrs->mode = mode;
attrs->pad_width = std::move(pad_width);
static const Op& op = Op::Get("nn.mirror_pad");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.mirror_pad")
.set_body_typed(MakeMirrorPad);
RELAY_REGISTER_OP("nn.mirror_pad")
.describe(R"code(MirrorPad for n-D tensor.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.MirrorPadAttrs")
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("MirrorPad", MirrorPadRel)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
} // namespace tvm
......@@ -469,6 +469,22 @@ def test_forward_depthtospace():
_test_depthtospace(np.random.normal(size=[1, 32, 32, 4]), 2)
_test_depthtospace(np.random.normal(size=[1, 16, 8, 32]), 4)
#######################################################################
# SpaceToDepth
# ------------
def _test_spacetodepth(data, block_size):
""" One iteration of space_to_depth operation with given data and block size """
with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype)
array_ops.space_to_depth(in_data, block_size)
compare_tf_with_tvm(data, 'Placeholder:0', 'SpaceToDepth:0')
def test_forward_spacetodepth():
_test_spacetodepth(np.random.normal(size=[1, 32, 32, 4]), 2)
_test_spacetodepth(np.random.normal(size=[1, 16, 8, 32]), 4)
#######################################################################
# Squeeze
......@@ -1330,6 +1346,8 @@ def _test_pad(input_shape, paddings, mode, **kwargs):
out_name = 'PadV2:0'
else:
out_name = 'Pad:0'
else:
out_name = 'MirrorPad:0'
compare_tf_with_tvm(x, 'Placeholder:0', out_name)
......@@ -1337,6 +1355,8 @@ def test_forward_pad():
""" Pad """
_test_pad((2, 3), [[1, 1], [2, 2]], mode="CONSTANT")
_test_pad((2, 3), [[1, 1], [2, 2]], mode="CONSTANT", constant_values=1.0)
_test_pad((2, 3), [[1, 1], [2, 2]], mode="SYMMETRIC")
_test_pad((2, 3), [[1, 1], [2, 2]], mode="REFLECT")
#######################################################################
# Logical operators
......@@ -2144,6 +2164,7 @@ if __name__ == '__main__':
test_forward_transpose()
test_forward_reshape()
test_forward_depthtospace()
test_forward_spacetodepth()
test_forward_squeeze()
test_forward_pack()
test_forward_size()
......
......@@ -21,3 +21,4 @@ from .bitserial_dense import *
from .l2_normalize import *
from .batch_matmul import *
from .sparse import *
from .pad import *
......@@ -74,3 +74,72 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
return tvm.if_then_else(not_zero, data(*index_tuple), pad_value)
return data(*index_tuple)
return tvm.compute(out_shape, _pad, name=name)
@tvm.tag_scope(tag=tag.INJECTIVE + ",pad")
def mirror_pad(data,
pad_before,
pad_after=None,
mode='SYMMETRIC',
name="MirrorPadInput"):
"""Pad Input with mirroring either symmetric or reflected.
Parameters
----------
data : tvm.Tensor
n-D input, can be any layout.
pad_before : list / tuple of n ints
Pad width on each dimension to pad the before the axis begin.
pad_after : list / tuple of n ints, optional
Pad width each dimension to pad the after the axis end.
mode: str, optional
Type of mirror padding to apply. Must be SYMMETRIC or REFLECT
name : str, optional
The name prefix operators generated
Returns
-------
Output : tvm.Tensor
n-D, the same layout as Input.
"""
n = len(data.shape)
pad_after = pad_after if pad_after else pad_before
if len(pad_before) != n:
raise ValueError("Input dimension and pad_before dismatch : %d vs %d" %
(n, len(pad_before)))
if len(pad_after) != n:
raise ValueError("Input dimension and pad_after dismatch : %d vs %d" %
(n, len(pad_before)))
out_shape = tuple(
tvm.ir_pass.Simplify((data.shape[i] + pad_before[i] + pad_after[i]))
for i in range(n))
assert mode in ('SYMMETRIC', 'REFLECT')
mode = int(mode == 'SYMMETRIC')
def _pad(*indices):
index_tuple = []
above = []
below = []
for i in range(n):
if equal_const_int(pad_before[i], 0) and equal_const_int(
pad_after[i], 0):
index_tuple.append(indices[i])
above.append(False)
below.append(False)
else:
index_tuple.append(indices[i] - pad_before[i])
above.append(indices[i] >= data.shape[i] + pad_before[i])
below.append(indices[i] < pad_before[i])
mapped_tuple = []
for i, axis in enumerate(index_tuple):
mapped_axis = tvm.if_then_else(below[i], -axis - mode, axis)
mapped_axis = tvm.if_then_else(
above[i], (2 * (data.shape[i] - 1)) - axis + mode, mapped_axis)
mapped_tuple.append(mapped_axis)
return data(*mapped_tuple)
return tvm.compute(out_shape, _pad, name=name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment