Commit e8a2c9b3 by masahi Committed by Yizhi Liu

[TOPI, Relay] Add half_pixel option to Resize op (#4610)

* add onnx resize converter

* update frontends

* updating topi

* adding onnx resize tests

* fixed NHWC test by casting size dtype to int32

* fix tests

* fix lint

* update existing test cases

* fix tensorflow frontend

* fix lint

* remove NHWC stuff

* update topi resize test for half_pixel

* update doc

* fix doc

* remove onnx resize bits
parent 203ca7a0
...@@ -36,7 +36,7 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> { ...@@ -36,7 +36,7 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
Array<IndexExpr> size; Array<IndexExpr> size;
std::string layout; std::string layout;
std::string method; std::string method;
bool align_corners; std::string coordinate_transformation_mode;
DataType out_dtype; DataType out_dtype;
TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") {
...@@ -52,8 +52,11 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> { ...@@ -52,8 +52,11 @@ struct ResizeAttrs : public tvm::AttrsNode<ResizeAttrs> {
"nearest_neighbor - Nearest Neighbor" "nearest_neighbor - Nearest Neighbor"
"bilinear - Bilinear Interpolation" "bilinear - Bilinear Interpolation"
"bicubic - Bicubic Interpolation"); "bicubic - Bicubic Interpolation");
TVM_ATTR_FIELD(align_corners).set_default(true) TVM_ATTR_FIELD(coordinate_transformation_mode).set_default("half_pixel")
.describe("Should be true to preserve the values at the corner pixels"); .describe("Describes how to transform the coordinate in the resized tensor"
"to the coordinate in the original tensor."
"Refer to the ONNX Resize operator specification for details"
"Available options are half_pixel, align_corners and asymmetric");
TVM_ATTR_FIELD(out_dtype) TVM_ATTR_FIELD(out_dtype)
.set_default(NullValue<DataType>()) .set_default(NullValue<DataType>())
.describe("Output data type."); .describe("Output data type.");
......
...@@ -676,7 +676,7 @@ def _mx_resize(inputs, attrs): ...@@ -676,7 +676,7 @@ def _mx_resize(inputs, attrs):
if scale_width is not None: if scale_width is not None:
width = (scale_width * shape[3]).astype("int32") width = (scale_width * shape[3]).astype("int32")
size = (height, width) size = (height, width)
return _op.image.resize(inputs[0], size, align_corners=True) return _op.image.resize(inputs[0], size, coordinate_transformation_mode="align_corners")
def _mx_roi_pooling(inputs, attrs): def _mx_roi_pooling(inputs, attrs):
new_attrs = {} new_attrs = {}
......
...@@ -1091,6 +1091,7 @@ class Or(Elemwise): ...@@ -1091,6 +1091,7 @@ class Or(Elemwise):
def _impl_v7(cls, inputs, attr, params): def _impl_v7(cls, inputs, attr, params):
return _op.logical_or(inputs[0], inputs[1]) return _op.logical_or(inputs[0], inputs[1])
class Expand(OnnxOpConverter): class Expand(OnnxOpConverter):
""" Operator converter for Expand. """ Operator converter for Expand.
""" """
...@@ -1138,6 +1139,7 @@ class Expand(OnnxOpConverter): ...@@ -1138,6 +1139,7 @@ class Expand(OnnxOpConverter):
shape = expand_shape(in_shape, shape) shape = expand_shape(in_shape, shape)
return _op.broadcast_to(inputs[0], shape=tuple(shape)) return _op.broadcast_to(inputs[0], shape=tuple(shape))
# compatible operators that do NOT require any conversion. # compatible operators that do NOT require any conversion.
_identity_list = [] _identity_list = []
...@@ -1263,7 +1265,7 @@ def _get_convert_map(opset): ...@@ -1263,7 +1265,7 @@ def _get_convert_map(opset):
'Tile': Tile.get_converter(opset), 'Tile': Tile.get_converter(opset),
'Erf': Erf.get_converter(opset), 'Erf': Erf.get_converter(opset),
'Where': Where.get_converter(opset), 'Where': Where.get_converter(opset),
'Or': Or.get_converter(opset) 'Or': Or.get_converter(opset),
} }
......
...@@ -582,7 +582,7 @@ def _crop_and_resize(): ...@@ -582,7 +582,7 @@ def _crop_and_resize():
raise tvm.error.OpAttributeUnImplemented( raise tvm.error.OpAttributeUnImplemented(
'Attribute method=nearest is not supported') 'Attribute method=nearest is not supported')
else: else:
attrs['align_corners'] = True attrs['coordinate_transformation_mode'] = 'align_corners'
attrs['method'] = 'bilinear' attrs['method'] = 'bilinear'
out = None out = None
...@@ -632,6 +632,10 @@ def _resize(method): ...@@ -632,6 +632,10 @@ def _resize(method):
inputs.pop(1) inputs.pop(1)
# NHWC # NHWC
attr['layout'] = 'NHWC' attr['layout'] = 'NHWC'
if attr.pop('align_corners') is True:
attr['coordinate_transformation_mode'] = 'align_corners'
else:
attr['coordinate_transformation_mode'] = 'asymmetric'
# Ignore the new attributes from TF2.0, for now. # Ignore the new attributes from TF2.0, for now.
return AttrCvt(op_name='resize', return AttrCvt(op_name='resize',
......
...@@ -330,7 +330,9 @@ class OperatorConverter(object): ...@@ -330,7 +330,9 @@ class OperatorConverter(object):
align_corners = resize_options.AlignCorners() align_corners = resize_options.AlignCorners()
# Use layout NHWC # Use layout NHWC
out = _op.image.resize(in_expr, target_size, "NHWC", method, align_corners) coord_trans = "align_corners" if align_corners else "asymmetric"
out = _op.image.resize(in_expr, target_size, "NHWC", method,
coordinate_transformation_mode=coord_trans)
return out return out
def convert_resize_bilinear(self, op): def convert_resize_bilinear(self, op):
......
...@@ -31,6 +31,6 @@ def compute_resize(attrs, inputs, out_type, target): ...@@ -31,6 +31,6 @@ def compute_resize(attrs, inputs, out_type, target):
size = attrs.size size = attrs.size
layout = attrs.layout layout = attrs.layout
method = attrs.method method = attrs.method
align_corners = attrs.align_corners coord_trans = attrs.coordinate_transformation_mode
out_dtype = attrs.out_dtype out_dtype = attrs.out_dtype
return [topi.image.resize(inputs[0], size, layout, method, align_corners, out_dtype)] return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)]
...@@ -22,7 +22,7 @@ def resize(data, ...@@ -22,7 +22,7 @@ def resize(data,
size, size,
layout="NCHW", layout="NCHW",
method="bilinear", method="bilinear",
align_corners=True, coordinate_transformation_mode="half_pixel",
out_dtype=None): out_dtype=None):
"""Image resize operator. """Image resize operator.
...@@ -48,8 +48,11 @@ def resize(data, ...@@ -48,8 +48,11 @@ def resize(data,
method : str, optional method : str, optional
Scale method to used [nearest_neighbor, bilinear, bicubic]. Scale method to used [nearest_neighbor, bilinear, bicubic].
align_corners : int, optional coordinate_transformation_mode : string, optional
Should be true to preserve the values at the corner pixels Describes how to transform the coordinate in the resized tensor
to the coordinate in the original tensor.
Refer to the ONNX Resize operator specification for details.
[half_pixel, align_corners, asymmetric]
out_dtype : str, optional out_dtype : str, optional
Type to return. If left None returns the same type as input. Type to return. If left None returns the same type as input.
...@@ -59,4 +62,4 @@ def resize(data, ...@@ -59,4 +62,4 @@ def resize(data,
result: relay.Expr result: relay.Expr
The resized result. The resized result.
""" """
return _make.resize(data, size, layout, method, align_corners, out_dtype) return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype)
...@@ -71,13 +71,13 @@ Expr MakeResize(Expr data, ...@@ -71,13 +71,13 @@ Expr MakeResize(Expr data,
Array<IndexExpr> size, Array<IndexExpr> size,
std::string layout, std::string layout,
std::string method, std::string method,
bool align_corners, std::string coordinate_transformation_mode,
DataType out_dtype) { DataType out_dtype) {
auto attrs = make_object<ResizeAttrs>(); auto attrs = make_object<ResizeAttrs>();
attrs->size = std::move(size); attrs->size = std::move(size);
attrs->layout = std::move(layout); attrs->layout = std::move(layout);
attrs->method = std::move(method); attrs->method = std::move(method);
attrs->align_corners = align_corners; attrs->coordinate_transformation_mode = coordinate_transformation_mode;
attrs->out_dtype = out_dtype; attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("image.resize"); static const Op& op = Op::Get("image.resize");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
......
...@@ -98,23 +98,6 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape): ...@@ -98,23 +98,6 @@ def verify_onnx_forward_impl(graph_file, data_shape, out_shape):
tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5) tvm.testing.assert_allclose(c2_out, tvm_out, rtol=1e-5, atol=1e-5)
def verify_super_resolution_example():
verify_onnx_forward_impl(
super_resolution, (1, 1, 224, 224), (1, 1, 672, 672))
def verify_squeezenet1_1():
verify_onnx_forward_impl(squeezenet1_1, (1, 3, 224, 224), (1, 1000))
def verify_lenet():
verify_onnx_forward_impl(lenet, (1, 1, 28, 28), (1, 10))
def verify_resnet18():
verify_onnx_forward_impl(resnet18_1_0, (1, 3, 224, 224), (1, 1000))
def test_reshape(): def test_reshape():
in_shape = (4, 3, 3, 4) in_shape = (4, 3, 3, 4)
ref_shape = (6, 2, 4, 3) ref_shape = (6, 2, 4, 3)
......
...@@ -39,7 +39,7 @@ def test_resize_infer_type(): ...@@ -39,7 +39,7 @@ def test_resize_infer_type():
assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8") assert zz.checked_type == relay.TensorType((n, c, th, tw), "int8")
x = relay.var("x", relay.TensorType((n, c, h, w), "int8")) x = relay.var("x", relay.TensorType((n, c, h, w), "int8"))
z= relay.image.resize(x, (100, 200), "NCHW", "bilinear", True) z= relay.image.resize(x, (100, 200), "NCHW", "bilinear", "align_corners")
assert "size=" in z.astext() assert "size=" in z.astext()
zz = run_infer_type(z) zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8") assert zz.checked_type == relay.TensorType((n, c, 100, 200), "int8")
...@@ -57,7 +57,7 @@ def test_resize(): ...@@ -57,7 +57,7 @@ def test_resize():
else: else:
ref_res = topi.testing.upsampling_python(x_data, (scale, scale), layout) ref_res = topi.testing.upsampling_python(x_data, (scale, scale), layout)
x = relay.var("x", relay.TensorType(dshape, "float32")) x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.image.resize(x, size, layout, method, True) z = relay.image.resize(x, size, layout, method, "align_corners")
assert "size=" in z.astext() assert "size=" in z.astext()
zz = run_infer_type(z) zz = run_infer_type(z)
assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") assert zz.checked_type == relay.TensorType(ref_res.shape, "float32")
......
...@@ -21,7 +21,8 @@ import tvm ...@@ -21,7 +21,8 @@ import tvm
from .. import tag from .. import tag
def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out_dtype=None): def resize(data, size, layout="NCHW", method="bilinear",
coordinate_transformation_mode="half_pixel", out_dtype=None):
"""Perform resize operation on the data. """Perform resize operation on the data.
Parameters Parameters
...@@ -37,8 +38,11 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out ...@@ -37,8 +38,11 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out
layout: string, optional layout: string, optional
"NCHW", "NHWC", or "NCHWc". "NCHW", "NHWC", or "NCHWc".
align_corners: Boolean, optional coordinate_transformation_mode: string, optional
To preserve the values at the corner pixels. Describes how to transform the coordinate in the resized tensor
to the coordinate in the original tensor.
Refer to the ONNX Resize operator specification for details.
Available options are "half_pixel", "align_corners" and "asymmetric".
method: {"bilinear", "nearest_neighbor", "bicubic"} method: {"bilinear", "nearest_neighbor", "bicubic"}
Method to be used for resizing. Method to be used for resizing.
...@@ -66,12 +70,15 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out ...@@ -66,12 +70,15 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out
in_n, in_c, in_h, in_w, in_cc = data.shape in_n, in_c, in_h, in_w, in_cc = data.shape
output_shape = [in_n, in_c, size[0], size[1], in_cc] output_shape = [in_n, in_c, size[0], size[1], in_cc]
if align_corners: if coordinate_transformation_mode == "align_corners":
y_ratio = (in_h - 1).astype('float') / (size[0] - 1) y_ratio = (in_h - 1).astype('float') / (size[0] - 1)
x_ratio = (in_w - 1).astype('float') / (size[1] - 1) x_ratio = (in_w - 1).astype('float') / (size[1] - 1)
else: elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]:
y_ratio = (in_h).astype('float') / (size[0]) y_ratio = (in_h).astype('float') / (size[0])
x_ratio = (in_w).astype('float') / (size[1]) x_ratio = (in_w).astype('float') / (size[1])
else:
raise ValueError("Unsupported coordinate_transformation_mode: {}".format(
coordinate_transformation_mode))
def _get_pixel(n, c, y, x, cc): def _get_pixel(n, c, y, x, cc):
y = tvm.max(tvm.min(y, in_h - 1), 0) y = tvm.max(tvm.min(y, in_h - 1), 0)
...@@ -109,7 +116,7 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out ...@@ -109,7 +116,7 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out
in_y = y_ratio * y in_y = y_ratio * y
in_x = x_ratio * x in_x = x_ratio * x
if align_corners: if coordinate_transformation_mode == "align_corners":
yint = tvm.round(in_y).astype('int32') yint = tvm.round(in_y).astype('int32')
xint = tvm.round(in_x).astype('int32') xint = tvm.round(in_x).astype('int32')
else: else:
...@@ -127,8 +134,12 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out ...@@ -127,8 +134,12 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out
def _bilinear(*indices): def _bilinear(*indices):
n, c, y, x, cc = _get_indices(*indices) n, c, y, x, cc = _get_indices(*indices)
in_y = y_ratio * y if coordinate_transformation_mode == "half_pixel":
in_x = x_ratio * x in_y = y_ratio * (y + 0.5) - 0.5
in_x = x_ratio * (x + 0.5) - 0.5
else:
in_y = y_ratio * y
in_x = x_ratio * x
xint = tvm.floor(in_x).astype('int32') xint = tvm.floor(in_x).astype('int32')
xfract = in_x - tvm.floor(in_x) xfract = in_x - tvm.floor(in_x)
...@@ -158,8 +169,12 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out ...@@ -158,8 +169,12 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out
def _bicubic(*indices): def _bicubic(*indices):
n, c, y, x, cc = _get_indices(*indices) n, c, y, x, cc = _get_indices(*indices)
in_y = y_ratio * y if coordinate_transformation_mode == "half_pixel":
in_x = x_ratio * x in_y = y_ratio * (y + 0.5) - 0.5
in_x = x_ratio * (x + 0.5) - 0.5
else:
in_y = y_ratio * y
in_x = x_ratio * x
xint = tvm.floor(in_x).astype('int32') xint = tvm.floor(in_x).astype('int32')
xfract = in_x - tvm.floor(in_x) xfract = in_x - tvm.floor(in_x)
......
...@@ -61,8 +61,9 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor', ...@@ -61,8 +61,9 @@ def upsampling(data, scale_h, scale_w, layout="NCHW", method='nearest_neighbor',
else: else:
raise ValueError("not support this layout {} yet".format(layout)) raise ValueError("not support this layout {} yet".format(layout))
coord_trans = "align_corners" if align_corners else "asymmetric"
return topi.image.resize(data, out_shape, layout=layout, return topi.image.resize(data, out_shape, layout=layout,
method=method, align_corners=align_corners) method=method, coordinate_transformation_mode=coord_trans)
def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor', def upsampling3d(data, scale_d, scale_h, scale_w, layout="NCDHW", method='nearest_neighbor',
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
import math import math
import numpy as np import numpy as np
def bilinear_resize_python(image, out_size, layout, align_corners=True): def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mode="align_corners"):
""" Bilinear scaling using python""" """ Bilinear scaling using python"""
(new_h, new_w) = out_size (new_h, new_w) = out_size
...@@ -30,32 +30,37 @@ def bilinear_resize_python(image, out_size, layout, align_corners=True): ...@@ -30,32 +30,37 @@ def bilinear_resize_python(image, out_size, layout, align_corners=True):
(batch, channel, h, w) = image.shape (batch, channel, h, w) = image.shape
scaled_image = np.ones((batch, channel, new_h, new_w)) scaled_image = np.ones((batch, channel, new_h, new_w))
if align_corners: if coordinate_transformation_mode == "align_corners":
height_scale = np.float32(h-1) / np.float32(out_size[0]-1) height_scale = np.float32(h-1) / np.float32(out_size[0]-1)
width_scale = np.float32(w-1) / np.float32(out_size[1]-1) width_scale = np.float32(w-1) / np.float32(out_size[1]-1)
else: else:
height_scale = np.float32(h) / np.float32(out_size[0]) height_scale = np.float32(h) / np.float32(out_size[0])
width_scale = np.float32(w) / np.float32(out_size[1]) width_scale = np.float32(w) / np.float32(out_size[1])
def _lerp(A, B, t):
return A * (1.0 - t) + B * t
for b in range(batch): for b in range(batch):
for i in range(channel): for i in range(channel):
for j in range(new_h): for j in range(new_h):
for k in range(new_w): for k in range(new_w):
in_y = j * height_scale if coordinate_transformation_mode == "half_pixel":
y0 = math.floor(in_y) in_y = (j + 0.5) * height_scale - 0.5
y1 = min(math.ceil(in_y), h - 1) else:
y_lerp = in_y - y0 in_y = j * height_scale
y0 = int(math.floor(in_y))
y0 = int(y0) y1 = max(min(y0 + 1, h - 1), 0)
y1 = int(y1) y0 = max(y0, 0)
y_lerp = in_y - math.floor(in_y)
in_x = k * width_scale
x0 = math.floor(in_x)
x1 = min(math.ceil(in_x), w - 1)
x_lerp = in_x - x0
x0 = int(x0) if coordinate_transformation_mode == "half_pixel":
x1 = int(x1) in_x = (k + 0.5) * width_scale - 0.5
else:
in_x = k * width_scale
x0 = int(math.floor(in_x))
x1 = max(min(x0 + 1, w - 1), 0)
x0 = max(x0, 0)
x_lerp = in_x - math.floor(in_x)
if layout == 'NHWC': if layout == 'NHWC':
A = image[b][y0][x0][i] A = image[b][y0][x0][i]
...@@ -68,10 +73,10 @@ def bilinear_resize_python(image, out_size, layout, align_corners=True): ...@@ -68,10 +73,10 @@ def bilinear_resize_python(image, out_size, layout, align_corners=True):
C = image[b][i][y1][x0] C = image[b][i][y1][x0]
D = image[b][i][y1][x1] D = image[b][i][y1][x1]
top = A + (B - A) * x_lerp top = _lerp(A, B, x_lerp)
bottom = C + (D - C) * x_lerp bottom = _lerp(C, D, x_lerp)
pixel = np.float32(top + (bottom - top) * y_lerp) pixel = np.float32(_lerp(top, bottom, y_lerp))
if layout == 'NHWC': if layout == 'NHWC':
scaled_image[b][j][k][i] = pixel scaled_image[b][j][k][i] = pixel
......
...@@ -23,7 +23,8 @@ import math ...@@ -23,7 +23,8 @@ import math
from common import get_all_backend from common import get_all_backend
def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=True, method="bilinear"): def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width,
layout='NCHW', coord_trans="align_corners", method="bilinear"):
if layout == 'NCHW': if layout == 'NCHW':
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32')
dtype = A.dtype dtype = A.dtype
...@@ -37,11 +38,9 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, ...@@ -37,11 +38,9 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width,
else: else:
raise NotImplementedError( raise NotImplementedError(
'Layout not supported {} '.format(layout)) 'Layout not supported {} '.format(layout))
B = topi.image.resize(A, (out_height, out_width), layout=layout, coordinate_transformation_mode=coord_trans, method=method)
B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners, method=method)
if method == "bilinear": if method == "bilinear":
b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners) b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, coord_trans)
else: else:
scale_h = out_height / in_height scale_h = out_height / in_height
scale_w = out_width / in_width scale_w = out_width / in_width
...@@ -70,14 +69,17 @@ def test_resize(): ...@@ -70,14 +69,17 @@ def test_resize():
# Scale NCHW # Scale NCHW
verify_resize(4, 16, 32, 32, 50, 50, 'NCHW') verify_resize(4, 16, 32, 32, 50, 50, 'NCHW')
# Scale NCHW + Align Corners # Scale NCHW + Align Corners
verify_resize(6, 32, 64, 64, 20, 20, 'NCHW', True) verify_resize(6, 32, 64, 64, 20, 20, 'NCHW')
# Scale NHWC # Scale NHWC
verify_resize(4, 16, 32, 32, 50, 50, "NHWC") verify_resize(4, 16, 32, 32, 50, 50, "NHWC")
# Scale NHWC + Align Corners # Scale NHWC + Align Corners
verify_resize(6, 32, 64, 64, 20, 20, "NHWC", True) verify_resize(6, 32, 64, 64, 20, 20, "NHWC")
# Nearest + Fractional # Nearest + Fractional
verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="nearest_neighbor", align_corners=False) verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', "asymmetric", method="nearest_neighbor")
verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="nearest_neighbor", align_corners=False) verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', "asymmetric", method="nearest_neighbor")
# half_pixel
verify_resize(4, 16, 16, 16, 32, 32, 'NCHW', "half_pixel", method="bilinear")
verify_resize(4, 16, 16, 16, 32, 32, 'NHWC', "half_pixel", method="bilinear")
def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth, out_height, out_width, def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth, out_height, out_width,
......
...@@ -43,7 +43,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w, ...@@ -43,7 +43,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w,
if method == "bilinear": if method == "bilinear":
out_size = (int(round(in_height*scale_h)), int(round(in_width*scale_w))) out_size = (int(round(in_height*scale_h)), int(round(in_width*scale_w)))
b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, align_corners=False) b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout, "asymmetric")
else: else:
b_np = topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout) b_np = topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment