Commit 56416ed0 by Yong Wu Committed by Zhi

[TOPI][RELAY][OP] add op crop_and_resize (#4417)

* [TOPI][RELAY][OP] add op crop_and_resize

* fix pylint

* incorporate comments

* fix ci
parent 12e51e6c
......@@ -104,6 +104,7 @@ List of operators
topi.ndarray_size
topi.layout_transform
topi.image.resize
topi.image.crop_and_resize
topi.argsort
topi.topk
topi.sequence_mask
......@@ -207,6 +208,7 @@ topi.nn
topi.image
~~~~~~~~~~
.. autofunction:: topi.image.resize
.. autofunction:: topi.image.crop_and_resize
topi.sparse
~~~~~~~~~~~
......
......@@ -169,6 +169,7 @@ This level enables additional math and transform operators.
:nosignatures:
tvm.relay.image.resize
tvm.relay.image.crop_and_resize
tvm.relay.vision.multibox_prior
tvm.relay.vision.multibox_transform_loc
tvm.relay.vision.nms
......@@ -335,6 +336,7 @@ Level 4 Definitions
Level 5 Definitions
-------------------
.. autofunction:: tvm.relay.image.resize
.. autofunction:: tvm.relay.image.crop_and_resize
.. autofunction:: tvm.relay.vision.multibox_prior
.. autofunction:: tvm.relay.vision.multibox_transform_loc
.. autofunction:: tvm.relay.vision.nms
......
......@@ -63,6 +63,34 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
}
};
/*! \brief Attributes used in image crop_and_resize operator */
struct CropAndResizeAttrs : public tvm::AttrsNode<CropAndResizeAttrs> {
Array<IndexExpr> crop_size;
std::string layout;
std::string method;
double extrapolation_value;
DataType out_dtype;
TVM_DECLARE_ATTRS(CropAndResizeAttrs, "relay.attrs.CropAndResizeAttrs") {
TVM_ATTR_FIELD(crop_size).set_default(NullValue<Array<IndexExpr> >())
.describe("Target Size.");
TVM_ATTR_FIELD(layout).set_default("NCHW")
.describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
"dimensions respectively. Resize is applied on the 'H' and"
"'W' dimensions.");
TVM_ATTR_FIELD(method).set_default("bilinear")
.describe("Specify the mode to use for scaling."
"nearest_neighbor - Nearest Neighbor"
"bilinear - Bilinear Interpolation");
TVM_ATTR_FIELD(extrapolation_value).set_default(0.0)
.describe("Specify value for extrapolation.");
TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>())
.describe("Output data type.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_IMAGE_H_
......@@ -546,47 +546,20 @@ def _crop_and_resize():
# input image is a 4-D tensor of shape [batch, image_height, image_width, depth]
# boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2]
try:
boxes = _get_list_param(params, inputs[1])
box_ind = _get_list_param(params, inputs[2])
crop_size = _get_list_param(params, inputs[3])
except (IndexError, KeyError):
boxes = _infer_value(inputs[1], params).asnumpy().tolist()
box_ind = _infer_value(inputs[2], params).asnumpy().tolist()
crop_size = _infer_value(inputs[3], params).asnumpy().tolist()
data_shape = attr['_input_shapes'][inputs[0]]
data_dim = len(data_shape)
method = attr['method'].decode()
attrs = {}
attrs['size'] = crop_size
attrs['layout'] = 'NHWC'
if method.lower() == 'nearest':
method = 'nearest_neighbor' if method == 'nearest' else method
if method not in ['bilinear', 'nearest_neighbor']:
raise tvm.error.OpAttributeUnImplemented(
'Attribute method=nearest is not supported')
else:
attrs['coordinate_transformation_mode'] = 'align_corners'
attrs['method'] = 'bilinear'
out = None
begin = [0] * data_dim
size = data_shape[:]
for idx in box_ind:
# 1) Crop
# y is mapped to the image coordinate at y * (image_height - 1)
# x is mapped to the image coordinate at x * (image_width - 1)
begin[0] = idx
begin[1] = int(round(boxes[idx][0] * (data_shape[1] - 1)))
begin[2] = int(round(boxes[idx][1] * (data_shape[2] - 1)))
size[0] = idx + 1
size[1] = int(round((data_shape[1] - 1) * boxes[idx][2])) + 1
size[2] = int(round((data_shape[2] - 1) * boxes[idx][3])) + 1
res_crop = _op.strided_slice(inputs[0], begin=begin, end=size)
# 2) Resize
res_resize = get_relay_op('resize')(res_crop, **attrs)
out = _op.concatenate([out, res_resize], axis=0) if out else res_resize
return out
'Method {} is not supported'.format(method))
layout = attr['layout'] if 'layout' in attr else 'NHWC'
extrapolation_value = attr['extrapolation_value']
return get_relay_op("crop_and_resize")(inputs[0], inputs[1], inputs[2], crop_size,
layout, method, extrapolation_value)
return _impl
def _cast():
......
......@@ -25,7 +25,6 @@ from ..op import schedule_injective
# resize
reg.register_schedule("image.resize", schedule_injective)
@reg.register_compute("image.resize")
def compute_resize(attrs, inputs, out_type, target):
size = attrs.size
......@@ -34,3 +33,18 @@ def compute_resize(attrs, inputs, out_type, target):
coord_trans = attrs.coordinate_transformation_mode
out_dtype = attrs.out_dtype
return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)]
# crop and resize
reg.register_schedule("image.crop_and_resize", schedule_injective)
@reg.register_compute("image.crop_and_resize")
def compute_crop_and_resize(attrs, inputs, out_type, target):
crop_size = attrs.crop_size
layout = attrs.layout
method = attrs.method
extrapolation_value = attrs.extrapolation_value
out_dtype = attrs.out_dtype
return [topi.image.crop_and_resize(inputs[0], inputs[1], inputs[2],
crop_size, layout, method,
extrapolation_value, out_dtype)]
......@@ -31,7 +31,7 @@ def resize(data,
with data of shape (n, c, h, w)
out will have a shape (n, c, size[0], size[1])
method indicates the algorithm to be used while calculating ghe out value
method indicates the algorithm to be used while calculating the out value
and method can be one of ("bilinear", "nearest_neighbor", "bicubic")
Parameters
......@@ -63,3 +63,53 @@ def resize(data,
The resized result.
"""
return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype)
def crop_and_resize(data,
boxes,
box_indices,
crop_size,
layout,
method="bilinear",
extrapolation_value=0,
out_dtype=None):
"""Crop input images and resize them.
method indicates the algorithm to be used while calculating the out value
and method can be either "bilinear" or "nearest_neighbor".
Parameters
----------
data : relay.Expr
The input data to the operator.
boxes : relay.Expr
A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies
the coordinates of a box.
box_indices : relay.Expr
A 1-D tensor of shape [num_boxes], box_ind[i] specifies the data that
the i-th box refers to.
crop_size : Tuple of Expr
The target size to which each box will be resized.
layout : str, optional
Layout of the input.
method : str, optional
Scale method, it can be either "nearest_neighbor" or "bilinear".
extrapolation_value : float, optional
Value used for extrapolation, when applicable.
out_dtype : str, optional
Type to return. If left None returns the same type as input.
Returns
-------
result: relay.Expr
The computed result.
"""
return _make.crop_and_resize(data, boxes, box_indices, crop_size,
layout, method, extrapolation_value, out_dtype)
......@@ -114,6 +114,9 @@ class DeformableConv2DAttrs(Attrs):
class ResizeAttrs(Attrs):
"""Attributes for image.resize"""
@register_relay_attr_node
class CropAndResizeAttrs(Attrs):
"""Attributes for image.crop_and_resize"""
@register_relay_attr_node
class ArgsortAttrs(Attrs):
......
......@@ -246,8 +246,8 @@ PrimExpr div(PrimExpr a, PrimExpr b) {
}
PrimExpr truncdiv(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
return div(a, b);
}
......@@ -276,8 +276,8 @@ PrimExpr indexmod(PrimExpr a, PrimExpr b) {
}
PrimExpr floordiv(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<ir::FloorDivNode>(a, b);
if (ret.defined()) return ret;
......@@ -285,8 +285,8 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b) {
}
PrimExpr floormod(PrimExpr a, PrimExpr b) {
CHECK(a.dtype().is_int() || a.dtype().is_uint());
CHECK(b.dtype().is_int() || b.dtype().is_uint());
CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a;
CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b;
BinaryOpMatchTypes(a, b);
PrimExpr ret = arith::TryConstFold<ir::FloorModNode>(a, b);
if (ret.defined()) return ret;
......
......@@ -109,5 +109,89 @@ RELAY_REGISTER_OP("image.resize")
.add_type_rel("Resize", ResizeRel)
.set_attr<TOpPattern>("TOpPattern", kInjective);
TVM_REGISTER_NODE_TYPE(CropAndResizeAttrs);
bool CropAndResizeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 4);
const auto* data = types[0].as<TensorTypeNode>();
const auto* boxes = types[1].as<TensorTypeNode>();
const auto* box_indices = types[2].as<TensorTypeNode>();
if (data == nullptr || boxes == nullptr ||
box_indices == nullptr) return false;
const CropAndResizeAttrs* param = attrs.as<CropAndResizeAttrs>();
CHECK(param != nullptr);
auto crop_size = param->crop_size;
DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
out_dtype = data->dtype;
}
// 4-D tensor of shape [num_boxes, crop_height, crop_width, depth]
static const Layout kNCHW("NCHW");
const Layout in_layout(param->layout);
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(0, box_indices->shape[0]);
oshape.Set(2, crop_size[0]);
oshape.Set(3, crop_size[1]);
auto bshape = layout_converter.BackwardShape(oshape);
// assign output type
reporter->Assign(types[3],
TensorTypeNode::make(layout_converter.BackwardShape(oshape),
out_dtype));
return true;
}
Expr MakeCropAndResize(Expr data,
Expr boxes,
Expr box_indices,
Array<IndexExpr> crop_size,
std::string layout,
std::string method,
double extrapolation_value,
DataType out_dtype) {
auto attrs = make_object<CropAndResizeAttrs>();
attrs->crop_size = std::move(crop_size);
attrs->layout = std::move(layout);
attrs->method = std::move(method);
attrs->extrapolation_value = std::move(extrapolation_value);
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("image.crop_and_resize");
return CallNode::make(op, {data, boxes, box_indices}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize")
.set_body_typed(MakeCropAndResize);
RELAY_REGISTER_OP("image.crop_and_resize")
.describe(R"code(Perform crop and resize to input array with nearest neighbour or bilinear interpolation.
- **data**: data is 4D array of shape
(batch_size, channels, in_height, in_width) for NCHW
(batch_size, in_height, in_width, channels) for NHWC
- **out**: Output is 4D array of shape
for layout NCHW
(batch_size, channels, crop_size[0], crop_size[1])
for layout NHWC
(batch_size, crop_size[0], crop_size[1], channels)
)code" TVM_ADD_FILELINE)
.set_num_inputs(3)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("boxes", "Tensor", "The boxes tensor.")
.add_argument("box_indices", "Tensor", "The box indices tensor.")
.set_attrs_type<CropAndResizeAttrs>()
.set_support_level(5)
.add_type_rel("CropAndResize", CropAndResizeRel)
.set_attr<TOpPattern>("TOpPattern", kInjective);
} // namespace relay
} // namespace tvm
......@@ -1706,39 +1706,47 @@ def test_forward_crop():
# CropAndResize
# -------------
def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size, method='bilinear', dtype="float32"):
def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size,
extrapolation_value=0.0, method='bilinear', dtype="float32"):
image = np.random.uniform(0, 10, size=img_shape).astype(dtype)
tf.reset_default_graph()
in_data = tf.placeholder(dtype, image.shape, name="in_data")
tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx, crop_size=crop_size,
method=method, name="crop_and_resize")
tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx,
crop_size=crop_size, method=method,
extrapolation_value=extrapolation_value,
name="crop_and_resize")
compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0')
def test_forward_crop_and_resize():
""" CropAndResize """
_test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, 1, 1]], [0], [5, 5])
_test_forward_crop_and_resize(
[1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5])
_test_forward_crop_and_resize(
[1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5])
_test_forward_crop_and_resize(
[1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4])
_test_forward_crop_and_resize(
[1, 41, 41, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3])
_test_forward_crop_and_resize([10, 11, 11, 3],
[[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]],
[0, 1],
[5, 5])
_test_forward_crop_and_resize([3, 11, 11, 3],
[[0, 0, 0.9, 0.9], [
0.2, 0.2, 0.8, 0.8], [0, 0, 1, 1]],
[0, 1, 2],
[3, 3])
_test_forward_crop_and_resize([3, 11, 11, 3],
[[0, 0, 1, 0.8], [0, 0, 0.9, 0.9], [0, 0, 1, 0.8]],
[2, 1, 0],
[3, 3])
_test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3])
_test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3], 0.2)
_test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3], 0.2, 'nearest')
_test_forward_crop_and_resize([1, 11, 11, 3], [[.3, .3, 1, 1]], [0], [21, 21])
_test_forward_crop_and_resize([1, 41, 41, 3], [[.2, .4, .8, .8]], [0], [21, 11])
_test_forward_crop_and_resize([1, 100, 100, 3], [[ 0, 0, .9, .9]], [0], [30, 30])
_test_forward_crop_and_resize([1, 224, 224, 3], [[.1, .2, 1, 1]], [0], [9, 9])
_test_forward_crop_and_resize([1, 249, 249, 3], [[ 0, 0, 1, 1]], [0], [9, 9])
_test_forward_crop_and_resize([1, 201, 301, 3], [[.2, .3, .7, .8]], [0], [51, 51])
_test_forward_crop_and_resize(img_shape=[10, 11, 11, 3],
boxes=[[ 0, 0, .9, .9],
[.2, .2, .8, .8]],
box_idx=[0, 1], crop_size=[5, 5])
_test_forward_crop_and_resize(img_shape=[20, 576, 576, 3],
boxes=[[ 0, 0, 1, 1],
[ 0, 0, .8, .8],
[.1, .2, .9, 1],
[.2, 0, 1, 1]],
box_idx=[1, 0, 2, 3], crop_size=[24, 24],
extrapolation_value=0.3)
_test_forward_crop_and_resize(img_shape=[20, 229, 229, 3],
boxes=[[ 0, 0, .9, .9],
[.3, .3, 1, 1],
[.2, .1, .7, .8],
[ 0, 0, 1, 1]],
box_idx=[3, 0, 2, 1], crop_size=[58, 58],
extrapolation_value=0.2, method='nearest')
#######################################################################
......
......@@ -72,6 +72,47 @@ def test_resize():
for layout in ["NHWC", "NCHW"]:
verify_resize((1, 4, 4, 4), 2, method, layout)
def test_crop_and_resize():
def verify_crop_and_resize(img_shape, boxes, box_indices, crop_size,
layout, method, extrapolation_value=0.0):
image_data = np.random.uniform(size=img_shape).astype("float32")
ref_res = topi.testing.crop_and_resize_python(image_data,
boxes,
box_indices,
crop_size,
layout, method,
extrapolation_value)
img = relay.var("img", relay.TensorType(img_shape, 'float32'))
bx = relay.var('bx', relay.TensorType(boxes.shape, 'float32'))
bx_idx = relay.var('bx_idx', relay.TensorType(box_indices.shape, 'int32'))
z = relay.image.crop_and_resize(img, bx, bx_idx, list(crop_size),
layout, method, extrapolation_value)
zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType(ref_res.shape, "float32")
func = relay.Function([img, bx, bx_idx], z)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(image_data, boxes, box_indices)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-3, atol=1e-04)
boxes_nhwc = np.array([[.1, .2, .8, .7], [.2, 0, 1, .6]]).astype("float32")
indices_nhwc = np.array([1, 0]).astype("int32")
size_nhwc = np.array([20, 30]).astype("int32")
boxes_nchw = np.array([[0, 0, 1, 1], [.2, .1, 1, .9]]).astype("float32")
indices_nchw = np.array([0, 1]).astype("int32")
size_nchw = np.array([30, 30]).astype("int32")
for method in ["bilinear", "nearest_neighbor"]:
verify_crop_and_resize((10, 224, 224, 3), boxes_nhwc, indices_nhwc,
size_nhwc, 'NHWC', method)
verify_crop_and_resize((5, 3, 255, 255), boxes_nchw, indices_nchw,
size_nchw, 'NCHW', method, 0.1)
def test_multibox_prior():
def get_ref_result(dshape, sizes=(1.0,),
......@@ -639,6 +680,7 @@ def test_space_to_depth():
if __name__ == "__main__":
test_resize_infer_type()
test_resize()
test_crop_and_resize()
test_multibox_prior()
test_multibox_transform_loc()
test_get_valid_counts()
......@@ -650,4 +692,4 @@ if __name__ == "__main__":
test_non_max_suppression()
test_deformable_conv2d()
test_depth_to_space()
test_space_to_depth()
\ No newline at end of file
test_space_to_depth()
......@@ -51,3 +51,4 @@ from .pool_grad_python import pool_grad_nchw
from .one_hot import one_hot
from .depth_to_space import depth_to_space_python
from .space_to_depth import space_to_depth_python
from .crop_and_resize_python import crop_and_resize_python
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-nested-blocks
"""crop and resize in python"""
import math
import numpy as np
def crop_and_resize_python(image, boxes, box_indices, crop_size, layout,
method='bilinear', extrapolation_value=0):
"""Crop and resize using python"""
(target_h, target_w) = crop_size
if layout == 'NHWC':
batch = boxes.shape[0]
image_height, image_width, channel = image.shape[1], image.shape[2], image.shape[3]
scaled_image = np.ones((batch, target_h, target_w, channel))
else:
batch = boxes.shape[0]
channel, image_height, image_width = image.shape[1], image.shape[2], image.shape[3]
scaled_image = np.ones((batch, channel, target_h, target_w))
for n, box in enumerate(boxes):
b_in = box_indices[n]
y1, x1 = boxes[n][0], boxes[n][1]
y2, x2 = boxes[n][2], boxes[n][3]
in_h = (image_height - 1) * (y2 - y1)
in_w = (image_width - 1) * (x2 - x1)
h_scale = np.float32(in_h)/np.float32(target_h - 1)
w_scale = np.float32(in_w)/np.float32(target_w - 1)
for y in range(target_h):
in_y = y1 * (image_height - 1) + h_scale * y
if in_y < 0 or in_y > image_height - 1:
for x in range(target_w):
for d in range(channel):
if layout == 'NHWC':
scaled_image[n][y][x][d] = extrapolation_value
else:
scaled_image[n][d][y][x] = extrapolation_value
continue
if method == 'bilinear':
top_y_index = math.floor(in_y)
bottom_y_index = math.ceil(in_y)
y_lerp = in_y - top_y_index
for x in range(target_w):
in_x = x1 * (image_width - 1) + x * w_scale
if in_x < 0 or in_x > image_width - 1:
for d in range(channel):
if layout == 'NHWC':
scaled_image[n][y][x][d] = extrapolation_value
else:
scaled_image[n][d][y][x] = extrapolation_value
continue
left_x_index = math.floor(in_x)
right_x_index = math.ceil(in_x)
x_lerp = in_x - left_x_index
for d in range(channel):
if layout == "NHWC":
top_left = image[b_in][top_y_index][left_x_index][d]
top_right = image[b_in][top_y_index][right_x_index][d]
bottom_left = image[b_in][bottom_y_index][left_x_index][d]
bottom_right = image[b_in][bottom_y_index][right_x_index][d]
top = top_left + (top_right - top_left) * x_lerp
bottom = bottom_left + (bottom_right - bottom_left) * x_lerp
scaled_image[n][y][x][d] = top + (bottom - top) * y_lerp
else:
top_left = image[b_in][d][top_y_index][left_x_index]
top_right = image[b_in][d][top_y_index][right_x_index]
bottom_left = image[b_in][d][bottom_y_index][left_x_index]
bottom_right = image[b_in][d][bottom_y_index][right_x_index]
top = top_left + (top_right - top_left) * x_lerp
bottom = bottom_left + (bottom_right - bottom_left) * x_lerp
scaled_image[n][d][y][x] = top + (bottom - top) * y_lerp
elif method == 'nearest_neighbor':
for x in range(target_w):
in_x = x1 * (image_width - 1) + x * w_scale
if in_x < 0 or in_x > image_width - 1:
for d in range(channel):
if layout == 'NHWC':
scaled_image[n][y][x][d] = extrapolation_value
else:
scaled_image[n][d][y][x] = extrapolation_value
continue
closest_x_index = np.round(in_x).astype("int32")
closest_y_index = np.round(in_y).astype("int32")
for d in range(channel):
if layout == "NHWC":
scaled_image[n][y][x][d] = image[b_in][closest_y_index][closest_x_index][d]
else:
scaled_image[n][d][y][x] = image[b_in][d][closest_y_index][closest_x_index]
return scaled_image
......@@ -19,7 +19,6 @@ import numpy as np
import tvm
import topi
import topi.testing
import math
from common import get_all_backend
......@@ -99,7 +98,7 @@ def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth,
'Layout not supported {} '.format(layout))
B = topi.image.resize3d(A, (out_depth, out_height, out_width), layout=layout,
coordinate_transformation_mode=coordinate_transformation_mode, method=method)
coordinate_transformation_mode=coordinate_transformation_mode, method=method)
if method == "trilinear":
b_np = topi.testing.trilinear_resize3d_python(a_np, (out_depth, out_height, out_width), layout,
......@@ -143,6 +142,68 @@ def test_resize3d():
verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NDHWC', method="nearest_neighbor")
def test_crop_and_resize():
def verify_crop_and_resize(image_shape, np_boxes, np_box_indices, np_crop_size, layout='NHWC',
method="bilinear", extrapolation_value=0.0):
images = tvm.placeholder(image_shape, name='images', dtype='float32')
np_images = np.random.uniform(size=image_shape).astype("float32")
boxes = tvm.placeholder(np_boxes.shape, name="boxes", dtype="float32")
box_ind = tvm.placeholder(np_box_indices.shape, name="box_ind", dtype="int32")
batch = len(np_box_indices)
target_height, target_width = np_crop_size[0], np_crop_size[1]
if layout == 'NHWC':
channel = image_shape[3]
out_shape = (batch, target_height, target_width, channel)
elif layout == 'NCHW':
channel = image_shape[1]
out_shape = (batch, channel, target_height, target_width)
else:
raise NotImplementedError(
'Layout {} is not supported.'.format(layout))
out = topi.image.crop_and_resize(images, boxes, box_ind, np_crop_size, layout=layout,
method=method, extrapolation_value=extrapolation_value)
baseline_np = topi.testing.crop_and_resize_python(np_images, np_boxes, np_box_indices,
np_crop_size, layout, method,
extrapolation_value)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(out)
tvm_images = tvm.nd.array(np_images, ctx)
tvm_boxes = tvm.nd.array(np_boxes, ctx)
tvm_indices = tvm.nd.array(np_box_indices, ctx)
tvm_out = tvm.nd.array(np.zeros(out_shape, dtype="float32"), ctx)
f = tvm.build(s, [images, boxes, box_ind, out], device, name="crop_and_resize")
f(tvm_images, tvm_boxes, tvm_indices, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), baseline_np, rtol=1e-3, atol=1e-3)
for device in get_all_backend():
check_device(device)
boxes_1 = np.array([[.2, .3, .7, .9]], dtype="float32")
boxes_2 = np.array([[.2, .3, .7, .9], [0, .1, .8, 1]], dtype="float32")
indices_1 = np.array([0], dtype="int32")
indices_2 = np.array([1, 0], dtype="int32")
size_1 = (7, 11)
size_2 = (90, 60)
verify_crop_and_resize((1, 255, 255, 3), boxes_1, indices_1, size_1, layout="NHWC")
verify_crop_and_resize((10, 224, 224, 5), boxes_2, indices_2,
size_2, extrapolation_value=0.3, layout="NHWC")
verify_crop_and_resize((1, 100, 100, 3), boxes_1, indices_1,
size_1, method='nearest_neighbor')
verify_crop_and_resize((1, 3, 224, 224), boxes_1, indices_1, size_1, layout="NCHW")
if __name__ == "__main__":
test_resize()
test_resize3d()
test_crop_and_resize()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment