Unverified Commit ff7bab80 by Animesh Jain Committed by GitHub

[TOPI] Setting workload correctly for Depthwise conv ARM. (#5182)

parent ff6fa399
......@@ -154,14 +154,14 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
if topi_tmpl == "depthwise_conv2d_nchw_spatial_pack.arm_cpu":
assert data_layout == "NCHW" and kernel_layout == "OIHW"
N, CI, H, W = get_const_tuple(data.shape)
CO, _, KH, KW = get_const_tuple(kernel.shape)
CO, M, KH, KW = get_const_tuple(kernel.shape)
VC = cfg['tile_co'].size[-1]
new_attrs['kernel_layout'] = 'OIHW%do' % (cfg['tile_co'].size[-1])
# Store the same config for the altered operator (workload)
new_data = data
new_kernel = te.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
new_kernel = te.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, out_dtype],
"depthwise_conv2d_nchw_spatial_pack.arm_cpu")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment