Commit 0ad590c0 by ziheng Committed by Tianqi Chen

[TOPI] Add ops compute (#323)

* [TOPI] Add ops compute

Remove 'compute' and add assert for safety

Add document

fix lint

fix softmax

* fix batch norm
parent ce18b565
......@@ -2,6 +2,42 @@
from __future__ import absolute_import as _abs
import tvm
@tvm.tag_scope(tag='ewise')
def identity(x):
"""Take identity of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
# pylint: disable=unnecessary-lambda
return tvm.compute(x.shape, lambda *i: x(*i))
@tvm.tag_scope(tag='ewise')
def negation(x):
"""Take negation of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
# pylint: disable=unnecessary-lambda
return tvm.compute(x.shape, lambda *i: -x(*i))
@tvm.tag_scope(tag="ewise")
def exp(x):
"""Take exponential of input x.
......
......@@ -2,7 +2,12 @@
"""Neural network operators"""
from __future__ import absolute_import as _abs
from .ewise import *
from .mapping import *
from .conv import *
from .batch_norm import *
from .convolution import *
from .elemwise import *
from .dilate import *
from .flatten import *
from .fully_connected import *
from .mapping import *
from .pooling import *
from .softmax import *
"""TVM operator batch normalization compute."""
from __future__ import absolute_import
import tvm
@tvm.tag_scope(tag='batch_norm')
def batch_norm(data, gamma, beta, moving_mean, moving_var, eps, fix_gamma):
"""Batch normalization operator in NCHW layout.
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, height, width]
gamma : tvm.Tensor
1-D with shape [channel]
beta : tvm.Tensor
1-D with shape [channel]
moving_mean : tvm.Tensor
1-D with shape [channel]
moving_var : tvm.Tensor
1-D with shape [channel]
eps : float
Epsilon to prevent div 0.
fix_gamma : boolean
Fix gamma while training
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, height, width]
mean : tvm.Tensor
1-D with shape [channel]
var : tvm.Tensor
1-D with shape [channel]
"""
assert len(data.shape) == 4, "only support 4-dim batch norm"
batch, channel, height, width = data.shape
if fix_gamma:
out = tvm.compute((batch, channel, height, width), \
lambda b, c, h, w: (data[b, c, h, w] - moving_mean[c]) / \
tvm.intrin.sqrt(moving_var[c] + eps) + beta[c])
else:
out = tvm.compute((batch, channel, height, width), \
lambda b, c, h, w: (data[b, c, h, w] - moving_mean[c]) / \
tvm.intrin.sqrt(moving_var[c] + eps) * gamma[c] + beta[c])
mean = tvm.compute((C, ), lambda c: moving_mean[c])
var = tvm.compute((C, ), lambda c: moving_var[c])
return [out, mean, var]
......@@ -89,7 +89,7 @@ def conv2d_hwcn(Input, Filter, stride, padding):
Returns
-------
Output : tvm.Tensor
output : tvm.Tensor
4-D with shape [out_height, out_width, out_channel, batch]
"""
assert isinstance(stride, int) or len(stride) == 2
......
"""TVM operator flatten compute."""
from __future__ import absolute_import
import tvm
@tvm.tag_scope(tag='flatten')
def flatten(data):
"""Flattens the input array into a 2-D array by collapsing the higher dimensions.
Parameters
----------
data : tvm.Tensor
Input array.
Returns
-------
output : tvm.Tensor
2-D array with collapsed higher dimensions.
"""
ishape = data.shape
dim = 1
for i in range(1, len(ishape)):
dim = dim * ishape[i]
oshape = [ishape[0], dim]
def unwrap(idx, shape):
index = []
for s in reversed(shape):
index.append(idx % s)
idx = idx / s
return list(reversed(index))
return tvm.compute(oshape, lambda i, j: data(i, *unwrap(j, ishape[1:])))
"""TVM operator fully connected compute."""
from __future__ import absolute_import
import tvm
@tvm.tag_scope(tag='fully_connected')
def fully_connected(data, weight):
"""Matrix multiplication
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
assert len(data.shape) == 2 and len(weight.shape) == 2, \
"only support 2-dim fully_connected"
batch, in_dim = data.shape
out_dim, _ = weight.shape
k = tvm.reduce_axis((0, in_dim), name='k')
return tvm.compute((batch, out_dim), lambda i, j: \
tvm.sum(data[i][k] * weight[j][k], axis=k))
@tvm.tag_scope(tag='fully_connected_with_bias')
def fully_connected_with_bias(data, weight, bias):
"""Applies a linear transformation: :math:`Y = XW^T + b`.
Parameters
----------
data : tvm.Tensor
2-D with shape [batch, in_dim]
weight : tvm.Tensor
2-D with shape [out_dim, in_dim]
bias : tvm.Tensor
1-D with shape [out_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
assert len(data.shape) == 2 and len(weight.shape) == 2, \
"only support 2-dim fully_connected"
assert len(data.shape) == 2 and len(weight.shape) == 2 and len(bias.shape) == 1, \
"only support 2-dim fully_connected"
batch, in_dim = data.shape
out_dim, _ = weight.shape
k = tvm.reduce_axis((0, in_dim), name='k')
matmul = tvm.compute((batch, out_dim), lambda i, j: \
tvm.sum(data[i, k] * weight[j, k], axis=k))
return tvm.compute((batch, out_dim), lambda i, j: \
matmul[i, j] + bias[j])
"""TVM operator pooling compute."""
from __future__ import absolute_import
import tvm
@tvm.tag_scope(tag='max_pool')
def max_pool(data, kernel, stride, pad):
"""Perform max pooling on the data
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width]
kernel : list/tuple of two ints
Kernel size, or [kernel_height, kernel_width]
stride : list/tuple of two ints
Stride size, or [stride_height, stride_width]
pad : list/tuple of two ints
Pad size, or [pad_height, pad_width]
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, out_height, out_width]
"""
assert len(data.shape) == 4, "only support 4-dim pooling"
assert len(stride.shape) == 2, "only support 2-dim stride"
assert len(pad.shape) == 2, "only support 2-dim pad"
kernel_height, kernel_width = kernel
stride_height, stride_width = stride
pad_height, pad_width = pad
batch, channel, height, width = data.shape
padded_height = height + 2*pad_height
padded_width = width + 2*pad_width
out_height = (height + 2*pad_height - kernl_height) / stride_height + 1
out_width = (width + 2*pad_width - kernel_width) / stride_width + 1
dheight = tvm.reduce_axis((0, kernel_height))
dwidth = tvm.reduce_axis((0, kernel_width))
temp = tvm.compute((batch, channel, padded_height, padded_width), lambda i, c, h, w: \
tvm.select(
tvm.make.Or(tvm.make.Or((h < pad_height), (h >= height + pad_height)),
tvm.make.Or((w < pad_width), (w >= width + pad_width))),
tvm.min_value('float32'),
data[i, c, h - pad_height, w - pad_width]), name='temp')
return tvm.compute((batch, channel, out_height, out_width), lambda i, c, h, w: \
tvm.max(temp[i, c, h*stride_height+dheight, w*stride_width+dwidth], axis=[dheight, dwidth]))
@tvm.tag_scope(tag='global_avg_pool')
def global_avg_pool(data):
"""Perform global average pooling on the data
Parameters
----------
data : tvm.Tensor
4-D with shape [batch, channel, in_height, in_width]
Returns
-------
output : tvm.Tensor
4-D with shape [batch, channel, 1, 1]
"""
assert len(data.shape) == 4, "only support 4-dim pooling"
batch, channel, height, width = data.shape
dheight = tvm.reduce_axis((0, height))
dwidth = tvm.reduce_axis((0, width))
tsum = tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tvm.sum(data[n, c, dheight, dwidth], axis=[dheight, dwidth]))
return tvm.compute((batch, channel, 1, 1), lambda n, c, h, w: \
tsum[n, c, h, w] / (height*width))
"""TVM operator softmax compute."""
from __future__ import absolute_import
import tvm
@tvm.tag_scope(tag='softmax')
def softmax(x):
"""Perform softmax activation on the data
Parameters
----------
data : tvm.Tensor
2-D input data
Returns
-------
output : tvm.Tensor
2-D output with same shape
"""
assert len(x.shape) == 2, "only support 2-dim softmax"
m, n = x.shape
k = tvm.reduce_axis((0, n), name='k')
max_elem = tvm.compute((m, ), lambda i: \
tvm.max(x[i, k]), axis=k)
expsum = tvm.compute((m, ), lambda i: \
tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k))
return tvm.compute(x.shape, lambda i, j: \
tvm.exp(x[i, j] - max_elem[i]) / expsum[i])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment