Commit 56416ed0 by Yong Wu Committed by Zhi

[TOPI][RELAY][OP] add op crop_and_resize (#4417)

* [TOPI][RELAY][OP] add op crop_and_resize

* fix pylint

* incorporate comments

* fix ci
parent 12e51e6c
...@@ -104,6 +104,7 @@ List of operators ...@@ -104,6 +104,7 @@ List of operators
topi.ndarray_size topi.ndarray_size
topi.layout_transform topi.layout_transform
topi.image.resize topi.image.resize
topi.image.crop_and_resize
topi.argsort topi.argsort
topi.topk topi.topk
topi.sequence_mask topi.sequence_mask
...@@ -207,6 +208,7 @@ topi.nn ...@@ -207,6 +208,7 @@ topi.nn
topi.image topi.image
~~~~~~~~~~ ~~~~~~~~~~
.. autofunction:: topi.image.resize .. autofunction:: topi.image.resize
.. autofunction:: topi.image.crop_and_resize
topi.sparse topi.sparse
~~~~~~~~~~~ ~~~~~~~~~~~
......
...@@ -169,6 +169,7 @@ This level enables additional math and transform operators. ...@@ -169,6 +169,7 @@ This level enables additional math and transform operators.
:nosignatures: :nosignatures:
tvm.relay.image.resize tvm.relay.image.resize
tvm.relay.image.crop_and_resize
tvm.relay.vision.multibox_prior tvm.relay.vision.multibox_prior
tvm.relay.vision.multibox_transform_loc tvm.relay.vision.multibox_transform_loc
tvm.relay.vision.nms tvm.relay.vision.nms
...@@ -335,6 +336,7 @@ Level 4 Definitions ...@@ -335,6 +336,7 @@ Level 4 Definitions
Level 5 Definitions Level 5 Definitions
------------------- -------------------
.. autofunction:: tvm.relay.image.resize .. autofunction:: tvm.relay.image.resize
.. autofunction:: tvm.relay.image.crop_and_resize
.. autofunction:: tvm.relay.vision.multibox_prior .. autofunction:: tvm.relay.vision.multibox_prior
.. autofunction:: tvm.relay.vision.multibox_transform_loc .. autofunction:: tvm.relay.vision.multibox_transform_loc
.. autofunction:: tvm.relay.vision.nms .. autofunction:: tvm.relay.vision.nms
......
...@@ -63,6 +63,34 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> { ...@@ -63,6 +63,34 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
} }
}; };
/*! \brief Attributes used in image crop_and_resize operator */
struct CropAndResizeAttrs : public tvm::AttrsNode<CropAndResizeAttrs> {
Array<IndexExpr> crop_size;
std::string layout;
std::string method;
double extrapolation_value;
DataType out_dtype;
TVM_DECLARE_ATTRS(CropAndResizeAttrs, "relay.attrs.CropAndResizeAttrs") {
TVM_ATTR_FIELD(crop_size).set_default(NullValue<Array<IndexExpr> >())
.describe("Target Size.");
TVM_ATTR_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Resize is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(method).set_default("bilinear")
.describe("Specify the mode to use for scaling."
"nearest_neighbor - Nearest Neighbor"
"bilinear - Bilinear Interpolation");
TVM_ATTR_FIELD(extrapolation_value).set_default(0.0)
.describe("Specify value for extrapolation.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type.");
}
};
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_ATTRS_IMAGE_H_ #endif // TVM_RELAY_ATTRS_IMAGE_H_
...@@ -546,47 +546,20 @@ def _crop_and_resize(): ...@@ -546,47 +546,20 @@ def _crop_and_resize():
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth] # input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2] # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
try: try:
boxes = _get_list_param(params, inputs[1])
box_ind = _get_list_param(params, inputs[2])
crop_size = _get_list_param(params, inputs[3]) crop_size = _get_list_param(params, inputs[3])
except (IndexError, KeyError): except (IndexError, KeyError):
boxes = _infer_value(inputs[1], params).asnumpy().tolist()
box_ind = _infer_value(inputs[2], params).asnumpy().tolist()
crop_size = _infer_value(inputs[3], params).asnumpy().tolist() crop_size = _infer_value(inputs[3], params).asnumpy().tolist()
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape)
method = attr['method'].decode() method = attr['method'].decode()
method = 'nearest_neighbor' if method == 'nearest' else method
attrs = {} if method not in ['bilinear', 'nearest_neighbor']:
attrs['size'] = crop_size
attrs['layout'] = 'NHWC'
if method.lower() == 'nearest':
raise tvm.error.OpAttributeUnImplemented( raise tvm.error.OpAttributeUnImplemented(
'Attribute method=nearest is not supported') 'Method {} is not supported'.format(method))
else: layout = attr['layout'] if 'layout' in attr else 'NHWC'
attrs['coordinate_transformation_mode'] = 'align_corners' extrapolation_value = attr['extrapolation_value']
attrs['method'] = 'bilinear'
return get_relay_op("crop_and_resize")(inputs[0], inputs[1], inputs[2], crop_size,
out = None layout, method, extrapolation_value)
begin = [0] * data_dim
size = data_shape[:]
for idx in box_ind:
# 1) Crop
# y is mapped to the image coordinate at y * (image_height - 1)
# x is mapped to the image coordinate at x * (image_width - 1)
begin[0] = idx
begin[1] = int(round(boxes[idx][0] * (data_shape[1] - 1)))
begin[2] = int(round(boxes[idx][1] * (data_shape[2] - 1)))
size[0] = idx + 1
size[1] = int(round((data_shape[1] - 1) * boxes[idx][2])) + 1
size[2] = int(round((data_shape[2] - 1) * boxes[idx][3])) + 1
res_crop = _op.strided_slice(inputs[0], begin=begin, end=size)
# 2) Resize
res_resize = get_relay_op('resize')(res_crop, **attrs)
out = _op.concatenate([out, res_resize], axis=0) if out else res_resize
return out
return _impl return _impl
def _cast(): def _cast():
......
...@@ -25,7 +25,6 @@ from ..op import schedule_injective ...@@ -25,7 +25,6 @@ from ..op import schedule_injective
# resize # resize
reg.register_schedule("image.resize", schedule_injective) reg.register_schedule("image.resize", schedule_injective)
@reg.register_compute("image.resize") @reg.register_compute("image.resize")
def compute_resize(attrs, inputs, out_type, target): def compute_resize(attrs, inputs, out_type, target):
size = attrs.size size = attrs.size
...@@ -34,3 +33,18 @@ def compute_resize(attrs, inputs, out_type, target): ...@@ -34,3 +33,18 @@ def compute_resize(attrs, inputs, out_type, target):
coord_trans = attrs.coordinate_transformation_mode coord_trans = attrs.coordinate_transformation_mode
out_dtype = attrs.out_dtype out_dtype = attrs.out_dtype
return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)] return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)]
# crop and resize
reg.register_schedule("image.crop_and_resize", schedule_injective)
@reg.register_compute("image.crop_and_resize")
def compute_crop_and_resize(attrs, inputs, out_type, target):
crop_size = attrs.crop_size
layout = attrs.layout
method = attrs.method
extrapolation_value = attrs.extrapolation_value
out_dtype = attrs.out_dtype
return [topi.image.crop_and_resize(inputs[0], inputs[1], inputs[2],
crop_size, layout, method,
extrapolation_value, out_dtype)]
...@@ -31,7 +31,7 @@ def resize(data, ...@@ -31,7 +31,7 @@ def resize(data,
with data of shape (n, c, h, w) with data of shape (n, c, h, w)
out will have a shape (n, c, size[0], size[1]) out will have a shape (n, c, size[0], size[1])
method indicates the algorithm to be used while calculating ghe out value method indicates the algorithm to be used while calculating the out value
and method can be one of ("bilinear", "nearest_neighbor", "bicubic") and method can be one of ("bilinear", "nearest_neighbor", "bicubic")
Parameters Parameters
...@@ -63,3 +63,53 @@ def resize(data, ...@@ -63,3 +63,53 @@ def resize(data,
The resized result. The resized result.
""" """
return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype) return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype)
def crop_and_resize(data,
boxes,
box_indices,
crop_size,
layout,
method="bilinear",
extrapolation_value=0,
out_dtype=None):
"""Crop input images and resize them.
method indicates the algorithm to be used while calculating the out value
and method can be either "bilinear" or "nearest_neighbor".
Parameters
----------
data : relay.Expr
The input data to the operator.
boxes : relay.Expr
A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies
the coordinates of a box.
box_indices : relay.Expr
A 1-D tensor of shape [num_boxes], box_ind[i] specifies the data that
the i-th box refers to.
crop_size : Tuple of Expr
The target size to which each box will be resized.
layout : str, optional
Layout of the input.
method : str, optional
Scale method, it can be either "nearest_neighbor" or "bilinear".
extrapolation_value : float, optional
Value used for extrapolation, when applicable.
out_dtype : str, optional
Type to return. If left None returns the same type as input.
Returns
-------
result: relay.Expr
The computed result.
"""
return _make.crop_and_resize(data, boxes, box_indices, crop_size,
layout, method, extrapolation_value, out_dtype)
...@@ -114,6 +114,9 @@ class DeformableConv2DAttrs(Attrs): ...@@ -114,6 +114,9 @@ class DeformableConv2DAttrs(Attrs):
class ResizeAttrs(Attrs): class ResizeAttrs(Attrs):
"""Attributes for image.resize""" """Attributes for image.resize"""
@register_relay_attr_node
class CropAndResizeAttrs(Attrs):
"""Attributes for image.crop_and_resize"""
@register_relay_attr_node @register_relay_attr_node
class ArgsortAttrs(Attrs): class ArgsortAttrs(Attrs):
......
...@@ -246,8 +246,8 @@ PrimExpr div(PrimExpr a, PrimExpr b) { ...@@ -246,8 +246,8 @@ PrimExpr div(PrimExpr a, PrimExpr b) {
} }
PrimExpr truncdiv(PrimExpr a, PrimExpr b) { PrimExpr truncdiv(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
CHECK(b.dtype().is_int() || b.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
return div(a, b); return div(a, b);
} }
...@@ -276,8 +276,8 @@ PrimExpr indexmod(PrimExpr a, PrimExpr b) { ...@@ -276,8 +276,8 @@ PrimExpr indexmod(PrimExpr a, PrimExpr b) {
} }
PrimExpr floordiv(PrimExpr a, PrimExpr b) { PrimExpr floordiv(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
CHECK(b.dtype().is_int() || b.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<ir::FloorDivNode>(a, b); PrimExpr ret = arith::TryConstFold<ir::FloorDivNode>(a, b);
if (ret.defined()) return ret; if (ret.defined()) return ret;
...@@ -285,8 +285,8 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b) { ...@@ -285,8 +285,8 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b) {
} }
PrimExpr floormod(PrimExpr a, PrimExpr b) { PrimExpr floormod(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint()); CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
CHECK(b.dtype().is_int() || b.dtype().is_uint()); CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
BinaryOpMatchTypes(a, b); BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<ir::FloorModNode>(a, b); PrimExpr ret = arith::TryConstFold<ir::FloorModNode>(a, b);
if (ret.defined()) return ret; if (ret.defined()) return ret;
......
...@@ -109,5 +109,89 @@ RELAY_REGISTER_OP("image.resize") ...@@ -109,5 +109,89 @@ RELAY_REGISTER_OP("image.resize")
.add_type_rel("Resize", ResizeRel) .add_type_rel("Resize", ResizeRel)
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
TVM_REGISTER_NODE_TYPE(CropAndResizeAttrs);
bool CropAndResizeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
const auto* boxes = types[1].as<TensorTypeNode>();
const auto* box_indices = types[2].as<TensorTypeNode>();
if (data == nullptr || boxes == nullptr ||
box_indices == nullptr) return false;
const CropAndResizeAttrs* param = attrs.as<CropAndResizeAttrs>();
CHECK(param != nullptr);
auto crop_size = param->crop_size;
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
// 4-D tensor of shape [num_boxes, crop_height, crop_width, depth]
static const Layout kNCHW("NCHW");
const Layout in_layout(param->layout);
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(0, box_indices->shape[0]);
oshape.Set(2, crop_size[0]);
oshape.Set(3, crop_size[1]);
auto bshape = layout_converter.BackwardShape(oshape);
// assign output type
reporter->Assign(types[3],
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
out_dtype));
return true;
}
Expr MakeCropAndResize(Expr data,
Expr boxes,
Expr box_indices,
Array<IndexExpr> crop_size,
std::string layout,
std::string method,
double extrapolation_value,
DataType out_dtype) {
auto attrs = make_object<CropAndResizeAttrs>();
attrs->crop_size = std::move(crop_size);
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->extrapolation_value = std::move(extrapolation_value);
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("image.crop_and_resize");
return CallNode::make(op, {data, boxes, box_indices}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize")
.set_body_typed(MakeCropAndResize);
RELAY_REGISTER_OP("image.crop_and_resize")
.describe(R"code(Perform crop and resize to input array with nearest neighbour or bilinear interpolation.
- **data**: data is 4D array of shape
(batch_size, channels, in_height, in_width) for NCHW
(batch_size, in_height, in_width, channels) for NHWC
- **out**: Output is 4D array of shape
for layout NCHW
(batch_size, channels, crop_size[0], crop_size[1])
for layout NHWC
(batch_size, crop_size[0], crop_size[1], channels)
)code" TVM_ADD_FILELINE)
.set_num_inputs(3)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("boxes", "Tensor", "The boxes tensor.")
.add_argument("box_indices", "Tensor", "The box indices tensor.")
.set_attrs_type<CropAndResizeAttrs>()
.set_support_level(5)
.add_type_rel("CropAndResize", CropAndResizeRel)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -1706,39 +1706,47 @@ def test_forward_crop(): ...@@ -1706,39 +1706,47 @@ def test_forward_crop():
# CropAndResize # CropAndResize
# ------------- # -------------
def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size, method='bilinear', dtype="float32"): def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size,
extrapolation_value=0.0, method='bilinear', dtype="float32"):
image = np.random.uniform(0, 10, size=img_shape).astype(dtype) image = np.random.uniform(0, 10, size=img_shape).astype(dtype)
tf.reset_default_graph() tf.reset_default_graph()
in_data = tf.placeholder(dtype, image.shape, name="in_data") in_data = tf.placeholder(dtype, image.shape, name="in_data")
tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx, crop_size=crop_size, tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx,
method=method, name="crop_and_resize") crop_size=crop_size, method=method,
extrapolation_value=extrapolation_value,
name="crop_and_resize")
compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0') compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0')
def test_forward_crop_and_resize(): def test_forward_crop_and_resize():
""" CropAndResize """ """ CropAndResize """
_test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, 1, 1]], [0], [5, 5]) _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3])
_test_forward_crop_and_resize( _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3], 0.2)
[1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5]) _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3], 0.2, 'nearest')
_test_forward_crop_and_resize( _test_forward_crop_and_resize([1, 11, 11, 3], [[.3, .3, 1, 1]], [0], [21, 21])
[1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5]) _test_forward_crop_and_resize([1, 41, 41, 3], [[.2, .4, .8, .8]], [0], [21, 11])
_test_forward_crop_and_resize( _test_forward_crop_and_resize([1, 100, 100, 3], [[ 0, 0, .9, .9]], [0], [30, 30])
[1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4]) _test_forward_crop_and_resize([1, 224, 224, 3], [[.1, .2, 1, 1]], [0], [9, 9])
_test_forward_crop_and_resize( _test_forward_crop_and_resize([1, 249, 249, 3], [[ 0, 0, 1, 1]], [0], [9, 9])
[1, 41, 41, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3]) _test_forward_crop_and_resize([1, 201, 301, 3], [[.2, .3, .7, .8]], [0], [51, 51])
_test_forward_crop_and_resize([10, 11, 11, 3], _test_forward_crop_and_resize(img_shape=[10, 11, 11, 3],
[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]], boxes=[[ 0, 0, .9, .9],
[0, 1], [.2, .2, .8, .8]],
[5, 5]) box_idx=[0, 1], crop_size=[5, 5])
_test_forward_crop_and_resize([3, 11, 11, 3], _test_forward_crop_and_resize(img_shape=[20, 576, 576, 3],
[[0, 0, 0.9, 0.9], [ boxes=[[ 0, 0, 1, 1],
0.2, 0.2, 0.8, 0.8], [0, 0, 1, 1]], [ 0, 0, .8, .8],
[0, 1, 2], [.1, .2, .9, 1],
[3, 3]) [.2, 0, 1, 1]],
_test_forward_crop_and_resize([3, 11, 11, 3], box_idx=[1, 0, 2, 3], crop_size=[24, 24],
[[0, 0, 1, 0.8], [0, 0, 0.9, 0.9], [0, 0, 1, 0.8]], extrapolation_value=0.3)
[2, 1, 0], _test_forward_crop_and_resize(img_shape=[20, 229, 229, 3],
[3, 3]) boxes=[[ 0, 0, .9, .9],
[.3, .3, 1, 1],
[.2, .1, .7, .8],
[ 0, 0, 1, 1]],
box_idx=[3, 0, 2, 1], crop_size=[58, 58],
extrapolation_value=0.2, method='nearest')
####################################################################### #######################################################################
......
...@@ -72,6 +72,47 @@ def test_resize(): ...@@ -72,6 +72,47 @@ def test_resize():
for layout in ["NHWC", "NCHW"]: for layout in ["NHWC", "NCHW"]:
verify_resize((1, 4, 4, 4), 2, method, layout) verify_resize((1, 4, 4, 4), 2, method, layout)
def test_crop_and_resize():
def verify_crop_and_resize(img_shape, boxes, box_indices, crop_size,
layout, method, extrapolation_value=0.0):
image_data = np.random.uniform(size=img_shape).astype("float32")
ref_res = topi.testing.crop_and_resize_python(image_data,
boxes,
box_indices,
crop_size,
layout, method,
extrapolation_value)
img = relay.var("img", relay.TensorType(img_shape, 'float32'))
bx = relay.var('bx', relay.TensorType(boxes.shape, 'float32'))
bx_idx = relay.var('bx_idx', relay.TensorType(box_indices.shape, 'int32'))
z = relay.image.crop_and_resize(img, bx, bx_idx, list(crop_size),
layout, method, extrapolation_value)
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType(ref_res.shape, "float32")
func = relay.Function([img, bx, bx_idx], z)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(image_data, boxes, box_indices)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-3, atol=1e-04)
boxes_nhwc = np.array([[.1, .2, .8, .7], [.2, 0, 1, .6]]).astype("float32")
indices_nhwc = np.array([1, 0]).astype("int32")
size_nhwc = np.array([20, 30]).astype("int32")
boxes_nchw = np.array([[0, 0, 1, 1], [.2, .1, 1, .9]]).astype("float32")
indices_nchw = np.array([0, 1]).astype("int32")
size_nchw = np.array([30, 30]).astype("int32")
for method in ["bilinear", "nearest_neighbor"]:
verify_crop_and_resize((10, 224, 224, 3), boxes_nhwc, indices_nhwc,
size_nhwc, 'NHWC', method)
verify_crop_and_resize((5, 3, 255, 255), boxes_nchw, indices_nchw,
size_nchw, 'NCHW', method, 0.1)
def test_multibox_prior(): def test_multibox_prior():
def get_ref_result(dshape, sizes=(1.0,), def get_ref_result(dshape, sizes=(1.0,),
...@@ -639,6 +680,7 @@ def test_space_to_depth(): ...@@ -639,6 +680,7 @@ def test_space_to_depth():
if __name__ == "__main__": if __name__ == "__main__":
test_resize_infer_type() test_resize_infer_type()
test_resize() test_resize()
test_crop_and_resize()
test_multibox_prior() test_multibox_prior()
test_multibox_transform_loc() test_multibox_transform_loc()
test_get_valid_counts() test_get_valid_counts()
...@@ -650,4 +692,4 @@ if __name__ == "__main__": ...@@ -650,4 +692,4 @@ if __name__ == "__main__":
test_non_max_suppression() test_non_max_suppression()
test_deformable_conv2d() test_deformable_conv2d()
test_depth_to_space() test_depth_to_space()
test_space_to_depth() test_space_to_depth()
\ No newline at end of file
...@@ -51,3 +51,4 @@ from .pool_grad_python import pool_grad_nchw ...@@ -51,3 +51,4 @@ from .pool_grad_python import pool_grad_nchw
from .one_hot import one_hot from .one_hot import one_hot
from .depth_to_space import depth_to_space_python from .depth_to_space import depth_to_space_python
from .space_to_depth import space_to_depth_python from .space_to_depth import space_to_depth_python
from .crop_and_resize_python import crop_and_resize_python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-nested-blocks
"""crop and resize in python"""
import math
import numpy as np
def crop_and_resize_python(image, boxes, box_indices, crop_size, layout,
method='bilinear', extrapolation_value=0):
"""Crop and resize using python"""
(target_h, target_w) = crop_size
if layout == 'NHWC':
batch = boxes.shape[0]
image_height, image_width, channel = image.shape[1], image.shape[2], image.shape[3]
scaled_image = np.ones((batch, target_h, target_w, channel))
else:
batch = boxes.shape[0]
channel, image_height, image_width = image.shape[1], image.shape[2], image.shape[3]
scaled_image = np.ones((batch, channel, target_h, target_w))
for n, box in enumerate(boxes):
b_in = box_indices[n]
y1, x1 = boxes[n][0], boxes[n][1]
y2, x2 = boxes[n][2], boxes[n][3]
in_h = (image_height - 1) * (y2 - y1)
in_w = (image_width - 1) * (x2 - x1)
h_scale = np.float32(in_h)/np.float32(target_h - 1)
w_scale = np.float32(in_w)/np.float32(target_w - 1)
for y in range(target_h):
in_y = y1 * (image_height - 1) + h_scale * y
if in_y < 0 or in_y > image_height - 1:
for x in range(target_w):
for d in range(channel):
if layout == 'NHWC':
scaled_image[n][y][x][d] = extrapolation_value
else:
scaled_image[n][d][y][x] = extrapolation_value
continue
if method == 'bilinear':
top_y_index = math.floor(in_y)
bottom_y_index = math.ceil(in_y)
y_lerp = in_y - top_y_index
for x in range(target_w):
in_x = x1 * (image_width - 1) + x * w_scale
if in_x < 0 or in_x > image_width - 1:
for d in range(channel):
if layout == 'NHWC':
scaled_image[n][y][x][d] = extrapolation_value
else:
scaled_image[n][d][y][x] = extrapolation_value
continue
left_x_index = math.floor(in_x)
right_x_index = math.ceil(in_x)
x_lerp = in_x - left_x_index
for d in range(channel):
if layout == "NHWC":
top_left = image[b_in][top_y_index][left_x_index][d]
top_right = image[b_in][top_y_index][right_x_index][d]
bottom_left = image[b_in][bottom_y_index][left_x_index][d]
bottom_right = image[b_in][bottom_y_index][right_x_index][d]
top = top_left + (top_right - top_left) * x_lerp
bottom = bottom_left + (bottom_right - bottom_left) * x_lerp
scaled_image[n][y][x][d] = top + (bottom - top) * y_lerp
else:
top_left = image[b_in][d][top_y_index][left_x_index]
top_right = image[b_in][d][top_y_index][right_x_index]
bottom_left = image[b_in][d][bottom_y_index][left_x_index]
bottom_right = image[b_in][d][bottom_y_index][right_x_index]
top = top_left + (top_right - top_left) * x_lerp
bottom = bottom_left + (bottom_right - bottom_left) * x_lerp
scaled_image[n][d][y][x] = top + (bottom - top) * y_lerp
elif method == 'nearest_neighbor':
for x in range(target_w):
in_x = x1 * (image_width - 1) + x * w_scale
if in_x < 0 or in_x > image_width - 1:
for d in range(channel):
if layout == 'NHWC':
scaled_image[n][y][x][d] = extrapolation_value
else:
scaled_image[n][d][y][x] = extrapolation_value
continue
closest_x_index = np.round(in_x).astype("int32")
closest_y_index = np.round(in_y).astype("int32")
for d in range(channel):
if layout == "NHWC":
scaled_image[n][y][x][d] = image[b_in][closest_y_index][closest_x_index][d]
else:
scaled_image[n][d][y][x] = image[b_in][d][closest_y_index][closest_x_index]
return scaled_image
...@@ -19,7 +19,6 @@ import numpy as np ...@@ -19,7 +19,6 @@ import numpy as np
import tvm import tvm
import topi import topi
import topi.testing import topi.testing
import math
from common import get_all_backend from common import get_all_backend
...@@ -99,7 +98,7 @@ def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth, ...@@ -99,7 +98,7 @@ def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth,
'Layout not supported {} '.format(layout)) 'Layout not supported {} '.format(layout))
B = topi.image.resize3d(A, (out_depth, out_height, out_width), layout=layout, B = topi.image.resize3d(A, (out_depth, out_height, out_width), layout=layout,
coordinate_transformation_mode=coordinate_transformation_mode, method=method) coordinate_transformation_mode=coordinate_transformation_mode, method=method)
if method == "trilinear": if method == "trilinear":
b_np = topi.testing.trilinear_resize3d_python(a_np, (out_depth, out_height, out_width), layout, b_np = topi.testing.trilinear_resize3d_python(a_np, (out_depth, out_height, out_width), layout,
...@@ -143,6 +142,68 @@ def test_resize3d(): ...@@ -143,6 +142,68 @@ def test_resize3d():
verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NDHWC', method="nearest_neighbor") verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NDHWC', method="nearest_neighbor")
def test_crop_and_resize():
def verify_crop_and_resize(image_shape, np_boxes, np_box_indices, np_crop_size, layout='NHWC',
method="bilinear", extrapolation_value=0.0):
images = tvm.placeholder(image_shape, name='images', dtype='float32')
np_images = np.random.uniform(size=image_shape).astype("float32")
boxes = tvm.placeholder(np_boxes.shape, name="boxes", dtype="float32")
box_ind = tvm.placeholder(np_box_indices.shape, name="box_ind", dtype="int32")
batch = len(np_box_indices)
target_height, target_width = np_crop_size[0], np_crop_size[1]
if layout == 'NHWC':
channel = image_shape[3]
out_shape = (batch, target_height, target_width, channel)
elif layout == 'NCHW':
channel = image_shape[1]
out_shape = (batch, channel, target_height, target_width)
else:
raise NotImplementedError(
'Layout {} is not supported.'.format(layout))
out = topi.image.crop_and_resize(images, boxes, box_ind, np_crop_size, layout=layout,
method=method, extrapolation_value=extrapolation_value)
baseline_np = topi.testing.crop_and_resize_python(np_images, np_boxes, np_box_indices,
np_crop_size, layout, method,
extrapolation_value)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(out)
tvm_images = tvm.nd.array(np_images, ctx)
tvm_boxes = tvm.nd.array(np_boxes, ctx)
tvm_indices = tvm.nd.array(np_box_indices, ctx)
tvm_out = tvm.nd.array(np.zeros(out_shape, dtype="float32"), ctx)
f = tvm.build(s, [images, boxes, box_ind, out], device, name="crop_and_resize")
f(tvm_images, tvm_boxes, tvm_indices, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), baseline_np, rtol=1e-3, atol=1e-3)
for device in get_all_backend():
check_device(device)
boxes_1 = np.array([[.2, .3, .7, .9]], dtype="float32")
boxes_2 = np.array([[.2, .3, .7, .9], [0, .1, .8, 1]], dtype="float32")
indices_1 = np.array([0], dtype="int32")
indices_2 = np.array([1, 0], dtype="int32")
size_1 = (7, 11)
size_2 = (90, 60)
verify_crop_and_resize((1, 255, 255, 3), boxes_1, indices_1, size_1, layout="NHWC")
verify_crop_and_resize((10, 224, 224, 5), boxes_2, indices_2,
size_2, extrapolation_value=0.3, layout="NHWC")
verify_crop_and_resize((1, 100, 100, 3), boxes_1, indices_1,
size_1, method='nearest_neighbor')
verify_crop_and_resize((1, 3, 224, 224), boxes_1, indices_1, size_1, layout="NCHW")
if __name__ == "__main__": if __name__ == "__main__":
test_resize() test_resize()
test_resize3d() test_resize3d()
test_crop_and_resize()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment