Commit 7196c791 by Tianqi Chen Committed by GitHub

[TOPI] Isolate padding option, improve decl of depthwise/conv2d/pool (#332)

parent abccd9cd
......@@ -50,6 +50,12 @@ class ExprOp(object):
def __rtruediv__(self, other):
return self.__rdiv__(other)
def __floordiv__(self, other):
return self.__div__(other)
def __rfloordiv__(self, other):
return self.__rdiv__(other)
def __mod__(self, other):
return _make.Mod(self, other)
......
......@@ -52,10 +52,11 @@ def static_cast(dtype, expr):
"""
target_type = TVMType(dtype)
src_type = TVMType(expr.dtype)
if target_type.type_code == src_type.type_code\
and src_type.lanes == 1\
and target_type.lanes > 1:
return Broadcast(expr, target_type.lanes)
if target_type.type_code == src_type.type_code and src_type.bits == target_type.bits:
if src_type.lanes == target_type.lanes:
return expr
elif src_type.lanes == 1 and target_type.lanes > 1:
return Broadcast(expr, target_type.lanes)
return Cast(dtype, expr)
......
......@@ -15,3 +15,4 @@ from .broadcast import *
from . import nn
from . import cuda
from . import testing
from . import util
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
# pylint: disable=invalid-name, unused-variable, too-many-locals
"""Convolution operators"""
from __future__ import absolute_import as _abs
import tvm
import numpy as np
from ..util import get_const_tuple
from ..util import simplify
from .pad import pad, _spatial2d_pad_option
@tvm.tag_scope(tag="conv2d_nchw")
def conv2d_nchw(Input, Filter, stride, padding):
"""Convolution operator in HWCN layout.
......@@ -31,45 +30,33 @@ def conv2d_nchw(Input, Filter, stride, padding):
"""
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(padding, int) or padding in ['VALID', 'SAME']
batch, in_channel, in_height, in_width = get_const_tuple(Input.shape)
num_filter, channel, kernel_h, kernel_w = get_const_tuple(Filter.shape)
batch, in_channel, in_height, in_width = Input.shape
num_filter, channel, kernel_h, kernel_w = Filter.shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
# compute the padding size
if isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == 'VALID':
pad_h = 0
pad_w = 0
else: # 'SAME'
pad_h = kernel_h - 1
pad_w = kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_left = int(np.ceil(float(pad_w) / 2))
pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option(
padding, (kernel_h, kernel_w))
# compute the output shape
out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
# compute graph
temp = tvm.compute(
(batch, in_channel, in_height + pad_h, in_width + pad_w),
lambda nn, cc, yy, xx: tvm.select(
tvm.all(yy >= pad_top, yy - pad_top < in_height,
xx >= pad_left, xx - pad_left < in_width),
Input[nn, cc, yy - pad_top, xx - pad_left], tvm.const(0.)),
name='temp')
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
temp = pad(Input, pad_before, pad_after, name="pad_temp")
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
return tvm.compute(
(batch, out_channel, out_height, out_width),
lambda nn, ff, yy, xx: tvm.sum(
temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx] * Filter[ff, rc, ry, rx],
axis=[rc, ry, rx]))
axis=[rc, ry, rx]), tag="conv2d_nchw")
@tvm.tag_scope(tag="conv2d_hwcn")
def conv2d_hwcn(Input, Filter, stride, padding):
"""Convolution operator in HWCN layout.
......@@ -93,36 +80,22 @@ def conv2d_hwcn(Input, Filter, stride, padding):
4-D with shape [out_height, out_width, out_channel, batch]
"""
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(padding, int) or padding in ['VALID', 'SAME']
in_height, in_width, in_channel, batch = get_const_tuple(Input.shape)
kernel_h, kernel_w, channel, num_filter = get_const_tuple(Filter.shape)
in_height, in_width, in_channel, batch = Input.shape
kernel_h, kernel_w, channel, num_filter = Filter.shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
# compute the padding size
if isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == 'VALID':
pad_h = 0
pad_w = 0
else: # 'SAME'
pad_h = kernel_h - 1
pad_w = kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_left = int(np.ceil(float(pad_w) / 2))
pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option(
padding, (kernel_h, kernel_w))
# compute the output shape
out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
# compute graph
PaddedInput = tvm.compute(
(in_height + pad_h, in_width + pad_w, in_channel, batch),
lambda yy, xx, cc, nn: tvm.select(
tvm.all(yy >= pad_top, yy - pad_top < in_height,
xx >= pad_left, xx - pad_left < in_width),
Input[yy - pad_top, xx - pad_left, cc, nn], tvm.const(0.)),
name='PaddedInput')
out_height = simplify((in_height - kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - kernel_w + pad_left + pad_right) // stride_w + 1)
pad_before = [pad_top, pad_left, 0, 0]
pad_after = [pad_down, pad_right, 0, 0]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
......@@ -131,12 +104,11 @@ def conv2d_hwcn(Input, Filter, stride, padding):
lambda yy, xx, ff, nn: tvm.sum(
PaddedInput[yy * stride_h + ry, xx * stride_w + rx, rc, nn] * Filter[ry, rx, rc, ff],
axis=[ry, rx, rc]),
name='Conv2dOutput')
name="Conv2dOutput", tag="conv2d_hwcn")
return Output
@tvm.tag_scope(tag="depthwise_conv2d")
def depthwise_conv2d(Input, Filter, Stride, padding):
def depthwise_conv2d(Input, Filter, stride, padding):
"""Depthwise convolution operator.
Parameters
......@@ -147,8 +119,8 @@ def depthwise_conv2d(Input, Filter, Stride, padding):
Filter : tvm.Tensor
4-D with shape [in_channel, channel_multiplier, filter_height, filter_width]
Stride : tvm.Tensor
1-D of size 2
stride : tuple of two ints
The spatial stride along height and width
padding : str
'VALID' or 'SAME'
......@@ -158,49 +130,28 @@ def depthwise_conv2d(Input, Filter, Stride, padding):
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
in_shape = get_const_tuple(Input.shape)
batch = in_shape[0]
in_channel = in_shape[1]
in_height = in_shape[2]
in_width = in_shape[3]
filter_shape = get_const_tuple(Filter.shape)
filter_channel = filter_shape[0]
channel_multiplier = filter_shape[1]
filter_height = filter_shape[2]
filter_width = filter_shape[3]
stride_h = Stride.asnumpy()[0]
stride_w = Stride.asnumpy()[1]
# calculate output shape
if padding == 'VALID':
out_channel = in_channel * channel_multiplier
out_height = (in_height - filter_height) // stride_h + 1
out_width = (in_width - filter_width) // stride_w + 1
pad_along_height = 0
pad_along_width = 0
if padding == 'SAME':
out_channel = in_channel * channel_multiplier
out_height = np.int(np.ceil(float(in_height) / float(stride_h)))
out_width = np.int(np.ceil(float(in_width) / float(stride_w)))
pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0))
pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0))
height_after_pad = in_height + pad_along_height
width_after_pad = in_width + pad_along_width
pad_top = np.int(np.ceil(float(pad_along_height) / 2))
pad_left = np.int(np.ceil(float(pad_along_width) / 2))
batch, in_channel, in_height, in_width = Input.shape
filter_channel, channel_multiplier, filter_height, filter_width = Filter.shape
stride_h, stride_w = stride
pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option(
padding, (filter_height, filter_width))
out_channel = simplify(in_channel * channel_multiplier)
out_height = simplify((in_height - filter_height + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - filter_width + pad_left + pad_right) // stride_w + 1)
# padding stage
PaddedInput = tvm.compute(
(batch, in_channel, height_after_pad, width_after_pad),
lambda b, c, i, j: tvm.select(
tvm.all(i >= pad_top, i - pad_top < in_height, j >= pad_left, j - pad_left < in_width),
Input[b, c, i - pad_top, j - pad_left], tvm.const(0.0)),
name="PaddedInput")
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
PaddedInput = pad(Input, pad_before, pad_after, name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_channel, out_height, out_width),
lambda b, c, i, j: tvm.sum(
PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] * Filter[c/channel_multiplier, c%channel_multiplier, di, dj],
(PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] *
Filter[c/channel_multiplier, c%channel_multiplier, di, dj]),
axis=[di, dj]),
name='DepthwiseConv2d')
name='DepthwiseConv2d', tag="depthwise_conv2d")
return Output
......@@ -6,35 +6,39 @@ from .. import util
@tvm.tag_scope(tag="dilation")
def dilate(Input, strides):
"""Dilate Input with zeros.
def dilate(data, strides, name="DilatedInput"):
"""Dilate data with zeros.
Parameters
----------
Input : tvm.Tensor
data : tvm.Tensor
n-D, can be any layout.
strides : list / tuple of n ints
Dilation stride on each dimension, 1 means no dilation.
name : str, optional
The name prefix operators generated
Returns
-------
Output : tvm.Tensor
n-D, the same layout as Input.
n-D, the same layout as data.
"""
n = len(Input.shape)
assert len(strides) == n, \
"Input dimension and strides size dismatch : %d vs %d" %(n, len(strides))
output_size = ()
for i in range(n):
output_size += (tvm.ir_pass.Simplify((Input.shape[i]-1)*strides[i]+1),)
def _dilate(data, *indices):
n = len(data.shape)
if len(strides) != n:
raise ValueError("data dimension and strides size dismatch : %d vs %d" % (
n, len(strides)))
out_shape = tuple(
tvm.ir_pass.Simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n))
def _dilate(*indices):
not_zero = []
index_tuple = []
for i in range(n):
if not util.equal_const_int(strides[i], 1):
index_tuple.append(indices[i]/strides[i])
index_tuple.append(indices[i] / strides[i])
not_zero.append((indices[i] % strides[i]).equal(0))
else:
index_tuple.append(indices[i])
......@@ -43,9 +47,4 @@ def dilate(Input, strides):
return tvm.select(not_zero, data(*index_tuple), tvm.const(0.0, data.dtype))
return data(*index_tuple)
Output = tvm.compute(
output_size,
lambda *indices: _dilate(Input, *indices),
name='DilatedInput')
return Output
return tvm.compute(out_shape, _dilate, name=name)
"""Pad the data by constant value """
from __future__ import absolute_import as _abs
import tvm
from ..util import equal_const_int
def _spatial2d_pad_option(padding, kernel):
"""Common code to get the pad option
Parameters
----------
padding : int or str
Padding size, or ['VALID', 'SAME']
kernel : tuple of int
Conv kernel size
Returns
-------
pad_top : int
Padding size on top
pad_left : int
Padding size on left
pad_down : int
Padding size on down.
pad_right : int
Padding size on right.
"""
# compute the padding size
if isinstance(padding, (tuple, list)):
pad_h = padding[0] * 2
pad_w = padding[1] * 2
elif isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == "VALID":
pad_h = 0
pad_w = 0
elif padding == "SAME":
pad_h = kernel[0] - 1
pad_w = kernel[1] - 1
else:
raise ValueError("Unknown padding option %s" % padding)
pad_top = (pad_h + 1) // 2
pad_left = (pad_w + 1) // 2
return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left
@tvm.tag_scope(tag="pad")
def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"):
"""Dilate Input with zeros.
Parameters
----------
data : tvm.Tensor
n-D input, can be any layout.
pad_before : list / tuple of n ints
Pad width on each dimension to pad the before the axis begin.
pad_after : list / tuple of n ints, optional
Pad width each dimension to pad the after the axis end.
pad_value : float, optional
The value to be padded.
name : str, optional
The name prefix operators generated
Returns
-------
Output : tvm.Tensor
n-D, the same layout as Input.
"""
n = len(data.shape)
pad_after = pad_after if pad_after else pad_before
if len(pad_before) != n:
raise ValueError("Input dimension and pad_before dismatch : %d vs %d" % (
n, len(pad_before)))
if len(pad_after) != n:
raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % (
n, len(pad_before)))
out_shape = tuple(
tvm.ir_pass.Simplify(
(data.shape[i] + pad_before[i] + pad_after[i])) for i in range(n))
pad_value = (pad_value if isinstance(pad_value, tvm.expr.Expr)
else tvm.const(pad_value, data.dtype))
def _pad(*indices):
not_zero = []
index_tuple = []
for i in range(n):
if equal_const_int(pad_before[i], 0) and equal_const_int(pad_after[i], 0):
index_tuple.append(indices[i])
else:
index_tuple.append(indices[i] - pad_before[i])
not_zero.append(indices[i] >= pad_before[i])
not_zero.append(indices[i] < data.shape[i] + pad_before[i])
if not_zero:
not_zero = tvm.all(*not_zero)
return tvm.select(not_zero, data(*index_tuple), pad_value)
return data(*index_tuple)
return tvm.compute(out_shape, _pad, name=name)
"""TVM operator pooling compute."""
from __future__ import absolute_import
import tvm
from .. import util
from .pad import pad, _spatial2d_pad_option
@tvm.tag_scope(tag='max_pool')
def max_pool(data, kernel, stride, pad):
def max_pool(data, kernel, stride, padding):
"""Perform max pooling on the data
Parameters
......@@ -17,7 +18,7 @@ def max_pool(data, kernel, stride, pad):
stride : list/tuple of two ints
Stride size, or [stride_height, stride_width]
pad : list/tuple of two ints
paddding : list/tuple of two ints
Pad size, or [pad_height, pad_width]
Returns
......@@ -26,29 +27,27 @@ def max_pool(data, kernel, stride, pad):
4-D with shape [batch, channel, out_height, out_width]
"""
assert len(data.shape) == 4, "only support 4-dim pooling"
assert len(stride.shape) == 2, "only support 2-dim stride"
assert len(pad.shape) == 2, "only support 2-dim pad"
assert len(stride) == 2, "only support 2-dim stride"
kernel_height, kernel_width = kernel
stride_height, stride_width = stride
pad_height, pad_width = pad
batch, channel, height, width = data.shape
padded_height = height + 2*pad_height
padded_width = width + 2*pad_width
out_height = (height + 2*pad_height - kernl_height) / stride_height + 1
out_width = (width + 2*pad_width - kernel_width) / stride_width + 1
pad_top, pad_left, pad_down, pad_right = _spatial2d_pad_option(
padding, (kernel_height, kernel_width))
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
temp = pad(data, pad_before, pad_after, name="pad_temp",
pad_value=tvm.min_value("float32"))
out_height = util.simplify((height - kernel_height + pad_top + pad_down) // stride_height + 1)
out_width = util.simplify((width - kernel_width + pad_left + pad_right) // stride_width + 1)
dheight = tvm.reduce_axis((0, kernel_height))
dwidth = tvm.reduce_axis((0, kernel_width))
temp = tvm.compute((batch, channel, padded_height, padded_width), lambda i, c, h, w: \
tvm.select(
tvm.make.Or(tvm.make.Or((h < pad_height), (h >= height + pad_height)),
tvm.make.Or((w < pad_width), (w >= width + pad_width))),
tvm.min_value('float32'),
data[i, c, h - pad_height, w - pad_width]), name='temp')
return tvm.compute((batch, channel, out_height, out_width), lambda i, c, h, w: \
tvm.max(temp[i, c, h*stride_height+dheight, w*stride_width+dwidth], axis=[dheight, dwidth]))
return tvm.compute(
(batch, channel, out_height, out_width),
lambda i, c, h, w:
tvm.max(temp[i, c, h*stride_height+dheight, w*stride_width+dwidth], axis=[dheight, dwidth]),
tag="max_pool")
@tvm.tag_scope(tag='global_avg_pool')
......
......@@ -19,9 +19,8 @@ def softmax(x):
assert len(x.shape) == 2, "only support 2-dim softmax"
m, n = x.shape
k = tvm.reduce_axis((0, n), name='k')
max_elem = tvm.compute((m, ), lambda i: \
tvm.max(x[i, k]), axis=k)
expsum = tvm.compute((m, ), lambda i: \
tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k))
return tvm.compute(x.shape, lambda i, j: \
tvm.exp(x[i, j] - max_elem[i]) / expsum[i])
max_elem = tvm.compute((m, ), lambda i: tvm.max(x[i, k], axis=k))
expsum = tvm.compute(
(m, ), lambda i: tvm.sum(tvm.exp(x[i, k] - max_elem[i]), axis=k))
return tvm.compute(
x.shape, lambda i, j: tvm.exp(x[i, j] - max_elem[i]) / expsum[i])
......@@ -63,3 +63,19 @@ def get_const_tuple(in_tuple):
raise ValueError("Element of input tuple should be const int")
out_tuple = out_tuple + (elem.value, )
return out_tuple
def simplify(expr):
"""Simplify the expression if it is Expr, directly return if it is int.
Parameters
----------
expr : Expr or int
The input.
Returns
-------
out : Expr or int
The simplified output
"""
return tvm.ir_pass.Simplify(expr) if isinstance(expr, tvm.expr.Expr) else expr
......@@ -49,7 +49,7 @@ def test_depthwise_conv2d():
# Placeholder
Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input')
Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
Stride = tvm.nd.array(np.array([stride_h, stride_w]))
Stride = [stride_h, stride_w]
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# Declare
......
......@@ -13,7 +13,7 @@ def depthwise_conv2d_with_workload(batch, in_channel, in_height, channel_multipl
# placeholder
Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input')
Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
Stride = tvm.nd.array(np.array([stride_h, stride_w]))
Stride = [stride_h, stride_w]
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# declare
......
......@@ -14,9 +14,7 @@ def test_dilate():
input_np = np.random.uniform(size=input_size).astype(Input.dtype)
output_np = topi.testing.dilate_python(input_np, strides)
input_tvm = tvm.nd.array(input_np, ctx=ctx)
output_size = ()
for i in range(len(input_size)):
output_size += (tvm.ir_pass.Simplify(Output.shape[i]).value,)
output_size = topi.util.get_const_tuple(Output.shape)
output_tvm = tvm.nd.array(np.zeros(shape=output_size).astype(Output.dtype), ctx=ctx)
f = tvm.build(schedule, [Input, Output], target)
f(input_tvm, output_tvm)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment