Commit 4580e690 by Yuwei HU Committed by Tianqi Chen

[TOPI] Example for depthwise convolution (#197)

* first commit

* move to topi/recipe

* refactor, almost rewrite

* 2-D sum reduction; implement SAME pad; improve schedule

* add util.py; separate test script

* conv + bn + relu fusion

* auto fusion

* separate declare and schedule; using op tag

* divide large image into blocks

* move to topi; improve blocking schedule

* restructure

* add doc

* using time_evaluator
parent 825566cc
# pylint: disable=redefined-builtin, wildcard-import
"""CUDA specific declaration and schedule.""" """CUDA specific declaration and schedule."""
from __future__ import absolute_import as _abs
from .depthwise_conv2d_map import *
# pylint: disable=invalid-name
"""Schedule for depthwise_conv2d with auto fusion"""
import tvm
from ..nn.util import get_const_tuple
def schedule_depthwise_conv2d_map(op):
"""Schedule for depthwise_conv2d and auto fusion with
one-to-one-mapping operators, e.g. scale-shift and relu.
Parameters
----------
op: Operation
The symbolic description of the operation, should be depthwise_conv2d or
depthwise_conv2d followed by a sequence of one-to-one-mapping operators.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
s = tvm.create_schedule(op)
def schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d):
"""Schedule for depthwise_conv2d declared in topi.nn.conv"""
out_shape = get_const_tuple(DepthwiseConv2d.shape)
out_height = out_shape[2]
out_width = out_shape[3]
channel_multiplier = get_const_tuple(Filter.shape)[1]
s[PaddedInput].compute_inline()
IS = s.cache_read(PaddedInput, "shared", [DepthwiseConv2d])
FS = s.cache_read(Filter, "shared", [DepthwiseConv2d])
IL = s.cache_read(IS, "local", [DepthwiseConv2d])
FL = s.cache_read(FS, "local", [DepthwiseConv2d])
if DepthwiseConv2d.op in s.outputs:
Output = DepthwiseConv2d
CL = s.cache_write(DepthwiseConv2d, "local")
else:
Output = op.output(0)
s[DepthwiseConv2d].set_scope("local")
# schedule parameters
num_thread = 8
num_vthread_x = 1
num_vthread_y = 1
blocking_h = out_height
blocking_w = out_width
if out_height % 48 == 0:
blocking_h = 48
elif out_height % 32 == 0:
blocking_h = 32
if out_width % 48 == 0:
blocking_w = 48
num_vthread_y = 3
elif out_width % 32 == 0:
blocking_w = 32
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy")
# split and bind
bx, bxi = s[Output].split(Output.op.axis[1], factor=channel_multiplier)
s[Output].reorder(Output.op.axis[2], Output.op.axis[3], bxi)
bx = s[Output].fuse(bx, Output.op.axis[0])
s[Output].bind(bx, block_x)
by1, y1i = s[Output].split(Output.op.axis[2], factor=blocking_h)
tvx, vxi = s[Output].split(y1i, nparts=num_vthread_x)
tx, xi = s[Output].split(vxi, nparts=num_thread)
by2, y2i = s[Output].split(Output.op.axis[3], factor=blocking_w)
tvy, vyi = s[Output].split(y2i, nparts=num_vthread_y)
ty, yi = s[Output].split(vyi, nparts=num_thread)
s[Output].reorder(by1, by2, tvx, tvy, tx, ty, xi, yi)
by = s[Output].fuse(by2, by1)
s[Output].bind(tvx, thread_vx)
s[Output].bind(tvy, thread_vy)
s[Output].bind(tx, thread_x)
s[Output].bind(ty, thread_y)
s[Output].bind(by, block_y)
# local memory load
s[IL].compute_at(s[Output], ty)
s[FL].compute_at(s[Output], ty)
if DepthwiseConv2d.op in s.outputs:
s[CL].compute_at(s[Output], ty)
else:
s[DepthwiseConv2d].compute_at(s[Output], ty)
# input's shared memory load
s[IS].compute_at(s[Output], by)
tx, xi = s[IS].split(IS.op.axis[2], nparts=num_thread)
ty, yi = s[IS].split(IS.op.axis[3], nparts=num_thread)
s[IS].bind(tx, thread_x)
s[IS].bind(ty, thread_y)
# filter's shared memory load
s[FS].compute_at(s[Output], by)
s[FS].reorder(FS.op.axis[2], FS.op.axis[3], FS.op.axis[1])
tx, xi = s[FS].split(FS.op.axis[2], nparts=num_thread)
ty, yi = s[FS].split(FS.op.axis[3], nparts=num_thread)
s[FS].bind(tx, thread_x)
s[FS].bind(ty, thread_y)
def traverse(OP):
# inline all one-to-one-mapping operators except the last stage (output)
if OP.tag == 'ewise' or OP.tag == 'scale_shift':
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if str(tensor.op.input_tensors) != str([]):
traverse(tensor.op)
# schedule depthwise_conv2d
if OP.tag == 'depthwise_conv2d':
PaddedInput = OP.input_tensors[0]
Filter = OP.input_tensors[1]
DepthwiseConv2d = OP.output(0)
schedule_depthwise_conv2d(PaddedInput, Filter, DepthwiseConv2d)
traverse(op)
return s
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
@tvm.tag_scope(tag="ewise")
def exp(x): def exp(x):
"""Take exponential of input x. """Take exponential of input x.
...@@ -18,6 +19,7 @@ def exp(x): ...@@ -18,6 +19,7 @@ def exp(x):
return tvm.compute(x.shape, lambda *i: tvm.exp(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.exp(x(*i)))
@tvm.tag_scope(tag="ewise")
def tanh(x): def tanh(x):
"""Take hyperbolic tanh of input x. """Take hyperbolic tanh of input x.
...@@ -34,6 +36,7 @@ def tanh(x): ...@@ -34,6 +36,7 @@ def tanh(x):
return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.tanh(x(*i)))
@tvm.tag_scope(tag="ewise")
def log(x): def log(x):
"""Take logarithm of input x. """Take logarithm of input x.
...@@ -50,6 +53,7 @@ def log(x): ...@@ -50,6 +53,7 @@ def log(x):
return tvm.compute(x.shape, lambda *i: tvm.log(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.log(x(*i)))
@tvm.tag_scope(tag="ewise")
def sqrt(x): def sqrt(x):
"""Take square root of input x. """Take square root of input x.
...@@ -66,6 +70,7 @@ def sqrt(x): ...@@ -66,6 +70,7 @@ def sqrt(x):
return tvm.compute(x.shape, lambda *i: tvm.sqrt(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.sqrt(x(*i)))
@tvm.tag_scope(tag="ewise")
def sigmoid(x): def sigmoid(x):
"""Take sigmoid tanh of input x. """Take sigmoid tanh of input x.
...@@ -80,3 +85,20 @@ def sigmoid(x): ...@@ -80,3 +85,20 @@ def sigmoid(x):
The result. The result.
""" """
return tvm.compute(x.shape, lambda *i: tvm.sigmoid(x(*i))) return tvm.compute(x.shape, lambda *i: tvm.sigmoid(x(*i)))
@tvm.tag_scope(tag="ewise")
def relu(x):
"""Take relu of input x.
Parameters
----------
x : tvm.Tensor
Input argument.
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(x.shape, lambda *i: tvm.max(x(*i), 0))
# pylint: disable=wildcard-import
"""Neural network operators"""
from __future__ import absolute_import as _abs
from .util import *
from .mapping import *
from .conv import *
# pylint: disable=invalid-name, line-too-long, unused-variable
"""Convolution operators"""
from __future__ import absolute_import as _abs
import tvm
import numpy as np
from .util import get_const_tuple
@tvm.tag_scope(tag="depthwise_conv2d")
def depthwise_conv2d(Input, Filter, Stride, padding):
"""Depthwise convolution operator, as depthwise_conv2d in tensorflow.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Filter : tvm.Tensor
4-D with shape [in_channel, channel_multiplier, filter_height, filter_width]
Stride : tvm.Tensor
1-D of size 2
padding : str
'VALID' or 'SAME'
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
in_shape = get_const_tuple(Input.shape)
batch = in_shape[0]
in_channel = in_shape[1]
in_height = in_shape[2]
in_width = in_shape[3]
filter_shape = get_const_tuple(Filter.shape)
filter_channel = filter_shape[0]
channel_multiplier = filter_shape[1]
filter_height = filter_shape[2]
filter_width = filter_shape[3]
stride_h = Stride.asnumpy()[0]
stride_w = Stride.asnumpy()[1]
# calculate output shape
if padding == 'VALID':
out_channel = in_channel * channel_multiplier
out_height = (in_height - filter_height) / stride_h + 1
out_width = (in_width - filter_width) / stride_w + 1
pad_along_height = 0
pad_along_width = 0
if padding == 'SAME':
out_channel = in_channel * channel_multiplier
out_height = np.int(np.ceil(float(in_height) / float(stride_h)))
out_width = np.int(np.ceil(float(in_width) / float(stride_w)))
pad_along_height = np.int(np.max((out_height - 1) * stride_h + filter_height - in_height, 0))
pad_along_width = np.int(np.max((out_width - 1) * stride_w + filter_width - in_width, 0))
height_after_pad = in_height + pad_along_height
width_after_pad = in_width + pad_along_width
pad_top = np.int(np.ceil(float(pad_along_height) / 2))
pad_left = np.int(np.ceil(float(pad_along_width) / 2))
# padding stage
PaddedInput = tvm.compute(
(batch, in_channel, height_after_pad, width_after_pad),
lambda b, c, i, j: tvm.select(
tvm.all(i >= pad_top, i - pad_top < in_height, j >= pad_left, j - pad_left < in_width),
Input[b, c, i - pad_top, j - pad_left], tvm.const(0.0)),
name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_channel, out_height, out_width),
lambda b, c, i, j: tvm.sum(
PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] * Filter[c/channel_multiplier, c%channel_multiplier, di, dj],
axis=[di, dj]),
name='DepthwiseConv2d')
return Output
# pylint: disable=invalid-name, line-too-long
"""Operators of one-to-one-mapping on the first input"""
from __future__ import absolute_import as _abs
import tvm
@tvm.tag_scope(tag="scale_shift")
def scale_shift(Input, Scale, Shift):
"""Batch normalization operator in inference.
Parameters
----------
Input : tvm.Tensor
Input tensor, layout is NCHW
Scale : tvm.Tensor
Scale tensor, 1-D of size channel number
Shift : tvm.Tensor
Shift tensor, 1-D of size channel number
Returns
-------
Output : tvm.Tensor
Output tensor, layout is NCHW
"""
return tvm.compute(Input.shape, lambda b, c, i, j: Input[b, c, i, j] * Scale[c] + Shift[c], name='ScaleShift')
"""Common topi utilities"""
from __future__ import absolute_import as _abs
import tvm
def get_const_tuple(in_tuple):
"""Verifies input tuple is IntImm, returns tuple of int.
Parameters
----------
in_tuple : tuple of tvm.expr.IntImm
The input.
Returns
-------
out_tuple : tuple of int
The output.
"""
out_tuple = ()
for elem in in_tuple:
if not isinstance(elem, tvm.expr.IntImm):
raise ValueError("Element of input tuple should be IntImm")
out_tuple = out_tuple + (elem.value, )
return out_tuple
import os
import tvm
import numpy as np
from scipy import signal
from tvm.contrib import nvcc_compiler
import topi
from topi.nn.util import get_const_tuple
from topi.cuda.depthwise_conv2d_map import schedule_depthwise_conv2d_map
TASK = "depthwise_conv2d_map"
USE_MANUAL_CODE = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc_compiler.compile_source(code, target="ptx", options=["-arch=sm_52"])
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
if USE_MANUAL_CODE:
code = open("perf/%s_manual.cu" % TASK).read()
return code
def test_depthwise_conv2d_map():
"""You may test different settings."""
batch = 2
in_channel = 256
in_height = 32
in_width = 32
filter_channel = in_channel
channel_multiplier = 2
filter_height = 5
filter_width = 5
stride_h = 2
stride_w = 2
padding = 'SAME' # or 'VALID'
# Placeholder
Input = tvm.placeholder((batch, in_channel, in_height, in_width), name='Input')
Filter = tvm.placeholder((filter_channel, channel_multiplier, filter_height, filter_width), name='Filter')
Stride = tvm.nd.array(np.array([stride_h, stride_w]))
Scale = tvm.placeholder((in_channel * channel_multiplier,), name='Scale')
Shift = tvm.placeholder((in_channel * channel_multiplier,), name='Shift')
# Declare
DepthwiseConv2d = topi.nn.depthwise_conv2d(Input, Filter, Stride, padding)
ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift)
Relu = topi.ewise.relu(ScaleShift)
# Schedule
s1 = schedule_depthwise_conv2d_map(DepthwiseConv2d.op)
s2 = schedule_depthwise_conv2d_map(ScaleShift.op)
s3 = schedule_depthwise_conv2d_map(Relu.op)
def depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np):
out_shape = get_const_tuple(DepthwiseConv2d.shape)
out_channel = out_shape[1]
out_height = out_shape[2]
out_width = out_shape[3]
depthwise_conv2d_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=DepthwiseConv2d.dtype)
scale_shift_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=ScaleShift.dtype)
relu_scipy = np.zeros((batch, out_channel, out_height, out_width), dtype=Relu.dtype)
if padding == 'SAME':
pad_top_tvm = np.int(np.ceil(float(np.max((out_height - 1) * stride_h + filter_height - in_height, 0)) / 2))
pad_left_tvm = np.int(np.ceil(float(np.max((out_width - 1) * stride_w + filter_width - in_width, 0)) / 2))
pad_top_scipy = np.int(np.ceil(float(filter_height - 1) / 2))
pad_left_scipy = np.int(np.ceil(float(filter_width - 1) / 2))
index_h = pad_top_scipy - pad_top_tvm
index_w = pad_left_scipy - pad_left_tvm
for i in range(batch):
for j in range(out_channel):
depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j/channel_multiplier,:,:], np.rot90(filter_np[j/channel_multiplier,j%channel_multiplier,:,:], 2),
mode='same')[index_h:in_height:stride_h, index_w:in_width:stride_w]
if padding == 'VALID':
for i in range(batch):
for j in range(out_channel):
depthwise_conv2d_scipy[i,j,:,:] = signal.convolve2d(input_np[i,j/channel_multiplier,:,:], np.rot90(filter_np[j/channel_multiplier,j%channel_multiplier,:,:], 2),
mode='valid')[0:(in_height - filter_height + 1):stride_h, 0:(in_width - filter_height + 1):stride_w]
for c in range(out_channel):
scale_shift_scipy[:,c,:,:] = depthwise_conv2d_scipy[:,c,:,:] * scale_np[c] + shift_np[c]
relu_scipy[:,:,:,:] = np.maximum(scale_shift_scipy[:,:,:,:], 0)
return depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
# Build the kernel
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
f3 = tvm.build(s3, [Input, Filter, Scale, Shift, Relu], device)
# Prepare data
input_np = np.random.uniform(size=get_const_tuple(Input.shape)).astype(Input.dtype)
filter_np = np.random.uniform(size=get_const_tuple(Filter.shape)).astype(Filter.dtype)
input_tvm = tvm.nd.array(input_np, ctx)
filter_tvm = tvm.nd.array(filter_np, ctx)
scale_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Scale.dtype)
shift_np = np.random.uniform(size=(in_channel * channel_multiplier)).astype(Shift.dtype)
scale_tvm = tvm.nd.array(scale_np, ctx)
shift_tvm = tvm.nd.array(shift_np, ctx)
depthwise_conv2d_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(DepthwiseConv2d.shape), dtype=DepthwiseConv2d.dtype), ctx)
scale_shift_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(ScaleShift.shape), dtype=ScaleShift.dtype), ctx)
relu_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(Relu.shape), dtype=Relu.dtype), ctx)
# Measure time cost of kernel 1 (depthwise_conv2d)
timer_1 = f1.time_evaluator(f1.entry_name, ctx, number=10000)
tcost_1 = timer_1(input_tvm, filter_tvm, depthwise_conv2d_tvm)
# Measure time cost of kernel 2 (depthwise_conv2d + scale_shift)
timer_2 = f2.time_evaluator(f2.entry_name, ctx, number=10000)
tcost_2 = timer_2(input_tvm, filter_tvm, scale_tvm, shift_tvm, scale_shift_tvm)
# Measure time cost of kernel 3 (depthwise_conv2d + scale_shift + relu)
timer_3 = f3.time_evaluator(f3.entry_name, ctx, number=10000)
tcost_3 = timer_3(input_tvm, filter_tvm, scale_tvm, shift_tvm, relu_tvm)
print("Input shape = " + str(get_const_tuple(Input.shape)))
print("Filter shape = " + str(get_const_tuple(Filter.shape)))
print("Stride = (%d, %d)" % (stride_h, stride_w))
print("padding = %s\n" % padding)
print("Output shape = " + str(get_const_tuple(DepthwiseConv2d.shape)))
print("average time cost of 10000 runs (depthwise_conv2d) = %g sec" % tcost_1)
print("average time cost of 10000 runs (depthwise_conv2d + scale_shift) = %g sec" % tcost_2)
print("average time cost of 10000 runs (depthwise_conv2d + scale_shift + relu) = %g sec" % tcost_3)
depthwise_conv2d_scipy, scale_shift_scipy, relu_scipy = depthwise_conv2d_map_scipy(input_np, filter_np, scale_np, shift_np)
np.testing.assert_allclose(depthwise_conv2d_tvm.asnumpy(), depthwise_conv2d_scipy, rtol=1e-5)
np.testing.assert_allclose(scale_shift_tvm.asnumpy(), scale_shift_scipy, rtol=1e-5)
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
print "success"
with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0,
unroll_explicit=True,
detect_global_barrier=False,
restricted_func=True):
check_device("cuda")
if __name__ == "__main__":
test_depthwise_conv2d_map()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment