Commit 2ea7969b by Lianmin Zheng Committed by Tianqi Chen

[TOPI] Fix adding dilation arguments (#2047)

parent cb31b1d0
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
"""Conv2D schedule for ARM CPU""" """Conv2D schedule for ARM CPU"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import warnings
import numpy as np import numpy as np
import tvm import tvm
...@@ -522,7 +524,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos): ...@@ -522,7 +524,10 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
out_dtype = attrs["out_dtype"] out_dtype = attrs["out_dtype"]
out_dtype = tinfos[0].dtype if out_dtype == "same" else out_dtype out_dtype = tinfos[0].dtype if out_dtype == "same" else out_dtype
if layout != 'NCHW' or groups != 1 or dilation != (1, 1): if layout != 'NCHW' or groups != 1:
return None
if dilation != (1, 1):
warnings.warn("Does not support weight pre-transform for dilated convolution.")
return None return None
data, kernel = tinfos[0:2] data, kernel = tinfos[0:2]
...@@ -531,7 +536,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos): ...@@ -531,7 +536,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
# query config of this workload # query config of this workload
workload = autotvm.task.args_to_workload( workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, layout, out_dtype], conv2d) [data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
target = tvm.target.current_target() target = tvm.target.current_target()
dispatch_ctx = autotvm.DispatchContext.current dispatch_ctx = autotvm.DispatchContext.current
cfg = dispatch_ctx.query(target, workload) cfg = dispatch_ctx.query(target, workload)
...@@ -548,7 +553,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos): ...@@ -548,7 +553,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
new_data = data new_data = data
new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype) new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload( new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, 'NCHW', out_dtype], conv2d) [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
return sym.conv2d(*copy_inputs, **new_attrs) return sym.conv2d(*copy_inputs, **new_attrs)
...@@ -574,7 +579,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos): ...@@ -574,7 +579,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos):
new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC), new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
kernel.dtype) kernel.dtype)
new_workload = autotvm.task.args_to_workload( new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, new_attrs['layout'], out_dtype, tile_size], [new_data, new_weight, strides, padding, dilation,
new_attrs['layout'], out_dtype, tile_size],
conv2d_winograd_without_weight_transform) conv2d_winograd_without_weight_transform)
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
......
...@@ -375,6 +375,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -375,6 +375,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
new_attrs['out_layout'] = new_layout new_attrs['out_layout'] = new_layout
new_attrs['kernel_layout'] = 'OIHW4o4i' new_attrs['kernel_layout'] = 'OIHW4o4i'
ic_block_factor = oc_block_factor = 4 ic_block_factor = oc_block_factor = 4
# Store the same config for the altered operator (workload)
new_data = tvm.placeholder((N, CI // ic_block_factor, H, W, ic_block_factor), new_data = tvm.placeholder((N, CI // ic_block_factor, H, W, ic_block_factor),
dtype=data.dtype) dtype=data.dtype)
new_kernel = tvm.placeholder((CO // oc_block_factor, CI // ic_block_factor, KH, KW,\ new_kernel = tvm.placeholder((CO // oc_block_factor, CI // ic_block_factor, KH, KW,\
...@@ -387,7 +389,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -387,7 +389,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
return sym.conv2d(*copy_inputs, **new_attrs) return sym.conv2d(*copy_inputs, **new_attrs)
if attrs.get_int_tuple("dilation") != (1, 1): if attrs.get_int_tuple("dilation") != (1, 1):
warnings.warn("Does not support weight pre-transform for dilated convolution.")
return None return None
# pre-compute weight transformation in winograd # pre-compute weight transformation in winograd
tile_size = _infer_tile_size(tinfos[0], tinfos[1]) tile_size = _infer_tile_size(tinfos[0], tinfos[1])
...@@ -397,6 +401,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos): ...@@ -397,6 +401,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos):
copy_inputs[1] = weight copy_inputs[1] = weight
new_attrs['tile_size'] = tile_size new_attrs['tile_size'] = tile_size
# Store the same config for the altered operator (workload)
new_data = data new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO), new_weight = tvm.placeholder((KH + tile_size - 1, KW + tile_size - 1, CI, CO),
dtype=kernel.dtype) dtype=kernel.dtype)
......
...@@ -440,7 +440,7 @@ def _schedule_winograd(cfg, s, op): ...@@ -440,7 +440,7 @@ def _schedule_winograd(cfg, s, op):
##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM ##### ##### REGISTER TOPI COMPUTE / SCHEDULE FOR WINOGRAD WITH WEIGHT TRANSFORM #####
@autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'mali', ['winograd']) @autotvm.register_topi_compute(conv2d_winograd_without_weight_transform, 'mali', ['winograd'])
def conv2d_winograd_ww(cfg, data, kernel, strides, padding, layout, out_dtype, tile_size): def conv2d_winograd_ww(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, tile_size):
"""TOPI compute callback""" """TOPI compute callback"""
return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, return _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
tile_size) tile_size)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment