Commit a8275bdb by masahi Committed by Tianqi Chen

[TOPI] Fix resize nearest with fractional scaling (#3244)

parent a479432d
......@@ -305,7 +305,7 @@ def test_upsampling_nearest_neighbor():
data = tvm.nd.array(a_np)
m.run(x=data)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
b_np = topi.testing.upsampling_python(a_np, scale, "NCHW")
b_np = topi.testing.upsampling_python(a_np, (scale, scale), "NCHW")
tvm.testing.assert_allclose(out.asnumpy(), b_np, rtol=1e-5)
def test_upsampling_bilinear():
......
......@@ -195,7 +195,7 @@ def verify_UpsampleLayerParams(input_dim, scale, mode):
a_np = np.full(input_dim, 1, dtype=dtype)
if mode == 'NN':
b_np = topi.testing.upsampling_python(a_np, scale)
b_np = topi.testing.upsampling_python(a_np, (scale, scale))
else:
new_h = input_dim[2] * scale
new_w = input_dim[3] * scale
......
......@@ -405,7 +405,7 @@ def _test_upsample_nearest():
y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0])
in_array = np.random.uniform(size=in_shape).astype(np.float32)
out_array = topi.testing.upsampling_python(in_array, scale, "NCHW")
out_array = topi.testing.upsampling_python(in_array, (scale, scale), "NCHW")
graph = helper.make_graph([y],
'upsample_nearest_test',
......
......@@ -179,7 +179,7 @@ def verify_UpsampleLayerParams(input_dim, scale, mode):
a_np = np.full(input_dim, 1, dtype=dtype)
if mode == 'NN':
b_np = topi.testing.upsampling_python(a_np, scale)
b_np = topi.testing.upsampling_python(a_np, (scale, scale))
else:
new_h = input_dim[2] * scale
new_w = input_dim[3] * scale
......
......@@ -417,7 +417,7 @@ def _test_upsample_nearest():
y = helper.make_node("Upsample", ['in'], ['out'], mode='nearest', scales=[1.0, 1.0, 2.0, 2.0])
in_array = np.random.uniform(size=in_shape).astype(np.float32)
out_array = topi.testing.upsampling_python(in_array, scale, "NCHW")
out_array = topi.testing.upsampling_python(in_array, (scale, scale), "NCHW")
graph = helper.make_graph([y],
'upsample_nearest_test',
......
......@@ -485,7 +485,7 @@ def _test_upsampling(layout, method):
func = relay.Function([x], y)
data = np.random.uniform(size=dshape).astype(dtype)
if method == "NEAREST_NEIGHBOR":
ref = topi.testing.upsampling_python(data, scale, layout)
ref = topi.testing.upsampling_python(data, (scale, scale), layout)
else:
ref = topi.testing.bilinear_resize_python(data, (h*scale, w*scale), layout)
for target, ctx in ctx_list():
......
......@@ -48,7 +48,7 @@ def test_resize():
if method == "BILINEAR":
ref_res = topi.testing.bilinear_resize_python(x_data, size, layout)
else:
ref_res = topi.testing.upsampling_python(x_data, scale, layout)
ref_res = topi.testing.upsampling_python(x_data, (scale, scale), layout)
x = relay.var("x", relay.TensorType(dshape, "float32"))
z = relay.image.resize(x, size, layout, method, False)
assert "size=" in z.astext()
......
......@@ -101,15 +101,12 @@ inline Tensor resize_nearest_neighbor_nhwc(const Tensor& input,
out_shape.push_back(shape[1]);
out_shape.push_back(input->shape[3]);
Expr h_ratio = shape[0] / input->shape[1];
Expr w_ratio = shape[1] / input->shape[2];
return compute(
out_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
idx.push_back(indices[0]);
idx.push_back(indices[1] / h_ratio);
idx.push_back(indices[2] / w_ratio);
idx.push_back(indices[1] * input->shape[1] / shape[0]);
idx.push_back(indices[2] * input->shape[2] / shape[1]);
idx.push_back(indices[3]);
return input(idx);
......@@ -138,16 +135,13 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
Expr h_ratio = shape[0] / input->shape[2];
Expr w_ratio = shape[1] / input->shape[3];
return compute(
out_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
idx.push_back(indices[0]);
idx.push_back(indices[1]);
idx.push_back(indices[2] / h_ratio);
idx.push_back(indices[3] / w_ratio);
idx.push_back(indices[2] * input->shape[2] / shape[0]);
idx.push_back(indices[3] * input->shape[3] / shape[1]);
return input(idx);
}, name, tag);
......@@ -176,16 +170,13 @@ inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
out_shape.push_back(shape[1]);
out_shape.push_back(input->shape[4]);
Expr h_ratio = shape[0] / input->shape[2];
Expr w_ratio = shape[1] / input->shape[3];
return compute(
out_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
idx.push_back(indices[0]);
idx.push_back(indices[1]);
idx.push_back(indices[2] / h_ratio);
idx.push_back(indices[3] / w_ratio);
idx.push_back(indices[2] * input->shape[2] / shape[0]);
idx.push_back(indices[3] * input->shape[3] / shape[1]);
idx.push_back(indices[4]);
return input(idx);
......
......@@ -53,5 +53,4 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale))
else:
raise ValueError("not support this layout {} yet".format(layout))
return topi.cpp.nn.upsampling(data, out_shape, layout, method)
......@@ -16,25 +16,35 @@
# under the License.
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Upsampling in python"""
import math
import numpy as np
def upsample_nearest(arr, scale):
""" Populate the array by scale factor"""
return arr.repeat(scale, axis=0).repeat(scale, axis=1)
h, w = arr.shape
out_h = math.floor(h * scale[0])
out_w = math.floor(w * scale[1])
out = np.empty((out_h, out_w))
for y in range(out_h):
for x in range(out_w):
in_y = math.floor(y / scale[0])
in_x = math.floor(x / scale[1])
out[y, x] = arr[in_y, in_x]
return out
def upsampling_python(data, scale, layout='NCHW'):
""" Python version of scaling using nearest neighbour """
ishape = data.shape
if layout == 'NCHW':
oshape = (ishape[0], ishape[1], ishape[2]*scale, ishape[3]*scale)
oshape = (ishape[0], ishape[1], math.floor(ishape[2]*scale[0]), math.floor(ishape[3]*scale[1]))
output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]):
for c in range(oshape[1]):
output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale)
return output_np
if layout == 'NHWC':
oshape = (ishape[0], ishape[1]*scale, ishape[1]*scale, ishape[3])
oshape = (ishape[0], math.floor(ishape[1]*scale[0]), math.floor(ishape[1]*scale[1]), ishape[3])
output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]):
for c in range(oshape[3]):
......
......@@ -23,8 +23,7 @@ import math
from common import get_all_backend
def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False):
def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, layout='NCHW', align_corners=False, method="BILINEAR"):
if layout == 'NCHW':
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A', dtype='float32')
dtype = A.dtype
......@@ -39,9 +38,14 @@ def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, ou
raise NotImplementedError(
'Layout not supported {} '.format(layout))
B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners)
B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners, method=method)
b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners)
if method == "BILINEAR":
b_np = topi.testing.bilinear_resize_python(a_np, (out_height, out_width), layout, align_corners)
else:
scale_h = out_height / in_height
scale_w = out_width / in_width
b_np = topi.testing.upsampling_python(a_np, (scale_h, scale_w), layout)
def check_device(device):
ctx = tvm.context(device, 0)
......@@ -61,15 +65,19 @@ def verify_bilinear_scale(batch, in_channel, in_height, in_width, out_height, ou
for device in get_all_backend():
check_device(device)
def test_resize():
# Scale NCHW
verify_bilinear_scale(4, 16, 32, 32, 50, 50, 'NCHW')
verify_resize(4, 16, 32, 32, 50, 50, 'NCHW')
# Scale NCHW + Align Corners
verify_bilinear_scale(6, 32, 64, 64, 20, 20, 'NCHW', True)
verify_resize(6, 32, 64, 64, 20, 20, 'NCHW', True)
# Scale NHWC
verify_bilinear_scale(4, 16, 32, 32, 50, 50, "NHWC")
verify_resize(4, 16, 32, 32, 50, 50, "NHWC")
# Scale NHWC + Align Corners
verify_bilinear_scale(6, 32, 64, 64, 20, 20, "NHWC", True)
verify_resize(6, 32, 64, 64, 20, 20, "NHWC", True)
# Nearest + Fractional
verify_resize(4, 16, 32, 32, 50, 50, 'NCHW', method="NEAREST_NEIGHBOR")
verify_resize(4, 16, 32, 32, 50, 50, 'NHWC', method="NEAREST_NEIGHBOR")
if __name__ == "__main__":
test_resize()
......@@ -46,7 +46,7 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH
out_size = (in_height*scale, in_width*scale)
b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout)
else:
b_np = topi.testing.upsampling_python(a_np, scale, layout)
b_np = topi.testing.upsampling_python(a_np, (scale, scale), layout)
def check_device(device):
ctx = tvm.context(device, 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment