Commit e6863990 by Leyuan Wang Committed by Tianqi Chen

[TOPI] Add GPU SSD (#1397)

parent 3a0b757c
......@@ -15,6 +15,8 @@ from .dense import dense_cuda, schedule_dense
from .pooling import schedule_pool, schedule_global_pool
from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
from .extern import schedule_extern
from .vision import schedule_region
from .vision import schedule_reorg
from .nn import schedule_lrn, schedule_l2_normalize
from .vision import *
from . import ssd
from .ssd import *
from .nms import *
# pylint: disable=wildcard-import
"""VISION network operators"""
from __future__ import absolute_import as _abs
from .multibox import *
# pylint: disable=invalid-name, unused-variable, unused-argument
# pylint: disable=invalid-name, unused-variable, unused-argument, no-member
"""Schedule for vision operators"""
from __future__ import absolute_import as _abs
import tvm
from .. import generic
from .. import cpp
from .. import tag
def _default_schedule(outs):
"""Default schedule for gpu."""
target = tvm.target.current_target()
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
if "nms" in op.tag:
sort = op.input_tensors[1]
score = s[sort].op.input_tensors[0]
fused = s[score].fuse(*s[score].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = s[score].split(fused, factor=num_thread)
s[score].bind(bx, tvm.thread_axis("blockIdx.x"))
s[score].bind(tx, tvm.thread_axis("threadIdx.x"))
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else:
x = op.output(0)
fused = s[x].fuse(*s[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = s[x].split(fused, factor=num_thread)
s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
for tensor in op.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
traverse(outs[0].op)
return s
@generic.schedule_reorg.register(["cuda", "gpu"])
def schedule_reorg(outs):
......@@ -41,8 +74,25 @@ def schedule_region(outs):
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.cuda.schedule_region(cpp_target, outs)
@generic.schedule_nms.register(["cuda", "gpu"])
def schedule_nms(outs):
"""Schedule for non-maximum suppression
Parameters
----------
outs: Array of Tensor
The computation graph description of nms
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs)
@generic.schedule_multibox_prior.register(["cuda", "gpu"])
def schedule_multibox_prior(out):
def schedule_multibox_prior(outs):
"""Schedule for multibox_prior operator.
Parameters
......@@ -56,10 +106,28 @@ def schedule_multibox_prior(out):
s: Schedule
The computation schedule for multibox_prior.
"""
raise RuntimeError("Currently multibox_prior only supports CPU.")
return _default_schedule(outs)
@generic.schedule_multibox_transform_loc.register(["cuda", "gpu"])
def schedule_multibox_transform_loc(outs):
"""Schedule for multibox_transform_loc
Parameters
----------
outs: Array of Tensor
The computation graph description of
multibox_transform_loc in the format
of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs)
@generic.schedule_multibox_detection.register(["cuda", "gpu"])
def schedule_multibox_detection(out):
def schedule_multibox_detection(outs):
"""Schedule for multibox_detection operator.
Parameters
......@@ -73,4 +141,4 @@ def schedule_multibox_detection(out):
s: Schedule
The computation schedule for multibox_detection.
"""
raise RuntimeError("Currently multibox_detection only supports CPU.")
return _default_schedule(outs)
......@@ -20,23 +20,27 @@ def verify_clip(N, a_min, a_max, dtype):
a_np, b_np = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
ctx = tvm.cpu(0) if device == "llvm" else tvm.gpu(0)
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective(B)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device, name="clip")
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm']:
for device in ['llvm', 'opencl']:
check_device(device)
def test_clip():
verify_clip(1024, -127, 127, 'int8')
verify_clip(1024, -127, 127, 'int16')
verify_clip(1024, -127, 127, 'float32')
verify_clip(1024, -127, 127, 'int16')
verify_clip(1024, -127, 127, 'int8')
if __name__ == "__main__":
......
......@@ -8,6 +8,8 @@ from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1):
print("Workload: (%d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding))
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
......@@ -59,7 +61,7 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
def test_conv2d_nchw():
# ResNet18 worklaods
# ResNet18 workloads
verify_conv2d_nchw(1, 3, 224, 64, 7, 2, 3)
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0)
......@@ -72,6 +74,21 @@ def test_conv2d_nchw():
verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
# ResNet 50 workloads
verify_conv2d_nchw(1, 64, 56, 256, 1, 1, 0)
verify_conv2d_nchw(1, 256, 56, 64, 1, 1, 0)
verify_conv2d_nchw(1, 256, 56, 128, 1, 2, 0)
verify_conv2d_nchw(1, 128, 28, 512, 1, 1, 0)
verify_conv2d_nchw(1, 256, 56, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 28, 128, 1, 1, 0)
verify_conv2d_nchw(1, 512, 28, 256, 1, 2, 0)
verify_conv2d_nchw(1, 256, 14, 1024, 1, 1, 0)
verify_conv2d_nchw(1, 512, 28, 1024, 1, 2, 0)
verify_conv2d_nchw(1, 1024, 14, 256, 1, 1, 0)
verify_conv2d_nchw(1, 1024, 14, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 7, 2048, 1, 2, 0)
verify_conv2d_nchw(1, 1024, 14, 2048, 1, 2, 0)
verify_conv2d_nchw(1, 2048, 7, 512, 1, 1, 0)
# Vgg16 workloads
verify_conv2d_nchw(1, 128, 122, 128, 3, 1, 1)
# Super resolution workloads
......
......@@ -14,7 +14,6 @@ def test_nms():
nms_threshold = 0.7
force_suppress = True
nms_topk = 2
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
......@@ -31,6 +30,10 @@ def test_nms():
return
print("Running on target: %s" % device)
with tvm.target.create(device):
if device == 'llvm':
out = nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
else:
out = topi.cuda.nms(data, valid_count, nms_threshold, force_suppress, nms_topk)
s = topi.generic.schedule_nms(out)
tvm_data = tvm.nd.array(np_data, ctx)
......@@ -40,13 +43,12 @@ def test_nms():
f(tvm_data, tvm_valid_count, tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), np_result, rtol=1e-4)
for device in ['llvm']:
for device in ['llvm', 'opencl']:
check_device(device)
def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False):
data = tvm.placeholder(dshape, name="data")
out = ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip)
dtype = data.dtype
input_data = np.random.uniform(size=dshape).astype(dtype)
......@@ -88,15 +90,19 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offse
return
print("Running on target: %s" % device)
with tvm.target.create(device):
if device == 'llvm':
out = ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip)
else:
out = topi.cuda.ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip)
s = topi.generic.schedule_multibox_prior(out)
tvm_input_data = tvm.nd.array(input_data, ctx)
tvm_out = tvm.nd.array(np.zeros(oshape, dtype=dtype), ctx)
f = tvm.build(s, [data, out], device)
f(tvm_input_data, tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-4)
np.testing.assert_allclose(tvm_out.asnumpy(), np_out, rtol=1e-3)
for device in ['llvm']:
for device in ['llvm', 'opencl']:
check_device(device)
......@@ -113,7 +119,6 @@ def test_multibox_detection():
cls_prob = tvm.placeholder((batch_size, num_anchors, num_classes), name="cls_prob")
loc_preds = tvm.placeholder((batch_size, num_anchors * 4), name="loc_preds")
anchors = tvm.placeholder((1, num_anchors, 4), name="anchors")
out = ssd.multibox_detection(cls_prob, loc_preds, anchors)
# Manually create test case
np_cls_prob = np.array([[[0.2, 0.5, 0.3], [0.25, 0.3, 0.45], [0.7, 0.1, 0.2]]])
......@@ -131,6 +136,10 @@ def test_multibox_detection():
return
print("Running on target: %s" % device)
with tvm.target.create(device):
if device == 'llvm':
out = ssd.multibox_detection(cls_prob, loc_preds, anchors)
else:
out = topi.cuda.ssd.multibox_detection(cls_prob, loc_preds, anchors)
s = topi.generic.schedule_multibox_detection(out)
tvm_cls_prob = tvm.nd.array(np_cls_prob.astype(cls_prob.dtype), ctx)
......@@ -141,11 +150,11 @@ def test_multibox_detection():
f(tvm_cls_prob, tvm_loc_preds, tvm_anchors, tvm_out)
np.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, rtol=1e-4)
for device in ['llvm']:
for device in ['llvm', 'opencl']:
check_device(device)
if __name__ == "__main__":
test_nms()
test_multibox_prior()
test_multibox_detection()
\ No newline at end of file
test_multibox_detection()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment