Commit 8dd928c7 by Wuwei Lin Committed by masahi

[TOPI] Add roi align (#2350)

* [TOPI] Add roi align

* Refactor bilinear in image resize

* Rename to roi_align_nchw

* Fix
parent 3708b311
......@@ -22,6 +22,45 @@ namespace image {
using namespace tvm;
/*!
* \brief Sample a point in a tensor using bilinear interpolation.
*
* \param input The input tensor.
* \param indices The index of the target point, which can be fractional
* \param max_y The maximum of y dimension
* \param max_x The maximum of x dimension
*
* \return The interpolated value in the given index.
*/
inline Expr bilinear_sample_nchw(const Tensor& input, const Array<Expr>& indices,
const Expr max_y, const Expr max_x) {
auto in_y = indices[2];
auto yf = tvm::floor(in_y);
auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
auto y1 = tvm::select((yc > max_y), max_y, yc);
auto y_lerp = in_y - yf;
auto in_x = indices[3];
auto xf = tvm::floor(in_x);
auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
auto x1 = tvm::select((xc > max_x), max_x, xc);
auto x_lerp = in_x - xf;
auto A = input(indices[0], indices[1], y0, x0);
auto B = input(indices[0], indices[1], y0, x1);
auto C = input(indices[0], indices[1], y1, x0);
auto D = input(indices[0], indices[1], y1, x1);
auto top = A + (B - A) * x_lerp;
auto bottom = C + (D - C) * x_lerp;
return (top + (bottom - top) * y_lerp);
}
/*!
* \brief Resize given tensor to given shape using nearest neighbour for NHWC
*
* \param input The input tensor.
......@@ -249,30 +288,8 @@ inline Tensor resize_bilinear_nchw(const Tensor& input,
return compute(
out_shape, [&](const Array<Var>& indices) {
auto in_y = indices[2] * y_ratio;
auto yf = tvm::floor(in_y);
auto yc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_y));
auto y0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_y));
auto y1 = tvm::select((yc > other_y), other_y, yc);
auto y_lerp = in_y - yf;
auto in_x = indices[3] * x_ratio;
auto xf = tvm::floor(in_x);
auto xc = HalideIR::Internal::Cast::make(Int(32), tvm::ceil(in_x));
auto x0 = HalideIR::Internal::Cast::make(Int(32), tvm::floor(in_x));
auto x1 = tvm::select((xc > other_x), other_x, xc);
auto x_lerp = in_x - xf;
auto A = input(indices[0], indices[1], y0, x0);
auto B = input(indices[0], indices[1], y0, x1);
auto C = input(indices[0], indices[1], y1, x0);
auto D = input(indices[0], indices[1], y1, x1);
auto top = A + (B - A) * x_lerp;
auto bottom = C + (D - C) * x_lerp;
return (top + (bottom - top) * y_lerp);
return bilinear_sample_nchw(input, {indices[0], indices[1], in_y, in_x}, other_y, other_x);
}, name, tag);
}
......
......@@ -5,6 +5,7 @@ import tvm
from .. import generic
from .. import cpp
from .. import tag
from .pooling import schedule_pool
def _default_schedule(outs):
"""Default schedule for gpu."""
......@@ -146,3 +147,7 @@ def schedule_multibox_detection(outs):
The computation schedule for multibox_detection.
"""
return _default_schedule(outs)
@generic.schedule_roi_align.register(["cuda", "gpu"])
def schedule_roi_align(outs):
return schedule_pool(outs, 'NCHW')
......@@ -140,3 +140,20 @@ def schedule_multibox_detection(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_roi_align(outs):
"""Schedule for roi_align
Parameters
----------
outs: Array of Tensor
The computation graph description of roi_align
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
......@@ -15,6 +15,7 @@ from .upsampling_python import upsampling_python
from .bilinear_resize_python import bilinear_resize_python
from .reorg_python import reorg_python
from .region_python import region_python
from .roi_align_python import roi_align_nchw_python
from .shortcut_python import shortcut_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
......
# pylint: disable=invalid-name, too-many-nested-blocks
"Roi align in python"
import math
import numpy as np
def roi_align_nchw_python(a_np, rois_np, pooled_size, spatial_scale, sample_ratio):
"""Roi align in python"""
_, channel, height, width = a_np.shape
num_roi = rois_np.shape[0]
b_np = np.zeros((num_roi, channel, pooled_size, pooled_size), dtype=a_np.dtype)
if isinstance(pooled_size, int):
pooled_size_h = pooled_size_w = pooled_size
else:
pooled_size_h, pooled_size_w = pooled_size
def _bilinear(b, c, y, x):
if y < -1 or y > height or x < -1 or x > width:
return 0
y = max(y, 0.0)
x = max(x, 0.0)
y_low = int(y)
x_low = int(x)
y_high = min(y_low + 1, height - 1)
x_high = min(x_low + 1, width - 1)
ly = y - y_low
lx = x - x_low
return (1 - ly) * (1 - lx) * a_np[b, c, y_low, x_low] + \
(1 - ly) * lx * a_np[b, c, y_low, x_high] + \
ly * (1 - lx) * a_np[b, c, y_high, x_low] + \
ly * lx * a_np[b, c, y_high, x_high]
for i in range(num_roi):
roi = rois_np[i]
batch_index = int(roi[0])
roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1:] * spatial_scale
roi_h = max(roi_end_h - roi_start_h, 1.0)
roi_w = max(roi_end_w - roi_start_w, 1.0)
bin_h = roi_h / pooled_size_h
bin_w = roi_w / pooled_size_w
if sample_ratio > 0:
roi_bin_grid_h = roi_bin_grid_w = int(sample_ratio)
else:
roi_bin_grid_h = int(math.ceil(roi_h / pooled_size))
roi_bin_grid_w = int(math.ceil(roi_w / pooled_size))
count = roi_bin_grid_h * roi_bin_grid_w
for c in range(channel):
for ph in range(pooled_size_h):
for pw in range(pooled_size_w):
total = 0.
for iy in range(roi_bin_grid_h):
for ix in range(roi_bin_grid_w):
y = roi_start_h + ph * bin_h + (iy + 0.5) * bin_h / roi_bin_grid_h
x = roi_start_w + pw * bin_w + (ix + 0.5) * bin_w / roi_bin_grid_w
total += _bilinear(batch_index, c, y, x)
b_np[i, c, ph, pw] = total / count
return b_np
......@@ -6,3 +6,4 @@ from . import yolo, ssd
from .shortcut import *
from .reorg import *
from .nms import *
from .rcnn import *
# pylint: disable=wildcard-import
"""Faster R-CNN and Mask R-CNN operators"""
from .roi_align import *
# pylint: disable=invalid-name
"""Roi align operator"""
import tvm
from ...util import get_const_tuple
from ...cpp.image import bilinear_sample_nchw
@tvm.target.generic_func
def roi_align_nchw(data, rois, pooled_size, spatial_scale, sample_ratio=-1):
"""ROI align operator in NCHW layout.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, height, width]
rois : tvm.Tensor
2-D with shape [num_roi, 5]. The last dimension should be in format of
[batch_index, w_start, h_start, w_end, h_end]
pooled_size : int or list/tuple of two ints
output size, or [out_height, out_width]
spatial_scale : float
Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal
of total stride in convolutional layers, which should be in range (0.0, 1.0]
sample_ratio : int
Optional sampling ratio of ROI align, using adaptive size by default.
Returns
-------
output : tvm.Tensor
4-D with shape [num_roi, channel, pooled_size, pooled_size]
"""
dtype = rois.dtype
_, channel, height, width = get_const_tuple(data.shape)
num_roi, _ = get_const_tuple(rois.shape)
if isinstance(pooled_size, int):
pooled_size_h = pooled_size_w = pooled_size
else:
pooled_size_h, pooled_size_w = pooled_size
def _bilinear(i, c, y, x):
outside = tvm.any(y < -1.0, x < -1.0, y > height, x > width)
y = tvm.max(y, 0.0)
x = tvm.max(x, 0.0)
val = bilinear_sample_nchw(data, (i, c, y, x), height - 1, width - 1)
return tvm.select(outside, 0.0, val)
def _sample(i, c, ph, pw):
roi = rois[i]
batch_index = roi[0].astype('int32')
roi_start_w, roi_start_h, roi_end_w, roi_end_h = roi[1], roi[2], roi[3], roi[4]
roi_start_h *= spatial_scale
roi_end_h *= spatial_scale
roi_start_w *= spatial_scale
roi_end_w *= spatial_scale
# force malformed ROIs to be 1x1
roi_h = tvm.max(roi_end_h - roi_start_h, tvm.const(1.0, dtype))
roi_w = tvm.max(roi_end_w - roi_start_w, tvm.const(1.0, dtype))
bin_h = roi_h / pooled_size_h
bin_w = roi_w / pooled_size_w
if sample_ratio > 0:
roi_bin_grid_h = roi_bin_grid_w = tvm.const(sample_ratio, 'int32')
else:
roi_bin_grid_h = tvm.ceil(roi_h / pooled_size).astype('int32')
roi_bin_grid_w = tvm.ceil(roi_w / pooled_size).astype('int32')
count = roi_bin_grid_h * roi_bin_grid_w
rh = tvm.reduce_axis((0, roi_bin_grid_h))
rw = tvm.reduce_axis((0, roi_bin_grid_w))
roi_start_h += ph * bin_h
roi_start_w += pw * bin_w
return tvm.sum(_bilinear(batch_index, c,
roi_start_h + (rh + 0.5) * bin_h / roi_bin_grid_h,
roi_start_w + (rw + 0.5) * bin_w / roi_bin_grid_w) / count,
axis=[rh, rw])
return tvm.compute((num_roi, channel, pooled_size_h, pooled_size_w), _sample,
tag='pool,roi_align_nchw')
......@@ -430,6 +430,11 @@ TVM_REGISTER_GLOBAL("topi.vision.yolo.region")
});
/* Ops from image/resize.h */
TVM_REGISTER_GLOBAL("topi.image.bilinear_sample_nchw")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = image::bilinear_sample_nchw(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_GLOBAL("topi.image.resize")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = image::resize(args[0], args[1], args[2], args[3], args[4]);
......
"""Test code for vision package"""
import math
import numpy as np
import tvm
import topi
import math
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from topi.vision import ssd, nms
......@@ -154,7 +157,57 @@ def test_multibox_detection():
check_device(device)
def verify_roi_align(batch, in_channel, in_size, num_roi, pooled_size, spatial_scale, sample_ratio):
a_shape = (batch, in_channel, in_size, in_size)
rois_shape = (num_roi, 5)
a = tvm.placeholder(a_shape)
rois = tvm.placeholder(rois_shape)
@memoize("topi.tests.test_topi_vision.verify_roi_align")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype('float32')
rois_np = np.random.uniform(size=rois_shape).astype('float32') * in_size
rois_np[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi)
b_np = topi.testing.roi_align_nchw_python(a_np, rois_np, pooled_size=pooled_size,
spatial_scale=spatial_scale,
sample_ratio=sample_ratio)
return a_np, rois_np, b_np
a_np, rois_np, b_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
b = topi.vision.rcnn.roi_align_nchw(a, rois, pooled_size=pooled_size,
spatial_scale=spatial_scale,
sample_ratio=sample_ratio)
s = topi.generic.schedule_roi_align(b)
tvm_a = tvm.nd.array(a_np, ctx)
tvm_rois = tvm.nd.array(rois_np, ctx)
tvm_b = tvm.nd.array(np.zeros(get_const_tuple(b.shape), dtype=b.dtype), ctx=ctx)
f = tvm.build(s, [a, rois, b], device)
f(tvm_a, tvm_rois, tvm_b)
tvm.testing.assert_allclose(tvm_b.asnumpy(), b_np, rtol=1e-3)
for device in ['llvm', 'cuda']:
check_device(device)
def test_roi_align():
verify_roi_align(1, 16, 32, 64, 7, 1.0, -1)
verify_roi_align(4, 16, 32, 64, 7, 0.5, 2)
if __name__ == "__main__":
test_nms()
test_multibox_prior()
test_multibox_detection()
test_roi_align()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment