Commit ce02ee3b by Pariksheet Pinjari Committed by Tianqi Chen

[TOPI] LRN & L2norm Operator (#1051)

parent b3f09b01
......@@ -16,3 +16,4 @@ from .pooling import schedule_pool, schedule_global_pool
from .conv2d_transpose_nchw import schedule_conv2d_transpose_nchw
from .extern import schedule_extern
from .vision import schedule_region
from .nn import schedule_lrn, schedule_l2norm
# pylint: disable=invalid-name
"""scheduler functions for cuda backend"""
from __future__ import absolute_import as _abs
import tvm
from .. import generic
from .. import tag
from .reduction import _schedule_reduce
@generic.schedule_lrn.register(["cuda"])
def schedule_lrn(outs):
"""Schedule for LRN
Parameters
----------
outs: Array of Tensor
The computation graph description of LRN
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
num_thread = 64
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
lrn = outs[0]
sqr_sum_up = lrn.op.input_tensors[1]
sqr_sum = sqr_sum_up.op.input_tensors[0]
set_pad = sqr_sum.op.input_tensors[0]
s[set_pad].bind(set_pad.op.axis[0], block_x)
rxk = sqr_sum.op.reduce_axis[0]
_, xki = s[sqr_sum].split(rxk, factor=num_thread)
srf = s.rfactor(sqr_sum, xki)
s[sqr_sum].bind(s[sqr_sum].op.axis[0], block_x)
s[sqr_sum].bind(s[sqr_sum].op.reduce_axis[0], thread_x)
s[srf].compute_at(s[sqr_sum], s[sqr_sum].op.reduce_axis[0])
s[sqr_sum_up].bind(sqr_sum_up.op.axis[0], block_x)
xto, _ = s[lrn].split(lrn.op.axis[1], nparts=num_thread)
s[lrn].bind(lrn.op.axis[0], block_x)
s[lrn].bind(xto, thread_x)
return s
@generic.schedule_l2norm.register(["cuda"])
def schedule_l2norm(outs):
"""Schedule for L2norm
Parameters
----------
outs: Array of Tensor
The computation graph description of L2norm
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
def traverse(OP):
'''inline all one-to-one-mapping operators
except the last stage (output)'''
if tag.is_injective(OP.tag) or OP.tag == 'l2norm':
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif OP.tag == 'comm_reduce':
_schedule_reduce(OP, s, is_idx_reduce=False)
for tensor in OP.input_tensors:
traverse(tensor.op)
else:
raise RuntimeError("Unsupported operator tag: %s" % OP.tag)
traverse(outs[0].op)
num_thread = 64
l2norm = outs[0]
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
xto, _ = s[l2norm].split(l2norm.op.axis[1], nparts=num_thread)
s[l2norm].bind(l2norm.op.axis[0], block_x)
s[l2norm].bind(xto, thread_x)
return s
......@@ -229,3 +229,39 @@ def schedule_binary_dense(outs):
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_lrn(outs):
"""Schedule for lrn
Parameters
----------
outs: Array of Tensor
The computation graph description of lrn
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_l2norm(outs):
"""Schedule for l2norm
Parameters
----------
outs: Array of Tensor
The computation graph description of l2norm
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
......@@ -15,3 +15,5 @@ from .softmax import *
from .conv2d_transpose import *
from .bnn import *
from .upsampling import *
from .local_response_norm import *
from .l2_norm import *
# pylint: disable=invalid-name
"""TVM operator for l2norm"""
from __future__ import absolute_import
import tvm
import topi
@tvm.target.generic_func
def l2norm_instance(data, eps, axis=None):
"""Perform L2norm on the input data
For axis=None, y(i, j) = x(i, j) / sqrt(max(sum(x^2), eps))
Parameters
----------
data : tvm.Tensor
4-D with NCHW or NHWC layout
eps : float
epsilon value
axis : list of int
axis over the normalization applied
Returns
-------
output : tvm.Tensor
4-D output with same shape
"""
assert len(data.shape) == 4, "only support 4-dim lrn"
dot_value = topi.cpp.pow(data, 2.0)
sum_value = topi.sum(dot_value, axis=axis, keepdims=True)
expand_sum = topi.broadcast_to(sum_value, data.shape)
return topi.broadcast_div(data, topi.sqrt(\
tvm.compute(expand_sum.shape, lambda i, j, k, l:\
tvm.max(expand_sum[i, j, k, l], eps), tag='l2norm')))
# pylint: disable=invalid-name
"""TVM operator for local response norm compute."""
from __future__ import absolute_import
import tvm
import topi
from .pad import pad
@tvm.target.generic_func
def lrn(data, size, axis=1, alpha=0.0001, beta=0.75, bias=2):
"""Perform the across channels local response normalisation
on the input data.
sum_sqr_up^i{x, y} = (bias+((alpha/size)* \
{sum_{j=max(0, i-size/2)}^{min(N-1,i+size/2)} \
(data^j{x,y})^2}))^beta
output^i{x, y} = data^i{x, y}/sum_sqr_up^i{x, y}
N is the number for input channels
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, height, width]
size : int
normalisation window size
axis : int
input data layout channel axis
default value is 1 for NCHW format
bias : float
offset to avoid dividing by 0
alpha : float
to be divided
beta : float
exponent
Returns
-------
output : tvm.Tensor
4-D output with same shape
"""
assert len(data.shape) == 4, "only support 4-dim lrn"
assert (size % 2) == 1, "size should be odd number"
assert (axis == 1) or (axis == 3), "axis should 1 or 3 for NCHW and NHWC"
##Add padding on left & right of size radius first
pad_after = pad_before = [0, 0, 0, 0]
pad_after[axis] = pad_before[axis] = (size//2)
pad_data = pad(data, pad_before, pad_after, name="pad_data")
rxs = tvm.reduce_axis((0, size), name='rxs')
if axis == 1:
#NCHW layout
sqr_sum = tvm.compute(data.shape, lambda i, j, k, l: tvm.sum(
pad_data[i, j + rxs, k, l] * pad_data[i, j + rxs, k, l],
axis=rxs))
elif axis == 3:
#NHWC layout
sqr_sum = tvm.compute(data.shape, lambda i, j, k, l: tvm.sum(
pad_data[i, j, k, l + rxs] * pad_data[i, j, k, l + rxs],
axis=rxs))
sqr_sum_up = tvm.compute(data.shape, lambda i, j, k, l: tvm.power(
(bias + (alpha * sqr_sum[i, j, k, l] / size)), beta))
return topi.broadcast_div(data, sqr_sum_up)
......@@ -5,3 +5,4 @@ from __future__ import absolute_import as _abs
from .conv2d import *
from .dense import *
from .vision import *
from .nn import *
"""scheduler for normalization functions on rocm backend"""
from __future__ import absolute_import as _abs
import topi
from .. import generic
@generic.schedule_lrn.register(["rocm", "gpu"])
def schedule_lrn(outs):
return topi.cuda.schedule_lrn(outs)
@generic.schedule_l2norm.register(["rocm", "gpu"])
def schedule_l2norm(outs):
return topi.cuda.schedule_l2norm(outs)
"""Test code for L2 norm"""
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
def l2norm_instance_python(a_np, eps, axis=None):
"""L2 norm operator in NCHW layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
eps : float
epsilon constant value
axis : list of int
axis over the normalization applied
Returns
-------
l2norm_out : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, axis1, axis2, axis3 = a_np.shape
sqr_sum = np.zeros(shape=(batch,)).astype(a_np.dtype)
sqrt_sum = np.zeros(shape=(batch,)).astype(a_np.dtype)
l2norm_out = np.zeros(shape=a_np.shape).astype(a_np.dtype)
dot_value = np.power(a_np, 2.0)
sqr_sum = np.sum(dot_value, axis, keepdims=True)
sqrt_sum = np.sqrt(np.maximum(np.broadcast_to(sqr_sum, a_np.shape), eps))
return np.divide(a_np, sqrt_sum)
def verify_l2norm(n, c, h, w, eps, axis=None):
A = tvm.placeholder((n, c, h, w), name='A')
B = topi.nn.l2norm_instance(A, eps, axis)
dtype = A.dtype
a_np = np.random.uniform(size=(n, c, h, w)).astype(dtype)
b_np = l2norm_instance_python(a_np, eps, axis)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_l2norm(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
def test_l2norm():
verify_l2norm(1, 3, 20, 20, 0.001)
verify_l2norm(1, 3, 20, 20, 0.001, 1)
verify_l2norm(1, 3, 20, 20, 0.001, (1, 2))
verify_l2norm(1, 3, 20, 20, 0.001, (2, 3))
verify_l2norm(1, 3, 20, 20, 0.001, (0, 3))
verify_l2norm(1, 3, 20, 20, 0.001, (0, 2, 3))
if __name__ == "__main__":
test_l2norm()
"""Test code for local response normalization"""
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
def lrn_python(a_np, size, axis, bias, alpha, beta):
"""Local response norm operator in NCHW layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
size : int
normalisation window size
axis : int
input data layout channel axis
bias : float
offset to avoid dividing by 0. constant value
alpha : float
contant valie
beta : float
exponent constant value
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
axis0, axis1, axis2, axis3 = a_np.shape
radius = size // 2
sqr_sum = np.zeros(shape=a_np.shape).astype(a_np.dtype)
sqr_sum_up = np.zeros(shape=a_np.shape).astype(a_np.dtype)
lrn_out = np.zeros(shape=a_np.shape).astype(a_np.dtype)
def sum_dot_values(i, j, k, l):
axis_size = a_np.shape[axis]
if (axis == 1):
#NCHW layout
sum_start = j-radius if j-radius >= 0 else 0
sum_end = j+radius+1 if j+radius+1 < axis_size else axis_size
sqr_sum[i, j, k, l] = sum(a_np[i, sum_start:sum_end, k, l] * \
a_np[i, sum_start:sum_end, k, l])
elif (axis == 3):
#NHWC layout
sum_start = l-radius if l-radius >= 0 else 0
sum_end = l+radius+1 if l+radius+1 < axis_size else axis_size
sqr_sum[i, j, k, l] = sum(a_np[i, j, k, sum_start:sum_end] * \
a_np[i, j, k, sum_start:sum_end])
for i in range(axis0):
for j in range(axis1):
for k in range(axis2):
for l in range(axis3):
sum_dot_values(i, j, k, l)
sqr_sum_up = np.power((bias + (alpha * sqr_sum /size)), beta)
return np.divide(a_np, sqr_sum_up)
def verify_lrn(shape, size, axis, bias, alpha, beta):
A = tvm.placeholder(shape, name='A')
B = topi.nn.lrn(A, size, axis, alpha, beta, bias)
dtype = A.dtype
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = lrn_python(a_np, size, axis, bias, alpha, beta)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
s = topi.generic.schedule_lrn(B)
ctx = tvm.context(device, 0)
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B], device)
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan']:
check_device(device)
def test_lrn():
verify_lrn((1, 3, 5, 5), 3, 1, 1, 1, 0.5)
verify_lrn((1, 3, 5, 5), 3, 3, 1, 1, 0.5)
verify_lrn((1, 3, 20, 20), 3, 1, 2, 1, 0.75)
if __name__ == "__main__":
test_lrn()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment