Commit 4ab73634 by Cody Hao Yu Committed by Leyuan Wang

[TOPI] Tunable Template for Conv2D HWCN on CUDA (#4168)

* support conv2d HWCN in AutoTVM and Relay

* fix lint

* fix comments and unit tests
parent 2e0dbaa6
......@@ -226,7 +226,7 @@ def args_to_workload(x, topi_compute_func=None):
elif x is None:
workload = 0
else:
raise RuntimeError('Do not support type "%s" in argument. Consider to use'
raise RuntimeError('Do not support type "%s" in argument. Consider to use '
'primitive types only' % type(x))
return (get_func_name(topi_compute_func), ) + workload if topi_compute_func else workload
......
......@@ -176,9 +176,12 @@ class TaskExtractEnv:
args = deserialize_args(args)
A, W = args[:2]
layout = args[-2]
assert layout == 'NCHW', "only support NCHW currently"
assert layout == 'NCHW' or layout == 'HWCN', "only support NCHW/HWCN currently"
C = topi.nn.conv2d(*args, **kwargs)
s = topi.generic.schedule_conv2d_nchw([C])
if layout == 'NCHW':
s = topi.generic.schedule_conv2d_nchw([C])
else:
s = topi.generic.schedule_conv2d_hwcn([C])
return s, [A, W, C]
@register("topi_nn_depthwise_conv2d_nchw")
......
......@@ -153,14 +153,14 @@ def compute_conv2d(attrs, inputs, out_type, target):
out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
else out_dtype)
assert layout in ["NCHW", "NHWC", "NCHW4c"]
assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"]
(dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
def _get_out_depth():
weight_shape = get_const_tuple(inputs[1].shape)
if kernel_layout == "HWOI":
if kernel_layout.startswith("HW"):
return weight_shape[2] * weight_shape[3]
return weight_shape[0] * weight_shape[1]
......@@ -192,11 +192,13 @@ def schedule_conv2d(attrs, outs, target):
with target:
if groups == 1 and layout == "NCHW":
return topi.generic.schedule_conv2d_nchw(outs)
if groups == 1 and layout == "NCHW4c":
elif groups == 1 and layout == "NCHW4c":
return topi.generic.schedule_conv2d_nchw(outs)
if groups == 1 and layout == "NHWC":
elif groups == 1 and layout == "NHWC":
return topi.generic.schedule_conv2d_nhwc(outs)
if groups != 1:
elif groups == 1 and layout == "HWCN":
return topi.generic.schedule_conv2d_hwcn(outs)
elif groups != 1:
# collect in_channels to distinguish depthwise and group conv2d
op = _find_conv2d_op(outs[0].op)
assert op is not None
......
......@@ -368,7 +368,6 @@ class Vectorizer : public IRMutator {
CHECK(!op->extent.type().is_vector());
Expr extent = Mutate(op->extent);
if (extent.type().is_vector()) {
LOG(WARNING) << "Detect vectorized extent type, scalarizing...";
return Scalarize(s);
}
Stmt body = Mutate(op->body);
......@@ -386,7 +385,6 @@ class Vectorizer : public IRMutator {
CHECK(!op->condition.type().is_vector());
Expr condition = this->Mutate(op->condition);
if (condition.type().is_vector()) {
LOG(WARNING) << "Detect vector condition in Vectorized Loop, scalarizing...";
return Scalarize(s);
}
Stmt then_case = this->Mutate(op->then_case);
......
......@@ -17,9 +17,14 @@
# pylint: disable=invalid-name, too-many-locals, too-many-statements
"""Schedule for conv2d_hwcn with auto fusion"""
import tvm
from .. import tag
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
def schedule_conv2d_hwcn(outs):
from .. import generic, tag
@autotvm.register_topi_schedule(generic.schedule_conv2d_hwcn, ["cuda", "gpu"], ["direct"])
def schedule_conv2d_hwcn(cfg, outs):
"""Schedule for conv2d_hwcn and any element-wise operations.
Parameters
......@@ -51,36 +56,44 @@ def schedule_conv2d_hwcn(outs):
sch[B].set_scope("local")
BL = B
tile = 8
num_thread = 8
block_factor = tile * num_thread
step = 8
vthread = 2
hi, wi, fi, ni = sch[Out].op.axis
block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")
# Create tuning space
n_thread_cand = [1, 2, 4, 8, 16, 32]
vthread_cand = [1, 2, 4, 8]
cfg.define_split(
'tile_fi',
fi,
num_outputs=4,
filter=lambda x:
(x.size[1] in vthread_cand and x.size[2] in n_thread_cand))
cfg.define_split(
'tile_ni',
ni,
num_outputs=4,
filter=lambda x:
(x.size[1] in vthread_cand and x.size[2] in n_thread_cand))
if cfg.is_fallback:
cfg['tile_fi'] = SplitEntity([-1, 2, 8, 4])
cfg['tile_ni'] = SplitEntity([-1, 2, 8, 4])
# Scheduling
step = 8
hi, wi, fi, ni = sch[Out].op.axis
bz = sch[Out].fuse(hi, wi)
by, fi = sch[Out].split(fi, factor=block_factor)
bx, ni = sch[Out].split(ni, factor=block_factor)
tyz, fi = sch[Out].split(fi, nparts=vthread)
txz, ni = sch[Out].split(ni, nparts=vthread)
ty, fi = sch[Out].split(fi, nparts=num_thread)
tx, ni = sch[Out].split(ni, nparts=num_thread)
by, tyz, ty, fi = cfg['tile_fi'].apply(sch, Out, fi)
bx, txz, tx, ni = cfg['tile_ni'].apply(sch, Out, ni)
sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)
sch[Out].bind(bz, block_z)
sch[Out].bind(by, block_y)
sch[Out].bind(bx, block_x)
sch[Out].bind(tyz, thread_yz)
sch[Out].bind(txz, thread_xz)
sch[Out].bind(ty, thread_y)
sch[Out].bind(tx, thread_x)
sch[Out].bind(bz, tvm.thread_axis('blockIdx.z'))
sch[Out].bind(by, tvm.thread_axis('blockIdx.y'))
sch[Out].bind(bx, tvm.thread_axis('blockIdx.x'))
sch[Out].bind(tyz, tvm.thread_axis('vthread'))
sch[Out].bind(txz, tvm.thread_axis('vthread'))
sch[Out].bind(ty, tvm.thread_axis('threadIdx.y'))
sch[Out].bind(tx, tvm.thread_axis('threadIdx.x'))
# Schedule BL local write
sch[BL].compute_at(sch[Out], tx)
......@@ -98,21 +111,21 @@ def schedule_conv2d_hwcn(outs):
sch[WL].compute_at(sch[BL], rci)
# Schedule for A's shared memory load
yi, xi, ci, ni = sch[AA].op.axis
ty, ci = sch[AA].split(ci, nparts=num_thread)
tx, ni = sch[AA].split(ni, nparts=num_thread)
ty, ci = sch[AA].split(ci, nparts=cfg['tile_fi'].size[2])
tx, ni = sch[AA].split(ni, nparts=cfg['tile_ni'].size[2])
_, ni = sch[AA].split(ni, factor=4)
sch[AA].reorder(ty, tx, yi, xi, ci, ni)
sch[AA].bind(ty, thread_y)
sch[AA].bind(tx, thread_x)
sch[AA].bind(ty, tvm.thread_axis('threadIdx.y'))
sch[AA].bind(tx, tvm.thread_axis('threadIdx.x'))
sch[AA].vectorize(ni)
# Schedule for W's shared memory load
yi, xi, ci, fi = sch[WW].op.axis
ty, ci = sch[WW].split(ci, nparts=num_thread)
tx, fi = sch[WW].split(fi, nparts=num_thread)
ty, ci = sch[WW].split(ci, nparts=cfg['tile_fi'].size[2])
tx, fi = sch[WW].split(fi, nparts=cfg['tile_ni'].size[2])
_, fi = sch[WW].split(fi, factor=4)
sch[WW].reorder(ty, tx, yi, xi, ci, fi)
sch[WW].bind(ty, thread_y)
sch[WW].bind(tx, thread_x)
sch[WW].bind(ty, tvm.thread_axis('threadIdx.y'))
sch[WW].bind(tx, tvm.thread_axis('threadIdx.x'))
sch[WW].vectorize(fi)
scheduled_ops = []
......
......@@ -35,6 +35,24 @@ def _default_schedule(outs, auto_inline):
@tvm.target.generic_func
def schedule_conv2d_hwcn(outs):
"""Schedule for conv2d_hwcn
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_hwcn
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw
......
......@@ -64,9 +64,9 @@ def conv2d(input, filter, strides, padding, dilation, layout='NCHW', out_dtype=N
# default declaration
if layout == 'NCHW':
return conv2d_nchw(input, filter, strides, padding, dilation, out_dtype)
if layout == 'HWCN':
elif layout == 'HWCN':
return conv2d_hwcn(input, filter, strides, padding, dilation, out_dtype)
if layout == 'NHWC':
elif layout == 'NHWC':
return conv2d_nhwc(input, filter, strides, padding, dilation, out_dtype)
raise ValueError("not support this layout {} yet".format(layout))
......
......@@ -29,24 +29,25 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A')
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W')
B = topi.nn.conv2d_hwcn(A, W, stride, padding, dilation)
C = topi.nn.relu(B)
s1 = topi.cuda.schedule_conv2d_hwcn([B])
s2 = topi.cuda.schedule_conv2d_hwcn([C])
B = tvm.placeholder((1, num_filter, 1), name='bias')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
b_shape = get_const_tuple(B.shape)
dtype = A.dtype
@memoize("topi.tests.test_topi_conv2d_hwcn.verify_hwcn")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=b_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
b_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
c_np = np.maximum(b_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
c1_np = topi.testing.conv2d_hwcn_python(a_np, dw_np, stride, padding)
c2_np = c1_np + b_np
c3_np = np.maximum(c2_np, 0)
return a_np, w_np, b_np, c1_np, c2_np, c3_np
a_np, w_np, b_np, c1_np, c2_np, c3_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
......@@ -54,16 +55,32 @@ def verify_conv2d_hwcn(batch, in_channel, in_size, num_filter, kernel, stride, p
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
t_conv = topi.nn.conv2d(A, W, stride, padding, dilation, layout='HWCN')
t_bias = topi.add(t_conv, B)
t_relu = topi.nn.relu(t_bias)
s1 = topi.generic.schedule_conv2d_hwcn([t_conv])
s2 = topi.generic.schedule_conv2d_hwcn([t_bias])
s3 = topi.generic.schedule_conv2d_hwcn([t_relu])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
func1 = tvm.build(s1, [A, W, B], device)
func2 = tvm.build(s2, [A, W, C], device)
func1(a, w, b)
func2(a, w, c)
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
b = tvm.nd.array(b_np, ctx)
conv_out = tvm.nd.array(
np.zeros(get_const_tuple(t_conv.shape), dtype=t_conv.dtype), ctx)
bias_out = tvm.nd.array(
np.zeros(get_const_tuple(t_bias.shape), dtype=t_bias.dtype), ctx)
relu_out = tvm.nd.array(
np.zeros(get_const_tuple(t_relu.shape), dtype=t_relu.dtype), ctx)
func1 = tvm.build(s1, [A, W, t_conv], device)
func2 = tvm.build(s2, [A, W, B, t_bias], device)
func3 = tvm.build(s3, [A, W, B, t_relu], device)
func1(a, w, conv_out)
func2(a, w, b, bias_out)
func3(a, w, b, relu_out)
tvm.testing.assert_allclose(conv_out.asnumpy(), c1_np, rtol=1e-5)
tvm.testing.assert_allclose(bias_out.asnumpy(), c2_np, rtol=1e-5)
tvm.testing.assert_allclose(relu_out.asnumpy(), c3_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
check_device(device)
......
......@@ -48,7 +48,6 @@ def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, p
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment