Commit 7742123a by Pariksheet Pinjari Committed by Tianqi Chen

Yolo2 operators (#911)

parent 38274115
......@@ -25,6 +25,7 @@ from . import opengl
from . import util
from . import rocm
from . import cpp
from . import vision
# not import testing by default
# because testing can have extra deps that are not necessary
# we can import them from test cases explicitly
......
......@@ -15,3 +15,4 @@ from .dense import dense_cuda, schedule_dense
from .pooling import schedule_pool, schedule_global_pool
from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
from .extern import schedule_extern
from .vision import schedule_region
# pylint: disable=invalid-name, unused-variable
"""Schedule for vision operators"""
from __future__ import absolute_import as _abs
import tvm
from .. import tag
from .. import generic
@generic.schedule_region.register(["cuda", "gpu"])
def schedule_region(outs):
"""Schedule for region operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of region
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for region.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
output = outs[0].op.output(0)
#thread = 64 for higher size tensors, give resource_unavailable error for higher values
num_thread = 64
def _schedule_softmax(softmax_op):
softmax = softmax_op.input_tensors[0]
max_elem = softmax_op.input_tensors[1]
expsum = softmax_op.input_tensors[2]
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
s[max_elem].bind(max_elem.op.axis[0], block_x)
k = expsum.op.reduce_axis[0]
ko, ki = s[expsum].split(k, factor=num_thread)
ef = s.rfactor(expsum, ki)
s[expsum].bind(s[expsum].op.axis[0], block_x)
s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x)
s[ef].compute_at(s[expsum], s[expsum].op.reduce_axis[0])
s[expsum].set_store_predicate(thread_x.var.equal(0))
tx, xi = s[softmax_op].split(softmax_op.axis[1], nparts=num_thread)
s[softmax_op].bind(softmax_op.axis[0], block_x)
s[softmax_op].bind(tx, thread_x)
return max_elem.op.input_tensors[0]
def _traverse(op):
if tag.is_injective(op.tag):
if op not in s.outputs:
s[op].compute_inline()
for tensor in op.input_tensors:
if tensor.op.input_tensors:
_traverse(tensor.op)
elif op.tag == 'softmax_output':
tensor = _schedule_softmax(op)
if tensor.op.input_tensors:
_traverse(tensor.op)
else:
raise RuntimeError("Unsupported operator: %s" % op.tag)
_traverse(outs[0].op)
k = output.op.axis[0]
bx, tx = s[output].split(k, factor=num_thread)
s[output].bind(bx, tvm.thread_axis("blockIdx.x"))
s[output].bind(tx, tvm.thread_axis("threadIdx.x"))
return s
......@@ -18,3 +18,4 @@ from __future__ import absolute_import as _abs
from .nn import *
from .injective import *
from .extern import *
from .vision import *
"""Generic vision operators"""
from __future__ import absolute_import as _abs
import tvm
def _default_schedule(outs, auto_inline):
"""Default schedule for llvm."""
target = tvm.target.current_target(allow_none=False)
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
if target.target_name != "llvm":
raise RuntimeError("schedule not registered for '%s'" % target)
s = tvm.create_schedule([x.op for x in outs])
if auto_inline:
x = outs[0]
tvm.schedule.AutoInlineInjective(s)
s[x].fuse(s[x].op.axis)
return s
@tvm.target.generic_func
def schedule_shortcut(outs):
"""Schedule for shortcut
Parameters
----------
outs: Array of Tensor
The computation graph description of shortcut
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_reorg(outs):
"""Schedule for reorg
Parameters
----------
outs: Array of Tensor
The computation graph description of reorg
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_region(outs):
"""Schedule for region
Parameters
----------
outs: Array of Tensor
The computation graph description of region
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
......@@ -4,3 +4,4 @@ from __future__ import absolute_import as _abs
from .conv2d import *
from .dense import *
from .vision import *
# pylint: disable=invalid-name, unused-variable
"""Schedule for vision operator"""
from __future__ import absolute_import as _abs
import topi
from .. import generic
@generic.schedule_region.register(["rocm"])
def schedule_region(outs):
"""Schedule for region operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of region
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for region.
"""
return topi.cuda.schedule_region(outs)
......@@ -12,3 +12,6 @@ from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, depthwise_con
from .dilate_python import dilate_python
from .softmax_python import softmax_python, log_softmax_python
from .upsampling_python import upsampling_python
from .reorg_python import reorg_python
from .region_python import region_python
from .shortcut_python import shortcut_python
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Region in python"""
import numpy as np
def entry_index(batch, w, h, outputs, classes, coords, location, entry):
n = int(location/(w*h))
loc = location%(w*h)
return batch*outputs + n*w*h*(coords+classes+1) + entry*w*h + loc
def region_python(a_np, N, classes, coords, background, softmax):
"""Region operator
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
N : int
Darknet layer parameter n
classes : int
Darknet layer parameter classes
coords : int
Darknet layer parameter coords
background : int
Darknet layer parameter background
softmax : int
Darknet layer parameter softmax
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = a_np.shape
a_np_temp = np.reshape(a_np, batch*in_channel*in_height*in_width)
outputs = batch*in_channel*in_height*in_width
b_np = np.zeros(batch*in_channel*in_height*in_width)
for i in range(batch*in_channel*in_height*in_width):
b_np[i] = a_np_temp[i]
for b in range(batch):
for n in range(N):
index = entry_index(b, in_width, in_height, outputs, classes, coords, n*in_width*in_height, 0)
b_np[index: index+2*in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+2*in_width*in_height]))
index = entry_index(b, in_width, in_height, outputs, classes, coords, n*in_width*in_height, coords)
if not background:
b_np[index: index+in_width*in_height] = 1/(1+np.exp(-1*b_np[index: index+in_width*in_height]))
b_np = np.reshape(b_np, (batch, in_channel, in_height, in_width))
def local_softmax(data_in):
data_c, data_h, data_w = data_in.shape
largest = np.max(data_in, axis=1)
data_out = np.zeros((data_c, data_h, data_w))
for i in range(data_h):
for j in range(data_w):
data_out[:, i, j] = np.exp(data_in[:, i, j] - largest[i, j])
return data_out/data_out.sum(axis=0)
if softmax:
index = coords + int(not background)
for b in range(batch):
for i in range(N):
b_np_index = int(i*(in_channel/N) + index)
b_np[b, b_np_index: b_np_index + classes+background, :, :] = local_softmax(b_np[b, b_np_index:b_np_index + classes+background, :, :])
return b_np
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Reorg in python"""
import numpy as np
def reorg_python(a_np, stride):
"""Reorg operator
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
stride : int
Stride size
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = a_np.shape
a_np = np.reshape(a_np, batch*in_channel*in_height*in_width)
out_c = int(in_channel/(stride*stride))
out_channel = in_channel*stride*stride
out_height = int(in_height/stride)
out_width = int(in_width/stride)
b_np = np.zeros(batch*out_channel*out_height*out_width)
cnt = 0
for b in range(batch):
for k in range(in_channel):
for j in range(in_height):
for i in range(in_width):
c2 = k % out_c
offset = int(k / out_c)
w2 = int(i*stride + offset % stride)
h2 = int(j*stride + offset / stride)
out_index = int(w2 + in_width*stride*(h2 + in_height*stride*(c2 + out_c*b)))
b_np[cnt] = a_np[int(out_index)]
cnt = cnt+1
b_np = np.reshape(b_np, (batch, out_channel, out_height, out_width))
return b_np
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Shortcut in python"""
import numpy as np
def shortcut_python(a_np1, a_np2):
"""Reorg operator
Parameters
----------
a_np1 : numpy.ndarray
4-D with shape [batch1, in_channel1, in_height1, in_width1]
a_np2 : numpy.ndarray
4-D with shape [batch2, in_channel2, in_height2, in_width2]
Returns
-------
b_np : np.ndarray
4-D with shape [batch1, out_channel1, out_height1, out_width1]
"""
batch1, in_channel1, in_height1, in_width1 = a_np1.shape
batch2, in_channel2, in_height2, in_width2 = a_np2.shape
a_np1_temp = np.reshape(a_np1, batch1*in_channel1*in_height1*in_width1)
a_np2_temp = np.reshape(a_np2, batch2*in_channel2*in_height2*in_width2)
b_np = np.zeros(batch1*in_channel1*in_height1*in_width1)
stride = int(in_width1/in_width2)
sample = int(in_width2/in_width1)
if stride < 1:
stride = 1
if sample < 1:
sample = 1
minw = min(in_width1, in_width2)
minh = min(in_height1, in_height2)
minc = min(in_channel1, in_channel2)
for i in range((batch1*in_channel1*in_height1*in_width1)):
b_np[i] = a_np1_temp[i]
for b in range(batch1):
for k in range(minc):
for j in range(minh):
for i in range(minw):
out_index = i*sample + in_width2*(j*sample + in_height2*(k + in_channel2*b))
add_index = i*stride + in_width1*(j*stride + in_height1*(k + in_channel1*b))
b_np[out_index] = a_np1_temp[out_index] + a_np2_temp[add_index]
b_np = np.reshape(b_np, (batch1, in_channel1, in_height1, in_width1))
return b_np
# pylint: disable=wildcard-import
"""VISION network operators"""
from __future__ import absolute_import as _abs
from . import yolo2
from .shortcut import *
from .reorg import *
"""
REORG Operator
====================
Reorg operator, used in darknet.
"""
from __future__ import absolute_import as _abs
import tvm
from .. import util
from .. import transform
@tvm.target.generic_func
def reorg(data, stride):
"""Reorg forward operators.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
stride : int
Stride value for reorganization
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, c_in, h_in, w_in = util.get_const_tuple(data.shape)
out_c = int(c_in / (stride * stride))
out = tvm.compute((batch, c_in, h_in, w_in), lambda b, k, j, i:
data[b * stride * stride,
(k % out_c) * stride * stride,
(j*stride + (k / out_c) / stride) * stride,
(i*stride + (k / out_c) % stride)],
tag="reorg")
out_c = int(c_in * stride * stride)
out_h = int(h_in / stride)
out_w = int(w_in / stride)
return transform.reshape(out, (batch, out_c, out_h, out_w))
"""Shortcut operators (short-cut connections)."""
from __future__ import absolute_import as _abs
import tvm
from .. import util
from .. import transform
@tvm.target.generic_func
def shortcut(inp1, inp2):
"""Shortcut forward operators.
Parameters
----------
First Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Second Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
_, inp1_c, inp1_h, inp1_w = util.get_const_tuple(inp1.shape)
batch, inp2_c, inp2_h, inp2_w = util.get_const_tuple(inp2.shape)
stride = int(max(inp2_w / inp1_w, 1))
sample = int(max(inp1_w / inp2_w, 1))
minc = min(inp2_c, inp1_c)
minh = min(inp2_h, inp1_h)
minw = min(inp2_w, inp1_w)
out = tvm.compute((batch, minc, minh, minw), lambda b, c, h, w:
inp1[b, c, h * sample, w * sample] +
inp2[b, c, h * stride, w * stride],
tag="shortcut")
split_indices = int(inp1_c / minc)
if split_indices > 1:
split_res = transform.split(inp1, split_indices, 1)
split_res[0] = out
out = transform.concatenate(split_res, 1)
return out
# pylint: disable=wildcard-import
"""VISION network operators"""
from __future__ import absolute_import as _abs
from .region import *
# pylint: disable=invalid-name, unused-variable
"""
REGION Operator
====================
Region operator, used in darknet.
"""
from __future__ import absolute_import as _abs
import tvm
from ... import transform
from ... import util
from ... import math
from ... import nn
@tvm.target.generic_func
def region(data, num, classes, coords, background, softmax=True):
"""Region forward operators.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, c_in, h_in, w_in]
num : int
Darknet layer parameter n
classes : int
Darknet layer parameter classes
coords : int
Darknet layer parameter coords
background : int
Darknet layer parameter background
softmax : boolean
Darknet layer parameter softmax
Returns
-------
out : tvm.Tensor
4-D with shape [batch, c_in, h_in, w_in]
"""
batch, c_in, h_in, w_in = util.get_const_tuple(data.shape)
split_indices = classes+coords+1
data_block = transform.reshape(data, (batch, num, split_indices, h_in, w_in))
split_res = transform.split(data_block, split_indices, 2)
split_res[0] = math.sigmoid(split_res[0])
split_res[1] = math.sigmoid(split_res[1])
if not background:
split_res[coords] = math.sigmoid(split_res[coords])
if softmax:
offset = coords + int(not background)
data_block_1 = []
data_block_1.append(transform.concatenate(split_res[0:offset], 2))
temp_out = transform.concatenate(split_res[offset:split_indices], 2)
temp_out = nn.softmax(temp_out, axis=2)
data_block_1.append(temp_out)
split_res = data_block_1
out = transform.concatenate(split_res, 2)
out = transform.reshape(out, data.shape)
return out
"""Example code to do region."""
import numpy as np
import topi
from topi.util import get_const_tuple
import tvm
def verify_region(batch, in_size, in_channel, n, classes, coords, background, l_softmax):
'''Verify region operator by comparing outputs from tvm and numpy implementation'''
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
B = topi.vision.yolo2.region(A, n, classes, coords, background, l_softmax)
a_shape = get_const_tuple(A.shape)
dtype = A.dtype
def get_ref_data_region():
a_np = np.random.uniform(size=a_shape).astype(dtype)
b_np = topi.testing.region_python(a_np, n, classes, coords, background, l_softmax)
return a_np, b_np
a_np, b_np = get_ref_data_region()
def check_device(device):
'''Cheching devices is enabled or not'''
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.vision.schedule_region([B])
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, B], device)
func(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'cuda']:
check_device(device)
def test_region():
verify_region(1, 19, 425, 5, 80, 4, 0, 1)
if __name__ == "__main__":
test_region()
"""Example code to do reorg."""
import numpy as np
import topi
from topi.util import get_const_tuple
import tvm
def verify_reorg(batch, in_size, in_channel, stride):
'''Verify reorg operator by comparing outputs from tvm and numpy implementation'''
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
B = topi.vision.reorg(A, stride)
a_shape = get_const_tuple(A.shape)
dtype = A.dtype
def get_ref_data_reorg():
a_np = np.random.uniform(size=a_shape).astype(dtype)
b_np = topi.testing.reorg_python(a_np, stride)
return a_np, b_np
a_np, b_np = get_ref_data_reorg()
def check_device(device):
'''Cheching devices is enabled or not'''
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective([B])
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A, B], device)
func(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'cuda']:
check_device(device)
def test_reorg():
verify_reorg(1, 20, 8, 2)
if __name__ == "__main__":
test_reorg()
"""Example code to do shortcut."""
import numpy as np
import topi
from topi.util import get_const_tuple
import tvm
def verify_shortcut(batch, in_size, in_channel):
'''Verify shortcut operator by comparing outputs from tvm and numpy implementation'''
in_height = in_width = in_size
A1 = tvm.placeholder((batch, in_channel, in_height, in_width), name='A1')
A2 = tvm.placeholder((batch, in_channel, in_height, in_width), name='A2')
B = topi.vision.shortcut(A1, A2)
a_shape = get_const_tuple(A1.shape)
dtype = A1.dtype
def get_ref_data_shortcut():
a_np1 = np.random.uniform(size=a_shape).astype(dtype)
a_np2 = np.random.uniform(size=a_shape).astype(dtype)
b_np = topi.testing.shortcut_python(a_np1, a_np2)
return a_np1, a_np2, b_np
a_np1, a_np2, b_np = get_ref_data_shortcut()
def check_device(device):
'''Cheching devices is enabled or not'''
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_injective([B])
a1 = tvm.nd.array(a_np1, ctx)
a2 = tvm.nd.array(a_np2, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
func = tvm.build(s, [A1, A2, B], device)
func(a1, a2, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'cuda']:
check_device(device)
def test_shortcut():
verify_shortcut(1, 144, 32)
if __name__ == "__main__":
test_shortcut()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment