Unverified Commit 974195de by Hua Jiang Committed by GitHub

[TOPI] upsample operator 'NCHWinic' format support. (#4791)

* [TOPI] upsample operator 'NCHWinic' format support.

some hardware accelerator ask packed format data like NCHWinic to fit the
hardware resource, here add upsample NCHWinic format support to help
such requirement.

* address review comments, add assert for 'else must be NCHWxc' logic.
parent c39ab93d
...@@ -18,8 +18,42 @@ ...@@ -18,8 +18,42 @@
"""TVM operator input resize compute.""" """TVM operator input resize compute."""
from __future__ import absolute_import from __future__ import absolute_import
import tvm import tvm
from topi.util import nchw_pack_layout, nchw_xc_layout
from .. import tag from .. import tag
def get_2d_indices(indices, layout='NCHW'):
""" Get 2d indices """
(cc, inum, ic) = (0, 0, 0)
if layout == 'NHWC':
n, y, x, c = indices
cc = None
elif layout == 'NCHW':
n, c, y, x = indices
cc = None
elif nchw_pack_layout(layout):
n, c, y, x, inum, ic = indices
else:
# else must be NCHWxc
assert nchw_xc_layout(layout)
n, c, y, x, cc = indices
return n, c, y, x, cc, inum, ic
def get_2d_pixel(data, layout, boxes, image_height, image_width, n, c, y, x, cc, ib, ic):
""" Get 2d pixel """
if boxes is None:
y = tvm.max(tvm.min(y, image_height - 1), 0)
x = tvm.max(tvm.min(x, image_width - 1), 0)
if layout == 'NHWC':
return data(n, y, x, c).astype('float')
if layout == 'NCHW':
return data(n, c, y, x).astype('float')
if nchw_pack_layout(layout):
return data(n, c, y, x, ib, ic).astype('float')
# else must be NCHWxc
assert nchw_xc_layout(layout)
return data(n, c, y, x, cc).astype('float')
def resize_nearest_neighbor(indices, data, image_height, image_width, def resize_nearest_neighbor(indices, data, image_height, image_width,
target_height, target_width, boxes=None, target_height, target_width, boxes=None,
...@@ -89,29 +123,7 @@ def resize_nearest_neighbor(indices, data, image_height, image_width, ...@@ -89,29 +123,7 @@ def resize_nearest_neighbor(indices, data, image_height, image_width,
dtype = data_dtype dtype = data_dtype
return value.astype(dtype) return value.astype(dtype)
def _get_indices(indices, layout='NCHW'): n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout)
if layout == 'NHWC':
n, y, x, c = indices
cc = None
elif layout == 'NCHW':
n, c, y, x = indices
cc = None
else:
n, c, y, x, cc = indices
return n, c, y, x, cc
def _get_pixel(data, layout, n, c, y, x, cc):
if boxes is None:
y = tvm.max(tvm.min(y, image_height - 1), 0)
x = tvm.max(tvm.min(x, image_width - 1), 0)
if layout == 'NHWC':
return data(n, y, x, c).astype('float')
if layout == 'NCHW':
return data(n, c, y, x).astype('float')
# else must be NCHWxc
return data(n, c, y, x, cc).astype('float')
n, c, y, x, cc = _get_indices(indices, layout)
box_idx = box_indices(n) if box_indices is not None else n box_idx = box_indices(n) if box_indices is not None else n
if boxes is not None: if boxes is not None:
y1, x1 = boxes(n, 0), boxes(n, 1) y1, x1 = boxes(n, 0), boxes(n, 1)
...@@ -146,7 +158,8 @@ def resize_nearest_neighbor(indices, data, image_height, image_width, ...@@ -146,7 +158,8 @@ def resize_nearest_neighbor(indices, data, image_height, image_width,
closest_y_index = tvm.floor(in_y + epsilon).astype('int32') closest_y_index = tvm.floor(in_y + epsilon).astype('int32')
closest_x_index = tvm.floor(in_x + epsilon).astype('int32') closest_x_index = tvm.floor(in_x + epsilon).astype('int32')
value = _get_pixel(data, layout, box_idx, c, closest_y_index, closest_x_index, cc) value = get_2d_pixel(data, layout, boxes, image_height, image_width,
box_idx, c, closest_y_index, closest_x_index, cc, inum, ic)
if extrapolation_value is not None: if extrapolation_value is not None:
out = tvm.if_then_else(in_y < 0, out = tvm.if_then_else(in_y < 0,
...@@ -234,29 +247,7 @@ def resize_bilinear(indices, data, image_height, image_width, ...@@ -234,29 +247,7 @@ def resize_bilinear(indices, data, image_height, image_width,
def _lerp(A, B, t): def _lerp(A, B, t):
return A * (1.0 - t) + B * t return A * (1.0 - t) + B * t
def _get_indices(indices, layout='NCHW'): n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout=layout)
if layout == 'NHWC':
n, y, x, c = indices
cc = None
elif layout == 'NCHW':
n, c, y, x = indices
cc = None
else:
n, c, y, x, cc = indices
return n, c, y, x, cc
def _get_pixel(data, layout, n, c, y, x, cc):
if boxes is None:
y = tvm.max(tvm.min(y, image_height - 1), 0)
x = tvm.max(tvm.min(x, image_width - 1), 0)
if layout == 'NHWC':
return data(n, y, x, c).astype('float')
if layout == 'NCHW':
return data(n, c, y, x).astype('float')
# else must be NCHWxc
return data(n, c, y, x, cc).astype('float')
n, c, y, x, cc = _get_indices(indices, layout=layout)
box_idx = box_indices(n) if box_indices is not None else n box_idx = box_indices(n) if box_indices is not None else n
if boxes is not None: if boxes is not None:
...@@ -296,10 +287,14 @@ def resize_bilinear(indices, data, image_height, image_width, ...@@ -296,10 +287,14 @@ def resize_bilinear(indices, data, image_height, image_width,
right_x_index = tvm.ceil(in_x).astype('int32') right_x_index = tvm.ceil(in_x).astype('int32')
x_lerp = in_x - left_x_index x_lerp = in_x - left_x_index
top_left = _get_pixel(data, layout, box_idx, c, top_y_index, left_x_index, cc) top_left = get_2d_pixel(data, layout, boxes, image_height, image_width,
top_right = _get_pixel(data, layout, box_idx, c, top_y_index, right_x_index, cc) box_idx, c, top_y_index, left_x_index, cc, inum, ic)
bottom_left = _get_pixel(data, layout, box_idx, c, bottom_y_index, left_x_index, cc) top_right = get_2d_pixel(data, layout, boxes, image_height, image_width,
bottom_right = _get_pixel(data, layout, box_idx, c, bottom_y_index, right_x_index, cc) box_idx, c, top_y_index, right_x_index, cc, inum, ic)
bottom_left = get_2d_pixel(data, layout, boxes, image_height, image_width,
box_idx, c, bottom_y_index, left_x_index, cc, inum, ic)
bottom_right = get_2d_pixel(data, layout, boxes, image_height, image_width,
box_idx, c, bottom_y_index, right_x_index, cc, inum, ic)
top = _lerp(top_left, top_right, x_lerp) top = _lerp(top_left, top_right, x_lerp)
bottom = _lerp(bottom_left, bottom_right, x_lerp) bottom = _lerp(bottom_left, bottom_right, x_lerp)
...@@ -394,29 +389,7 @@ def resize_bicubic(indices, data, image_height, image_width, ...@@ -394,29 +389,7 @@ def resize_bicubic(indices, data, image_height, image_width,
dtype = data_dtype dtype = data_dtype
return value.astype(dtype) return value.astype(dtype)
def _get_indices(indices, layout='NCHW'): n, c, y, x, cc, inum, ic = get_2d_indices(indices, layout)
if layout == 'NHWC':
n, y, x, c = indices
cc = None
elif layout == 'NCHW':
n, c, y, x = indices
cc = None
else:
n, c, y, x, cc = indices
return n, c, y, x, cc
def _get_pixel(data, layout, n, c, y, x, cc):
if boxes is None:
y = tvm.max(tvm.min(y, image_height - 1), 0)
x = tvm.max(tvm.min(x, image_width - 1), 0)
if layout == 'NHWC':
return data(n, y, x, c).astype('float')
if layout == 'NCHW':
return data(n, c, y, x).astype('float')
# else must be NCHWxc
return data(n, c, y, x, cc).astype('float')
n, c, y, x, cc = _get_indices(indices, layout)
box_idx = box_indices(n) if box_indices is not None else n box_idx = box_indices(n) if box_indices is not None else n
if boxes is not None: if boxes is not None:
...@@ -455,28 +428,44 @@ def resize_bicubic(indices, data, image_height, image_width, ...@@ -455,28 +428,44 @@ def resize_bicubic(indices, data, image_height, image_width,
yfract = in_y - tvm.floor(in_y) yfract = in_y - tvm.floor(in_y)
# 1st row # 1st row
p00 = _get_pixel(data, layout, box_idx, c, yint - 1, xint - 1, cc) p00 = _get_pixel(data, layout, boxes, image_height, image_width,
p10 = _get_pixel(data, layout, box_idx, c, yint - 1, xint + 0, cc) box_idx, c, yint - 1, xint - 1, cc, inum, ic)
p20 = _get_pixel(data, layout, box_idx, c, yint - 1, xint + 1, cc) p10 = _get_pixel(data, layout, boxes, image_height, image_width,
p30 = _get_pixel(data, layout, box_idx, c, yint - 1, xint + 2, cc) box_idx, c, yint - 1, xint + 0, cc, inum, ic)
p20 = _get_pixel(data, layout, boxes, image_height, image_width,
box_idx, c, yint - 1, xint + 1, cc, inum, ic)
p30 = _get_pixel(data, layout, boxes, image_height, image_width,
box_idx, c, yint - 1, xint + 2, cc, inum, ic)
# 2nd row # 2nd row
p01 = _get_pixel(data, layout, box_idx, c, yint + 0, xint - 1, cc) p01 = _get_pixel(data, layout, boxes, image_height, image_width,
p11 = _get_pixel(data, layout, box_idx, c, yint + 0, xint + 0, cc) box_idx, c, yint + 0, xint - 1, cc, inum, ic)
p21 = _get_pixel(data, layout, box_idx, c, yint + 0, xint + 1, cc) p11 = _get_pixel(data, layout, boxes, image_height, image_width,
p31 = _get_pixel(data, layout, box_idx, c, yint + 0, xint + 2, cc) box_idx, c, yint + 0, xint + 0, cc, inum, ic)
p21 = _get_pixel(data, layout, boxes, image_height, image_width,
box_idx, c, yint + 0, xint + 1, cc, inum, ic)
p31 = _get_pixel(data, layout, boxes, image_height, image_width,
box_idx, c, yint + 0, xint + 2, cc, inum, ic)
# 3rd row # 3rd row
p02 = _get_pixel(data, layout, box_idx, c, yint + 1, xint - 1, cc) p02 = _get_pixel(data, layout, boxes, image_height, image_width,
p12 = _get_pixel(data, layout, box_idx, c, yint + 1, xint + 0, cc) box_idx, c, yint + 1, xint - 1, cc, inum, ic)
p22 = _get_pixel(data, layout, box_idx, c, yint + 1, xint + 1, cc) p12 = _get_pixel(data, layout, boxes, image_height, image_width,
p32 = _get_pixel(data, layout, box_idx, c, yint + 1, xint + 2, cc) box_idx, c, yint + 1, xint + 0, cc, inum, ic)
p22 = _get_pixel(data, layout, boxes, image_height, image_width,
box_idx, c, yint + 1, xint + 1, cc, inum, ic)
p32 = _get_pixel(data, layout, boxes, image_height, image_width,
box_idx, c, yint + 1, xint + 2, cc, inum, ic)
# 4th row # 4th row
p03 = _get_pixel(data, layout, box_idx, c, yint + 2, xint - 1, cc) p03 = _get_pixel(data, layout, boxes, image_height, image_width,
p13 = _get_pixel(data, layout, box_idx, c, yint + 2, xint + 0, cc) box_idx, c, yint + 2, xint - 1, cc, inum, ic)
p23 = _get_pixel(data, layout, box_idx, c, yint + 2, xint + 1, cc) p13 = _get_pixel(data, layout, boxes, image_height, image_width,
p33 = _get_pixel(data, layout, box_idx, c, yint + 2, xint + 2, cc) box_idx, c, yint + 2, xint + 0, cc, inum, ic)
p23 = _get_pixel(data, layout, boxes, image_height, image_width,
box_idx, c, yint + 2, xint + 1, cc, inum, ic)
p33 = _get_pixel(data, layout, boxes, image_height, image_width,
box_idx, c, yint + 2, xint + 2, cc, inum, ic)
# Interpolate bicubically # Interpolate bicubically
col0 = _cubic_kernel(p00, p10, p20, p30, xfract) col0 = _cubic_kernel(p00, p10, p20, p30, xfract)
...@@ -536,6 +525,7 @@ def resize(data, size, layout="NCHW", method="bilinear", ...@@ -536,6 +525,7 @@ def resize(data, size, layout="NCHW", method="bilinear",
or [batch, in_height*scale, in_width*scale, channel] or [batch, in_height*scale, in_width*scale, channel]
or 5-D with shape [batch, channel-major, in_height*scale, in_width*scale, channel-minor] or 5-D with shape [batch, channel-major, in_height*scale, in_width*scale, channel-minor]
""" """
method = method.lower() method = method.lower()
if layout == 'NHWC': if layout == 'NHWC':
...@@ -544,7 +534,10 @@ def resize(data, size, layout="NCHW", method="bilinear", ...@@ -544,7 +534,10 @@ def resize(data, size, layout="NCHW", method="bilinear",
elif layout == 'NCHW': elif layout == 'NCHW':
in_n, in_c, in_h, in_w = data.shape in_n, in_c, in_h, in_w = data.shape
output_shape = [in_n, in_c, size[0], size[1]] output_shape = [in_n, in_c, size[0], size[1]]
elif layout.startswith("NCHW"):# for NCHWxc elif nchw_pack_layout(layout):# for NCHWinic
in_n, in_c, in_h, in_w, in_inum, in_ic = data.shape
output_shape = [in_n, in_c, size[0], size[1], in_inum, in_ic]
elif nchw_xc_layout(layout):# for NCHWxc
in_n, in_c, in_h, in_w, in_cc = data.shape in_n, in_c, in_h, in_w, in_cc = data.shape
output_shape = [in_n, in_c, size[0], size[1], in_cc] output_shape = [in_n, in_c, size[0], size[1], in_cc]
else: else:
......
...@@ -18,14 +18,20 @@ ...@@ -18,14 +18,20 @@
"""Bilinear Scale in python""" """Bilinear Scale in python"""
import math import math
import numpy as np import numpy as np
from topi.util import nchw_pack_layout
def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mode="align_corners"): def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mode="align_corners"):
""" Bilinear scaling using python""" """ Bilinear scaling using python"""
(new_h, new_w) = out_size (new_h, new_w) = out_size
(ib, ic) = (1, 1)
if layout == 'NHWC': if layout == 'NHWC':
(batch, h, w, channel) = image.shape (batch, h, w, channel) = image.shape
scaled_image = np.ones((batch, new_h, new_w, channel)) scaled_image = np.ones((batch, new_h, new_w, channel))
# NCHWinic
elif nchw_pack_layout(layout):
(batch, channel, h, w, ib, ic) = image.shape
scaled_image = np.ones((batch, channel, new_h, new_w, ib, ic))
else: else:
(batch, channel, h, w) = image.shape (batch, channel, h, w) = image.shape
scaled_image = np.ones((batch, channel, new_h, new_w)) scaled_image = np.ones((batch, channel, new_h, new_w))
...@@ -40,8 +46,7 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo ...@@ -40,8 +46,7 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo
def _lerp(A, B, t): def _lerp(A, B, t):
return A * (1.0 - t) + B * t return A * (1.0 - t) + B * t
for b in range(batch): def _img_scale(b, m, i, n):
for i in range(channel):
for j in range(new_h): for j in range(new_h):
for k in range(new_w): for k in range(new_w):
if coordinate_transformation_mode == "half_pixel": if coordinate_transformation_mode == "half_pixel":
...@@ -67,6 +72,11 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo ...@@ -67,6 +72,11 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo
B = image[b][y0][x1][i] B = image[b][y0][x1][i]
C = image[b][y1][x0][i] C = image[b][y1][x0][i]
D = image[b][y1][x1][i] D = image[b][y1][x1][i]
elif nchw_pack_layout(layout):
A = image[b][i][y0][x0][m][n]
B = image[b][i][y0][x1][m][n]
C = image[b][i][y1][x0][m][n]
D = image[b][i][y1][x1][m][n]
else: else:
A = image[b][i][y0][x0] A = image[b][i][y0][x0]
B = image[b][i][y0][x1] B = image[b][i][y0][x1]
...@@ -80,7 +90,15 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo ...@@ -80,7 +90,15 @@ def bilinear_resize_python(image, out_size, layout, coordinate_transformation_mo
if layout == 'NHWC': if layout == 'NHWC':
scaled_image[b][j][k][i] = pixel scaled_image[b][j][k][i] = pixel
elif nchw_pack_layout(layout):
scaled_image[b][i][j][k][m][n] = pixel
else: else:
scaled_image[b][i][j][k] = pixel scaled_image[b][i][j][k] = pixel
for b in range(batch):
for m in range(ib):
for i in range(channel):
for n in range(ic):
_img_scale(b, m, i, n)
return scaled_image return scaled_image
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
"""Upsampling in python""" """Upsampling in python"""
import math import math
import numpy as np import numpy as np
from topi.util import nchw_pack_layout
def upsample_nearest(arr, scale): def upsample_nearest(arr, scale):
""" Populate the array by scale factor""" """ Populate the array by scale factor"""
...@@ -44,6 +46,18 @@ def upsampling_python(data, scale, layout='NCHW'): ...@@ -44,6 +46,18 @@ def upsampling_python(data, scale, layout='NCHW'):
for c in range(oshape[1]): for c in range(oshape[1]):
output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale) output_np[b, c, :, :] = upsample_nearest(data[b, c, :, :], scale)
return output_np return output_np
# NCHWinic
if nchw_pack_layout(layout):
oshape = (ishape[0], ishape[1], int(round(ishape[2]*scale[0])),
int(round(ishape[3]*scale[1])), ishape[4], ishape[5])
output_np = np.zeros(oshape, dtype=data.dtype)
for b in range(oshape[0]):
for ib in range(oshape[4]):
for c in range(oshape[1]):
for ic in range(oshape[5]):
output_np[b, c, :, :, ib, ic] = upsample_nearest(data[b, c, :, :, ib, ic], scale)
return output_np
if layout == 'NHWC': if layout == 'NHWC':
oshape = (ishape[0], int(round(ishape[1]*scale[0])), oshape = (ishape[0], int(round(ishape[1]*scale[0])),
int(round(ishape[2]*scale[1])), ishape[3]) int(round(ishape[2]*scale[1])), ishape[3])
......
...@@ -27,6 +27,14 @@ class InvalidShapeError(ValueError): ...@@ -27,6 +27,14 @@ class InvalidShapeError(ValueError):
"""Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)""" """Invalid shape for a topi function. i.e. call winograd template for non-3x3 kernel)"""
pass pass
def nchw_pack_layout(layout_info):
"""Check whether the layout type is NCHWinic"""
return layout_info[:4] == 'NCHW' and 'c' in layout_info and 'n' in layout_info
def nchw_xc_layout(layout_info):
"""Check whether the layout type is NCHWxc"""
return layout_info[:4] == 'NCHW' and 'c' in layout_info and layout_info[4:-1].isnumeric()
def traverse_inline(s, final_op, callback): def traverse_inline(s, final_op, callback):
"""Traverse computation graph and do auto inline """Traverse computation graph and do auto inline
......
...@@ -20,16 +20,26 @@ import tvm ...@@ -20,16 +20,26 @@ import tvm
import topi import topi
import topi.testing import topi.testing
import math import math
from topi.util import nchw_pack_layout
from common import get_all_backend from common import get_all_backend
def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w, def verify_upsampling(batch, in_channel, in_height, in_width, scale_h, scale_w,
layout='NCHW', method="nearest_neighbor"): layout='NCHW', method="nearest_neighbor",
in_batch_block = 0, in_channel_block = 0):
if layout == 'NCHW': if layout == 'NCHW':
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A') A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
dtype = A.dtype dtype = A.dtype
out_shape = (batch, in_channel, int(round(in_height*scale_h)), int(round(in_width*scale_w))) out_shape = (batch, in_channel, int(round(in_height*scale_h)), int(round(in_width*scale_w)))
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype) a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width)).astype(dtype)
elif nchw_pack_layout(layout):
A = tvm.placeholder((batch, in_channel, in_height, in_width, in_batch_block, in_channel_block),
name='A')
dtype = A.dtype
out_shape = (batch, in_channel, int(round(in_height*scale_h)), int(round(in_width*scale_w)),
in_batch_block, in_channel_block)
a_np = np.random.uniform(size=(batch, in_channel, in_height, in_width,
in_batch_block, in_channel_block)).astype(dtype)
elif layout == 'NHWC': elif layout == 'NHWC':
A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A') A = tvm.placeholder((batch, in_height, in_width, in_channel), name='A')
dtype = A.dtype dtype = A.dtype
...@@ -81,6 +91,22 @@ def test_upsampling(): ...@@ -81,6 +91,22 @@ def test_upsampling():
verify_upsampling(2, 2, 32, 32, 3.0, 3.0, method="bilinear") verify_upsampling(2, 2, 32, 32, 3.0, 3.0, method="bilinear")
verify_upsampling(1, 64, 22, 32, 1.954545497894287, 2.0, method="bilinear") verify_upsampling(1, 64, 22, 32, 1.954545497894287, 2.0, method="bilinear")
# nearest_neighbor - NCHWinic
verify_upsampling(2, 2, 32, 32, in_batch_block=4, in_channel_block=8,
scale_h=2.0, scale_w=2.0)
verify_upsampling(2, 2, 64, 64, in_batch_block=1, in_channel_block=16,
scale_h=3.0, scale_w=3.0)
verify_upsampling(1, 4, 22, 32, in_batch_block=1, in_channel_block=16,
scale_h=1.954545497894287, scale_w=2.0)
# bilinear - NCHWinic
verify_upsampling(2, 2, 32, 32, in_batch_block=1, in_channel_block=1,
scale_h=2.0, scale_w=2.0, method="bilinear")
verify_upsampling(2, 2, 32, 32, in_batch_block=1, in_channel_block=1,
scale_h=3.0, scale_w=3.0, method="bilinear")
verify_upsampling(2, 4, 22, 32, in_batch_block=1, in_channel_block=16,
scale_h=1.954545497894287, scale_w=2.0, layout="NCHW1n16c", method="bilinear")
# bilinear - NHWC # bilinear - NHWC
verify_upsampling(2, 2, 32, 32, 2.0, 2.0, layout="NHWC", method="bilinear") verify_upsampling(2, 2, 32, 32, 2.0, 2.0, layout="NHWC", method="bilinear")
verify_upsampling(2, 2, 32, 32, 3.0, 3.0, layout="NHWC", method="bilinear") verify_upsampling(2, 2, 32, 32, 3.0, 3.0, layout="NHWC", method="bilinear")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment