Commit 79735eb2 by Lianmin Zheng Committed by Tianqi Chen

[TOPHUB] fix x86 backend after introducing dilation (#2129)

parent a376eb30
......@@ -21,7 +21,7 @@ AUTOTVM_TOPHUB_ROOT_PATH = os.path.join(os.path.expanduser('~'), ".tvm", "tophub
# the version of each package
PACKAGE_VERSION = {
'arm_cpu': "v0.04",
'llvm': "v0.02",
'llvm': "v0.03",
'cuda': "v0.04",
'rocm': "v0.02",
......
......@@ -292,6 +292,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo):
out_channel = attrs.get_int("channels")
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
layout = attrs['layout']
kh, kw = attrs.get_int_tuple("kernel_size")
......@@ -309,10 +310,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfo):
target = tvm.target.current_target()
# query schedule and fallback if necessary
workload = autotvm.task.args_to_workload(
[data, kernel, strides, padding, out_dtype], depthwise_conv2d_nchw) \
[data, kernel, strides, padding, dilation, out_dtype], depthwise_conv2d_nchw) \
if is_depthwise else \
autotvm.task.args_to_workload(
[data, kernel, strides, padding, layout, out_dtype], conv2d)
[data, kernel, strides, padding, dilation, layout, out_dtype], conv2d)
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback:
_get_default_config(cfg, data, kernel, strides, padding, out_dtype, is_depthwise)
......@@ -334,7 +335,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo):
# Store altered operator's config
new_kernel = tvm.placeholder((out_channel//oc_bn, kh, kw, oc_bn), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, new_attrs['layout'],
[new_data, new_kernel, strides, padding, dilation, new_attrs['layout'],
new_attrs['out_layout'], out_dtype], depthwise_conv2d_NCHWc)
else:
out_channel, _, kh, kw = get_const_tuple(kernel.shape)
......@@ -345,7 +346,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfo):
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn),
dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, new_attrs['layout'],
[new_data, new_kernel, strides, padding, dilation, new_attrs['layout'],
new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
dispatch_ctx.update(target, new_workload, cfg)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment