Commit d2b22a6f by Siva Committed by Tianqi Chen

Bugfix #1692. Constant folding and result comparision allowance. (#1708)

parent 0565fcc8
"""TVM operator upsampling compute."""
from __future__ import absolute_import
import topi
from ..util import simplify
def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
......@@ -31,9 +32,9 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
"""
if layout == "NCHW":
out_shape = (data.shape[2] * scale, data.shape[3] * scale)
out_shape = (simplify(data.shape[2] * scale), simplify(data.shape[3] * scale))
elif layout == "NHWC":
out_shape = (data.shape[1] * scale, data.shape[2] * scale)
out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale))
else:
raise ValueError("not support this layout {} yet".format(layout))
......
......@@ -5,7 +5,7 @@ import topi
import topi.testing
import math
def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW'):
def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCHW', method="NEAREST_NEIGHBOR"):
if layout == 'NCHW':
......@@ -22,9 +22,13 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH
raise NotImplementedError(
'Layout not supported {} '.format(layout))
B = topi.nn.upsampling(A, scale, layout=layout)
B = topi.nn.upsampling(A, scale, layout=layout, method=method)
b_np = topi.testing.upsampling_python(a_np, scale, layout)
if method == "BILINEAR":
out_size = (in_height*scale, in_width*scale)
b_np = topi.testing.bilinear_resize_python(a_np, out_size, layout)
else:
b_np = topi.testing.upsampling_python(a_np, scale, layout)
def check_device(device):
ctx = tvm.context(device, 0)
......@@ -39,18 +43,27 @@ def verify_upsampling(batch, in_channel, in_height, in_width, scale, layout='NCH
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
for device in ['llvm', 'cuda', 'vulkan', 'nvptx']:
check_device(device)
def test_upsampling():
# NCHW
# NEAREST_NEIGHBOR - NCHW
verify_upsampling(8, 16, 32, 32, 2)
verify_upsampling(12, 32, 64, 64, 3)
# NHWC
verify_upsampling(8, 16, 32, 32, 2, "NHWC")
verify_upsampling(12, 32, 64, 64, 3, "NHWC")
# NEAREST_NEIGHBOR - NHWC
verify_upsampling(8, 16, 32, 32, 2, layout="NHWC")
verify_upsampling(12, 32, 64, 64, 3, layout="NHWC")
# BILINEAR - NCHW
verify_upsampling(2, 2, 32, 32, 2, method="BILINEAR")
verify_upsampling(2, 2, 32, 32, 3, method="BILINEAR")
# BILINEAR - NHWC
verify_upsampling(2, 2, 32, 32, 2, layout="NHWC", method="BILINEAR")
verify_upsampling(2, 2, 32, 32, 3, layout="NHWC", method="BILINEAR")
if __name__ == "__main__":
test_upsampling()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment