Commit cbff637f by Leyuan Wang Committed by Tianqi Chen

[TOPI] conv2d nchw gpu scheduler (#315)

* __init__ updated

* pull request updated

* build_module added

* typo fixed

* another typo fixed

* conv2d gpu scheduler for two layouts moved to tvm

* changes made according to CR

* conv2d_nchw formating updated, conv2d_hwcn tests updated

* lint error fixed

* element wise operator schedule fusing fixed for conv2d

* conv2d_nchw topi test added, all resnet workloads now pass

* conv compute lint error fixed

* fixed python 3 compatibility problem

* conv2d tensor input support added, test typo fixed, ir_pass.Simplify changed to util.get_const_int
parent d76712d1
...@@ -2,5 +2,6 @@ ...@@ -2,5 +2,6 @@
"""CUDA specific declaration and schedules.""" """CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .conv2d_hwcn_map import schedule_conv2d_hwcn_map from .conv2d_nchw import schedule_conv2d_nchw
from .conv2d_hwcn import schedule_conv2d_hwcn
from .depthwise_conv2d_map import schedule_depthwise_conv2d_map from .depthwise_conv2d_map import schedule_depthwise_conv2d_map
# pylint: disable=invalid-name # pylint: disable=invalid-name, too-many-locals, too-many-statements
"""Schedule for conv2d_hwcn with auto fusion""" """Schedule for conv2d_hwcn with auto fusion"""
import tvm import tvm
def _schedule_conv2d_hwcn(op, sch): def schedule_conv2d_hwcn(outs):
assert len(op.input_tensors) == 2 """Schedule for conv2d_hwcn and any element-wise operations.
Apad = op.input_tensors[0]
W = op.input_tensors[1]
B = op.output(0)
sch[Apad].compute_inline() Parameters
AA = sch.cache_read(Apad, "shared", [B]) ----------
WW = sch.cache_read(W, "shared", [B]) outs: Array of Tensor
AL = sch.cache_read(AA, "local", [B]) The computation graph description of conv2d_hwcn in the format
WL = sch.cache_read(WW, "local", [B]) of an array of tensors.
if op in sch.outputs:
Out = op.output(0)
BL = sch.cache_write(Out, "local")
else:
Out = sch.outputs[0].output(0)
sch[B].set_scope("local")
BL = B
tile = 8
num_thread = 8
block_factor = tile * num_thread
step = 8
vthread = 2
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
hi, wi, fi, ni = sch[Out].op.axis Returns
bz = sch[Out].fuse(hi, wi) -------
by, fi = sch[Out].split(fi, factor=block_factor) s: Schedule
bx, ni = sch[Out].split(ni, factor=block_factor) The computation schedule for conv2d_hwcn.
tyz, fi = sch[Out].split(fi, nparts=vthread) """
txz, ni = sch[Out].split(ni, nparts=vthread) outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
ty, fi = sch[Out].split(fi, nparts=num_thread) sch = tvm.create_schedule([x.op for x in outs])
tx, ni = sch[Out].split(ni, nparts=num_thread) def schedule(Apad, W, B):
sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni) """Schedule conv2d_hwcn"""
sch[Out].bind(bz, block_z) sch[Apad].compute_inline()
sch[Out].bind(by, block_y) AA = sch.cache_read(Apad, "shared", [B])
sch[Out].bind(bx, block_x) WW = sch.cache_read(W, "shared", [B])
sch[Out].bind(tyz, thread_yz) AL = sch.cache_read(AA, "local", [B])
sch[Out].bind(txz, thread_xz) WL = sch.cache_read(WW, "local", [B])
sch[Out].bind(ty, thread_y)
sch[Out].bind(tx, thread_x)
# Schedule BL local write if B.op in sch.outputs:
sch[BL].compute_at(sch[Out], tx) Out = B
yi, xi, fi, ni = sch[BL].op.axis BL = sch.cache_write(Out, "local")
ry, rx, rc = sch[BL].op.reduce_axis else:
rco, rci = sch[BL].split(rc, factor=step) Out = sch.outputs[0].output(0)
sch[BL].reorder(rco, ry, rx, rci, fi, ni) sch[B].set_scope("local")
fuse_index = sch[BL].fuse(ry, rx) BL = B
fuse_index = sch[BL].fuse(fuse_index, rco)
rx = fuse_index
sch[AA].compute_at(sch[BL], rx) tile = 8
sch[WW].compute_at(sch[BL], rx) num_thread = 8
sch[AL].compute_at(sch[BL], rci) block_factor = tile * num_thread
sch[WL].compute_at(sch[BL], rci) step = 8
# Schedule for A's shared memory load vthread = 2
yi, xi, ci, ni = sch[AA].op.axis
ty, ci = sch[AA].split(ci, nparts=num_thread)
tx, ni = sch[AA].split(ni, nparts=num_thread)
_, ni = sch[AA].split(ni, factor=4)
sch[AA].reorder(ty, tx, yi, xi, ci, ni)
sch[AA].bind(ty, thread_y)
sch[AA].bind(tx, thread_x)
sch[AA].vectorize(ni)
# Schedule for W's shared memory load
yi, xi, ci, fi = sch[WW].op.axis
ty, ci = sch[WW].split(ci, nparts=num_thread)
tx, fi = sch[WW].split(fi, nparts=num_thread)
_, fi = sch[WW].split(fi, factor=4)
sch[WW].reorder(ty, tx, yi, xi, ci, fi)
sch[WW].bind(ty, thread_y)
sch[WW].bind(tx, thread_x)
sch[WW].vectorize(fi)
return sch block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
hi, wi, fi, ni = sch[Out].op.axis
bz = sch[Out].fuse(hi, wi)
by, fi = sch[Out].split(fi, factor=block_factor)
bx, ni = sch[Out].split(ni, factor=block_factor)
tyz, fi = sch[Out].split(fi, nparts=vthread)
txz, ni = sch[Out].split(ni, nparts=vthread)
ty, fi = sch[Out].split(fi, nparts=num_thread)
tx, ni = sch[Out].split(ni, nparts=num_thread)
sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)
sch[Out].bind(bz, block_z)
sch[Out].bind(by, block_y)
sch[Out].bind(bx, block_x)
sch[Out].bind(tyz, thread_yz)
sch[Out].bind(txz, thread_xz)
sch[Out].bind(ty, thread_y)
sch[Out].bind(tx, thread_x)
def schedule_conv2d_hwcn_map(op): # Schedule BL local write
"""Schedule for conv2d_hwcn map ops. sch[BL].compute_at(sch[Out], tx)
yi, xi, fi, ni = sch[BL].op.axis
ry, rx, rc = sch[BL].op.reduce_axis
rco, rci = sch[BL].split(rc, factor=step)
sch[BL].reorder(rco, ry, rx, rci, fi, ni)
fuse_index = sch[BL].fuse(ry, rx)
fuse_index = sch[BL].fuse(fuse_index, rco)
rx = fuse_index
Parameters sch[AA].compute_at(sch[BL], rx)
---------- sch[WW].compute_at(sch[BL], rx)
op: tvm.tensor.Operation sch[AL].compute_at(sch[BL], rci)
The symbolic description of the operation, should be conv2d_hwcn or sch[WL].compute_at(sch[BL], rci)
conv2d_hwcn followed by a sequence of one-to-one-mapping operators. # Schedule for A's shared memory load
yi, xi, ci, ni = sch[AA].op.axis
ty, ci = sch[AA].split(ci, nparts=num_thread)
tx, ni = sch[AA].split(ni, nparts=num_thread)
_, ni = sch[AA].split(ni, factor=4)
sch[AA].reorder(ty, tx, yi, xi, ci, ni)
sch[AA].bind(ty, thread_y)
sch[AA].bind(tx, thread_x)
sch[AA].vectorize(ni)
# Schedule for W's shared memory load
yi, xi, ci, fi = sch[WW].op.axis
ty, ci = sch[WW].split(ci, nparts=num_thread)
tx, fi = sch[WW].split(fi, nparts=num_thread)
_, fi = sch[WW].split(fi, factor=4)
sch[WW].reorder(ty, tx, yi, xi, ci, fi)
sch[WW].bind(ty, thread_y)
sch[WW].bind(tx, thread_x)
sch[WW].vectorize(fi)
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
def traverse(operator): def traverse(operator):
"""Traverse operators from computation graph"""
if operator.tag == 'ewise' or operator.tag == 'scale_shift': if operator.tag == 'ewise' or operator.tag == 'scale_shift':
if operator not in sch.outputs: if operator not in sch.outputs:
sch[operator].compute_inline() sch[operator].compute_inline()
...@@ -112,10 +108,12 @@ def schedule_conv2d_hwcn_map(op): ...@@ -112,10 +108,12 @@ def schedule_conv2d_hwcn_map(op):
if tensor.op.input_tensors: if tensor.op.input_tensors:
traverse(tensor.op) traverse(tensor.op)
elif operator.tag == 'conv2d_hwcn': elif operator.tag == 'conv2d_hwcn':
_schedule_conv2d_hwcn(operator, sch) Apad = operator.input_tensors[0]
W = operator.input_tensors[1]
B = operator.output(0)
schedule(Apad, W, B)
else: else:
raise RuntimeError("Unsupported operator: %s" % operator.tag) raise RuntimeError("Unsupported operator: %s" % operator.tag)
sch = tvm.create_schedule(op) traverse(outs[0].op)
traverse(op)
return sch return sch
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-statements
"""Schedule for conv2d_nchw with auto fusion"""
import tvm
from .. import util
def schedule_conv2d_small_batch(outs):
"""Create schedule for tensors or return error if batch size is larager than 1"""
s = tvm.create_schedule([x.op for x in outs])
def schedule(temp, Filter, Output):
"""Schedule conv2d_nchw"""
block_h = util.get_const_int(Output.shape[3])
block_w = util.get_const_int(temp.shape[1])
if block_h % 48 == 0:
block_h = 48
elif block_h % 32 == 0:
block_h = 32
if block_w % 48 == 0:
block_w = 48
elif block_w % 32 == 0:
block_w = 32
s[temp].compute_inline()
temp_S = s.cache_read(temp, "shared", [Output])
Filter_S = s.cache_read(Filter, "shared", [Output])
if Output.op in s.outputs:
Out = Output
Out_L = s.cache_write(Out, "local")
else:
Out = outs[0].op.output(0)
s[Output].set_scope("local")
Out_L = Output
# sheduler params
num_thread = 8
vthread = 2
out_filter = min(64, util.get_const_int(Filter.shape[0]))
in_filter = util.get_const_int(Filter.shape[1])
opart2 = out_filter//8
ofactor = out_filter
wfactor = block_h
ifactor = in_filter//4
sfactor = max(1, ofactor//(opart2*2))
spart = (wfactor + vthread-1) // vthread
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
i, oc, h, w = s[Out].op.axis
ooc, ioc = s[Out].split(oc, factor=ofactor)
ow, iw = s[Out].split(w, factor=wfactor)
ow = s[Out].fuse(ow, h)
oioc, iioc = s[Out].split(ioc, nparts=vthread)
oiw, iiw = s[Out].split(iw, nparts=vthread)
oiioc, iiioc = s[Out].split(iioc, nparts=opart2)
s[Out].reorder(i, ooc, ow, oioc, oiw, oiioc, iiw, iiioc)
s[Out].bind(iiioc, thread_x)
s[Out].bind(iiw, thread_y)
s[Out].bind(oiioc, thread_xz)
s[Out].bind(oiw, thread_yz)
s[Out].bind(oioc, block_x)
s[Out].bind(ow, block_y)
s[Out].bind(ooc, block_z)
s[Out_L].compute_at(s[Out], iiioc)
# schedule Out_L local write
i, oc, h, w = s[Out_L].op.axis
ic, dh, dw = s[Out_L].op.reduce_axis
oic, iic = s[Out_L].split(ic, factor=ifactor)
s[Out_L].reorder(oic, dh, dw, iic, h, w)
fuse_index = s[Out_L].fuse(dw, dh)
fuse_index = s[Out_L].fuse(fuse_index, oic)
dw = fuse_index
s[temp_S].compute_at(s[Out_L], dw)
s[Filter_S].compute_at(s[Out_L], dw)
#schedule temp_S shared mem load
i, ic, h, w = s[temp_S].op.axis
_, iic = s[temp_S].split(ic, factor=sfactor)
_, iw = s[temp_S].split(w, factor=spart)
s[temp_S].bind(iic, thread_x)
s[temp_S].bind(iw, thread_y)
#schedule Filter_S shared mem load
i, oc, h, w = s[Filter_S].op.axis
_, ioc = s[Filter_S].split(oc, factor=sfactor)
_, ii = s[Filter_S].split(i, factor=spart)
s[Filter_S].bind(ioc, thread_x)
s[Filter_S].bind(ii, thread_y)
def traverse(OP):
"""Traverse operators from computation graph"""
# inline all one-to-one-mapping operators except the last stage (output)
if 'ewise' in OP.tag or 'bcast' in OP.tag:
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule conv2d
if 'conv2d_nchw' in OP.tag:
temp = OP.input_tensors[0]
Filter = OP.input_tensors[1]
Output = OP.output(0)
schedule(temp, Filter, Output)
traverse(outs[0].op)
return s
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw and any element-wise operations.
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_nchw
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for conv2d_nchw.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
batch_size = util.get_const_int(outs[0].op.output(0).shape[0])
if batch_size > 1:
raise RuntimeError("Batch size: %d is too large for this schedule" % batch_size)
return schedule_conv2d_small_batch(outs)
# pylint: disable=invalid-name, line-too-long, unused-variable # pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Convolution operators""" """Convolution operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
...@@ -6,6 +6,69 @@ import numpy as np ...@@ -6,6 +6,69 @@ import numpy as np
from ..util import get_const_tuple from ..util import get_const_tuple
@tvm.tag_scope(tag="conv2d_nchw")
def conv2d_nchw(Input, Filter, stride, padding):
"""Convolution operator in HWCN layout.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Filter : tvm.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(padding, int) or padding in ['VALID', 'SAME']
batch, in_channel, in_height, in_width = get_const_tuple(Input.shape)
num_filter, channel, kernel_h, kernel_w = get_const_tuple(Filter.shape)
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
# compute the padding size
if isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == 'VALID':
pad_h = 0
pad_w = 0
else: # 'SAME'
pad_h = kernel_h - 1
pad_w = kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_left = int(np.ceil(float(pad_w) / 2))
# compute the output shape
out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
# compute graph
temp = tvm.compute(
(batch, in_channel, in_height + pad_h, in_width + pad_w),
lambda nn, cc, yy, xx: tvm.select(
tvm.all(yy >= pad_top, yy - pad_top < in_height,
xx >= pad_left, xx - pad_left < in_width),
Input[nn, cc, yy - pad_top, xx - pad_left], tvm.const(0.)),
name='temp')
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
return tvm.compute(
(batch, out_channel, out_height, out_width),
lambda nn, ff, yy, xx: tvm.sum(
temp[nn, rc, yy * stride_h + ry, xx * stride_w + rx] * Filter[ff, rc, ry, rx],
axis=[rc, ry, rx]))
@tvm.tag_scope(tag="conv2d_hwcn") @tvm.tag_scope(tag="conv2d_hwcn")
def conv2d_hwcn(Input, Filter, stride, padding): def conv2d_hwcn(Input, Filter, stride, padding):
"""Convolution operator in HWCN layout. """Convolution operator in HWCN layout.
......
...@@ -5,3 +5,4 @@ Used to verify the correctness of operators in TOPI . ...@@ -5,3 +5,4 @@ Used to verify the correctness of operators in TOPI .
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from .conv2d_hwcn_python import conv2d_hwcn_python from .conv2d_hwcn_python import conv2d_hwcn_python
from .conv2d_nchw_python import conv2d_nchw_python
# pylint: disable=invalid-name, line-too-long, unused-variable # pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Convolution in python""" """Convolution in python"""
import numpy as np import numpy as np
import scipy.signal import scipy.signal
......
# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
"""Convolution in python"""
import numpy as np
import scipy.signal
def conv2d_nchw_python(a_np, w_np, stride, padding):
"""Convolution operator in HWCN layout.
Parameters
----------
a_np : numpy.ndarray
4-D with shape [batch, in_channel, in_height, in_width]
w_np : numpy.ndarray
4-D with shape [num_filter, in_channel, filter_height, filter_width]
stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]
padding : int or str
Padding size, or ['VALID', 'SAME']
Returns
-------
b_np : np.ndarray
4-D with shape [batch, out_channel, out_height, out_width]
"""
batch, in_channel, in_height, in_width = a_np.shape
num_filter, _, kernel_h, kernel_w = w_np.shape
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == 'VALID':
pad_h = 0
pad_w = 0
else: # 'SAME'
pad_h = kernel_h - 1
pad_w = kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_bottom = pad_h - pad_top
pad_left = int(np.ceil(float(pad_w) / 2))
pad_right = pad_w - pad_left
# compute the output shape
out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) // stride_h + 1
out_width = (in_width - kernel_w + pad_w) // stride_w + 1
b_np = np.zeros((batch, out_channel, out_height, out_width))
# computation
for n in range(batch):
for f in range(out_channel):
for c in range(in_channel):
if pad_h > 0:
apad = np.zeros((in_height + pad_h, in_width + pad_w))
apad[pad_top:-pad_bottom, pad_left:-pad_right] = a_np[n, c]
else:
apad = a_np[n, c]
out = scipy.signal.convolve2d(
apad, np.rot90(np.rot90(w_np[f, c])), mode='valid')
b_np[n, f] += out[::stride, ::stride]
return b_np
...@@ -12,7 +12,7 @@ USE_MANUAL_CODE = False ...@@ -12,7 +12,7 @@ USE_MANUAL_CODE = False
@tvm.register_func @tvm.register_func
def tvm_callback_cuda_compile(code): def tvm_callback_cuda_compile(code):
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_37"])
return ptx return ptx
def write_code(code, fname): def write_code(code, fname):
...@@ -43,8 +43,8 @@ def test_conv2d_hwcn_map(): ...@@ -43,8 +43,8 @@ def test_conv2d_hwcn_map():
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = topi.nn.conv2d_hwcn(A, W, stride, padding) B = topi.nn.conv2d_hwcn(A, W, stride, padding)
C = topi.nn.relu(B) C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_hwcn_map(B.op) s1 = topi.cuda.schedule_conv2d_hwcn([B])
s2 = topi.cuda.schedule_conv2d_hwcn_map(C.op) s2 = topi.cuda.schedule_conv2d_hwcn([C])
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype) w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
......
...@@ -13,8 +13,8 @@ def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, strid ...@@ -13,8 +13,8 @@ def verify_conv2d_hwcn_map(batch, in_channel, in_size, num_filter, kernel, strid
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W') W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = topi.nn.conv2d_hwcn(A, W, stride, padding) B = topi.nn.conv2d_hwcn(A, W, stride, padding)
C = topi.nn.relu(B) C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_hwcn_map(B.op) s1 = topi.cuda.schedule_conv2d_hwcn([B])
s2 = topi.cuda.schedule_conv2d_hwcn_map(C.op) s2 = topi.cuda.schedule_conv2d_hwcn([C])
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype) w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
......
"""Example code to do convolution."""
import os
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding):
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
B = topi.nn.conv2d_nchw(A, W, stride, padding)
C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_nchw([B])
s2 = topi.cuda.schedule_conv2d_nchw([C])
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype)
b_np = topi.testing.conv2d_nchw_python(a_np, w_np, stride, padding)
c_np = np.maximum(b_np, 0)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
with tvm.build_config(auto_unroll_max_step=32,
auto_unroll_min_depth=0,
unroll_explicit=False):
func1 = tvm.build(s1, [A, W, B], device)
func2 = tvm.build(s2, [A, W, C], device)
func1(a, w, b)
func2(a, w, c)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal']:
check_device(device)
def test_conv2d_nchw():
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw(1, 64, 56, 64, 1, 1, 0)
verify_conv2d_nchw(1, 64, 56, 128, 3, 2, 1)
verify_conv2d_nchw(1, 64, 56, 128, 1, 2, 0)
verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
verify_conv2d_nchw(1, 128, 28, 256, 3, 2, 1)
verify_conv2d_nchw(1, 128, 28, 256, 1, 2, 0)
verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
verify_conv2d_nchw(1, 256, 14, 512, 3, 2, 1)
verify_conv2d_nchw(1, 256, 14, 512, 1, 2, 0)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
if __name__ == "__main__":
test_conv2d_nchw()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment