Commit eaea99c5 by Haichen Shen Committed by Tianqi Chen

[TOPI] Example for convolution in GPU (#212)

* [TOPI] Example for convolution

* update conv ex

* fix submodule HalideIR

* update conv impl

* python3

* minor fix

* fix pylint error

* Add test code

* x

* fix

* fix

* move python helper function into topi.testing

* fix pylint
parent 01cbc61a
......@@ -9,3 +9,4 @@ from __future__ import absolute_import as _abs
from .math import *
from . import nn
from . import cuda
from . import testing
......@@ -2,4 +2,5 @@
"""CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .conv2d_hwcn_map import schedule_conv2d_hwcn_map
from .depthwise_conv2d_map import schedule_depthwise_conv2d_map
# pylint: disable=invalid-name
"""Schedule for conv2d_hwcn with auto fusion"""
import tvm
def _schedule_conv2d_hwcn(op, sch):
assert len(op.input_tensors) == 2
Apad = op.input_tensors[0]
W = op.input_tensors[1]
B = op.output(0)
sch[Apad].compute_inline()
AA = sch.cache_read(Apad, "shared", [B])
WW = sch.cache_read(W, "shared", [B])
AL = sch.cache_read(AA, "local", [B])
WL = sch.cache_read(WW, "local", [B])
if op in sch.outputs:
Out = op.output(0)
BL = sch.cache_write(Out, "local")
else:
Out = sch.outputs[0].output(0)
sch[B].set_scope("local")
BL = B
tile = 8
num_thread = 8
block_factor = tile * num_thread
step = 8
vthread = 2
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
hi, wi, fi, ni = sch[Out].op.axis
bz = sch[Out].fuse(wi, hi)
by, fi = sch[Out].split(fi, factor=block_factor)
bx, ni = sch[Out].split(ni, factor=block_factor)
tyz, fi = sch[Out].split(fi, nparts=vthread)
txz, ni = sch[Out].split(ni, nparts=vthread)
ty, fi = sch[Out].split(fi, nparts=num_thread)
tx, ni = sch[Out].split(ni, nparts=num_thread)
sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)
sch[Out].bind(bz, block_z)
sch[Out].bind(by, block_y)
sch[Out].bind(bx, block_x)
sch[Out].bind(tyz, thread_yz)
sch[Out].bind(txz, thread_xz)
sch[Out].bind(ty, thread_y)
sch[Out].bind(tx, thread_x)
# Schedule BL local write
sch[BL].compute_at(sch[Out], tx)
yi, xi, fi, ni = sch[BL].op.axis
ry, rx, rc = sch[BL].op.reduce_axis
rco, rci = sch[BL].split(rc, factor=step)
sch[BL].reorder(rco, ry, rx, rci, fi, ni)
fuse_index = sch[BL].fuse(rx, ry)
fuse_index = sch[BL].fuse(fuse_index, rco)
rx = fuse_index
sch[AA].compute_at(sch[BL], rx)
sch[WW].compute_at(sch[BL], rx)
sch[AL].compute_at(sch[BL], rci)
sch[WL].compute_at(sch[BL], rci)
# Schedule for A's shared memory load
yi, xi, ci, ni = sch[AA].op.axis
ty, ci = sch[AA].split(ci, nparts=num_thread)
tx, ni = sch[AA].split(ni, nparts=num_thread)
_, ni = sch[AA].split(ni, factor=4)
sch[AA].reorder(ty, tx, yi, xi, ci, ni)
sch[AA].bind(ty, thread_y)
sch[AA].bind(tx, thread_x)
sch[AA].vectorize(ni)
# Schedule for W's shared memory load
yi, xi, ci, fi = sch[WW].op.axis
ty, ci = sch[WW].split(ci, nparts=num_thread)
tx, fi = sch[WW].split(fi, nparts=num_thread)
_, fi = sch[WW].split(fi, factor=4)
sch[WW].reorder(ty, tx, yi, xi, ci, fi)
sch[WW].bind(ty, thread_y)
sch[WW].bind(tx, thread_x)
sch[WW].vectorize(fi)
return sch
def schedule_conv2d_hwcn_map(op):
"""Schedule for conv2d_hwcn map ops.
Parameters
----------
op: tvm.tensor.Operation
The symbolic description of the operation, should be conv2d_hwcn or
conv2d_hwcn followed by a sequence of one-to-one-mapping operators.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
def traverse(operator):
if operator.tag == 'ewise' or operator.tag == 'scale_shift':
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif operator.tag == 'conv2d_hwcn':
_schedule_conv2d_hwcn(operator, sch)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
sch = tvm.create_schedule(op)
traverse(op)
return sch
......@@ -5,6 +5,73 @@ import tvm
import numpy as np
from .util import get_const_tuple
@tvm.tag_scope(tag="conv2d_hwcn")
def conv2d_hwcn(Input, Filter, stride, padding):
"""Convolution operator in HWCN layout.
Parameters
----------
Input : tvm.Tensor
4-D with shape [in_height, in_width, in_channel, batch]
Filter : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, num_filter]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
Output : tvm.Tensor
4-D with shape [out_height, out_width, out_channel, batch]
"""
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(padding, int) or padding in ['VALID', 'SAME']
in_height, in_width, in_channel, batch = get_const_tuple(Input.shape)
kernel_h, kernel_w, channel, num_filter = get_const_tuple(Filter.shape)
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
# compute the padding size
if isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == 'VALID':
pad_h = 0
pad_w = 0
else: # 'SAME'
pad_h = kernel_h - 1
pad_w = kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_left = int(np.ceil(float(pad_w) / 2))
# compute the output shape
out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
# compute graph
PaddedInput = tvm.compute(
(in_height + pad_h, in_width + pad_w, in_channel, batch),
lambda yy, xx, cc, nn: tvm.select(
tvm.all(yy >= pad_top, yy - pad_top < in_height,
xx >= pad_left, xx - pad_left < in_width),
Input[yy - pad_top, xx - pad_left, cc, nn], tvm.const(0.)),
name='PaddedInput')
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
Output = tvm.compute(
(out_height, out_width, out_channel, batch),
lambda yy, xx, ff, nn: tvm.sum(
PaddedInput[yy * stride_h + ry, xx * stride_w + rx, rc, nn] * Filter[ry, rx, rc, ff],
axis=[ry, rx, rc]),
name='Conv2dOutput')
return Output
@tvm.tag_scope(tag="depthwise_conv2d")
def depthwise_conv2d(Input, Filter, Stride, padding):
"""Depthwise convolution operator.
......
"""TOPI Testing Util functions.
Used to verify the correctness of operators in TOPI .
"""
from __future__ import absolute_import as _abs
from .conv2d_hwcn_python import conv2d_hwcn_python
# pylint: disable=invalid-name, line-too-long, unused-variable
"""Convolution in python"""
import numpy as np
import scipy.signal
def conv2d_hwcn_python(a_np, w_np, stride, padding):
"""Convolution operator in HWCN layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [in_height, in_width, in_channel, batch]
w_np : numpy.ndarray
4-D with shape [filter_height, filter_width, in_channel, num_filter]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
b_np : np.ndarray
4-D with shape [out_height, out_width, out_channel, batch]
"""
in_height, in_width, in_channel, batch = a_np.shape
kernel_h, kernel_w, _, num_filter = w_np.shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == 'VALID':
pad_h = 0
pad_w = 0
else: # 'SAME'
pad_h = kernel_h - 1
pad_w = kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_bottom = pad_h - pad_top
pad_left = int(np.ceil(float(pad_w) / 2))
pad_right = pad_w - pad_left
# compute the output shape
out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
# change the layout from HWCN to NCHW
at = a_np.transpose((3, 2, 0, 1))
wt = w_np.transpose((3, 2, 0, 1))
bt = np.zeros((batch, out_channel, out_height, out_width))
# computation
for n in range(batch):
for f in range(out_channel):
for c in range(in_channel):
if pad_h > 0:
apad = np.zeros((in_height + pad_h, in_width + pad_w))
apad[pad_top:-pad_bottom, pad_left:-pad_right] = at[n, c]
else:
apad = at[n, c]
out = scipy.signal.convolve2d(
apad, np.rot90(np.rot90(wt[f, c])), mode='valid')
bt[n, f] += out[::stride, ::stride]
return bt.transpose((2, 3, 1, 0))
"""Example code to do convolution."""
import os
import numpy as np
import scipy.signal
import tvm
from tvm.contrib import nvcc
import topi
from topi.nn.util import get_const_tuple
TASK = "conv2d_hwcn_map"
USE_MANUAL_CODE = False
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"])
return ptx
def write_code(code, fname):
with open(fname, "w") as f:
f.write(code)
@tvm.register_func
def tvm_callback_cuda_postproc(code):
if not os.path.exists("perf"):
os.mkdir("perf")
write_code(code, "perf/%s_generated.cu" % TASK)
if USE_MANUAL_CODE:
code = open("perf/%s_manual.cu" % TASK).read()
return code
def test_conv2d_hwcn_map():
batch = 64
in_channel = 128
in_height = 16
in_width = 16
num_filter = 128
kernel = 3
stride = 2
padding = 'SAME'
A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = topi.nn.conv2d_hwcn(A, W, stride, padding)
C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_hwcn_map(B.op)
s2 = topi.cuda.schedule_conv2d_hwcn_map(C.op)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
b_np = topi.testing.conv2d_hwcn_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0,
unroll_explicit=False):
func1 = tvm.build(s1, [A, W, B], device)
func1(a, w, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
func2 = tvm.build(s2, [A, W, C], device)
func2(a, w, c)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl']:
check_device(device)
if __name__ == "__main__":
test_conv2d_hwcn_map()
"""Example code to do convolution."""
import os
import numpy as np
import tvm
import topi
from topi.nn.util import get_const_tuple
def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, stride, padding):
in_height = in_width = in_size
A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = topi.nn.conv2d_hwcn(A, W, stride, padding)
C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_hwcn_map(B.op)
s2 = topi.cuda.schedule_conv2d_hwcn_map(C.op)
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
b_np = topi.testing.conv2d_hwcn_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0,
unroll_explicit=False):
func1 = tvm.build(s1, [A, W, B], device)
func2 = tvm.build(s2, [A, W, C], device)
func1(a, w, b)
func2(a, w, c)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']:
check_device(device)
def test_conv2d_hwcn_map():
verify_conv2d_hwcn_map(1, 256, 32, 256, 3, 1, "SAME")
verify_conv2d_hwcn_map(1, 256, 32, 256, 3, 1, "SAME")
verify_conv2d_hwcn_map(4, 128, 16, 128, 5, 2, "SAME")
verify_conv2d_hwcn_map(4, 128, 16, 256, 5, 2, "SAME")
verify_conv2d_hwcn_map(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_hwcn_map(1, 256, 32, 256, 3, 1, "VALID")
verify_conv2d_hwcn_map(4, 128, 16, 128, 5, 2, "VALID")
verify_conv2d_hwcn_map(4, 128, 16, 256, 5, 2, "VALID")
if __name__ == "__main__":
test_conv2d_hwcn_map()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment