Commit bac22073 by Wuwei Lin Committed by Tianqi Chen

Alter op layout for group_conv2d on CUDA (#2148)

parent 34648272
......@@ -108,7 +108,7 @@ def compute_conv2d(attrs, inputs, _):
groups == channels:
out = topi.nn.depthwise_conv2d_nchw(
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
elif layout == "NCHW":
elif layout in ["NCHW", "NCHW4c"]:
out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
out_dtype=out_dtype)
elif layout == "NHWC" and \
......@@ -146,7 +146,7 @@ def schedule_conv2d(attrs, outs, target):
return topi.generic.schedule_depthwise_conv2d_nchw(outs)
elif groups == channels and layout == "NHWC" and kernel_layout == "HWOI":
return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
elif layout == "NCHW":
elif layout in ["NCHW", "NCHW4c"]:
return topi.generic.schedule_group_conv2d_nchw(outs)
else:
raise ValueError("No compatible schedule")
......
......@@ -7,7 +7,7 @@ import tvm
from tvm import autotvm
from .. import nn
from ..nn import conv2d, conv2d_winograd_without_weight_transform
from ..nn import conv2d, group_conv2d_nchw, conv2d_winograd_without_weight_transform
from ..util import get_const_int, get_const_tuple, const_matrix, traverse_inline
from ..generic import schedule_conv2d_winograd_without_weight_transform
......@@ -353,12 +353,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
CO, _, KH, KW = get_const_tuple(kernel.shape)
dispatch_ctx = autotvm.DispatchContext.current
target = tvm.target.current_target()
if groups == 1:
# query config of this workload
workload = ('conv2d',) + autotvm.task.args_to_workload(
[tinfos[0], tinfos[1], strides, padding, dilation, layout, out_dtype])
target = tvm.target.current_target()
workload = autotvm.task.args_to_workload(
[tinfos[0], tinfos[1], strides, padding, dilation, layout, out_dtype], conv2d)
cfg = autotvm.DispatchContext.current.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
......@@ -411,6 +411,36 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
)
dispatch_ctx.update(target, new_workload, cfg)
return sym.contrib.conv2d_winograd_without_weight_transform(*copy_inputs, **new_attrs)
elif groups != CI:
workload = autotvm.task.args_to_workload(
[tinfos[0], tinfos[1], strides, padding, dilation, groups, out_dtype],
group_conv2d_nchw)
cfg = autotvm.DispatchContext.current.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None
if cfg.template_key == 'int8':
assert 'cuda' in target.keys
new_layout = 'NCHW4c'
new_attrs['layout'] = new_layout
new_attrs['out_layout'] = new_layout
new_attrs['kernel_layout'] = 'OIHW4o4i'
ic_block_factor = oc_block_factor = 4
# Store the same config for the altered operator (workload)
new_data = tvm.placeholder((N, CI // ic_block_factor, H, W, ic_block_factor),
dtype=data.dtype)
new_kernel = tvm.placeholder((CO // oc_block_factor, CI // ic_block_factor // groups,\
KH, KW, oc_block_factor, ic_block_factor),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, groups, out_dtype],
group_conv2d_nchw
)
dispatch_ctx.update(target, new_workload, cfg)
return sym.conv2d(*copy_inputs, **new_attrs)
# do nothing for depthwise convolution
return None
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment