Commit 153fd7ff by Animesh Jain Committed by Yizhi Liu

[AlterOpLayout][x86] NHWC to NCHWc conv support. (#4080)

parent b5bcdbb0
...@@ -158,21 +158,19 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F): ...@@ -158,21 +158,19 @@ def _alter_conv2d_layout(attrs, inputs, tinfo, F):
return None return None
return F.nn.contrib_conv2d_nchwc_int8(*copy_inputs, **new_attrs) return F.nn.contrib_conv2d_nchwc_int8(*copy_inputs, **new_attrs)
if data_layout == 'NCHW' and attrs['kernel_layout'] == 'OIHW': # (oc, ic, h, w) -> (OC, IC, h, w, ic, oc)
# (oc, ic, h, w) -> (OC, IC, h, w, ic, oc) new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn)
new_attrs['kernel_layout'] = 'OIHW%di%do' % (ic_bn, oc_bn) # Store altered operator's config
# Store altered operator's config new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn,
new_kernel = tvm.placeholder((out_channel//oc_bn, in_channel//ic_bn, kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype)
kh, kw, ic_bn, oc_bn), dtype=kernel_tensor.dtype) new_workload = autotvm.task.args_to_workload(
new_workload = autotvm.task.args_to_workload( [new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name],
[new_data, new_kernel, strides, padding, dilation, new_attrs[layout_name], new_attrs['out_layout'], out_dtype], conv2d_NCHWc)
new_attrs['out_layout'], out_dtype], conv2d_NCHWc) dispatch_ctx.update(target, new_workload, cfg)
dispatch_ctx.update(target, new_workload, cfg)
if F.__name__ == 'nnvm.symbol':
if F.__name__ == 'nnvm.symbol': return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs)
return F.contrib.conv2d_NCHWc(*copy_inputs, **new_attrs) return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
return F.nn.contrib_conv2d_nchwc(*copy_inputs, **new_attrs)
return None
@conv2d_legalize.register("cpu") @conv2d_legalize.register("cpu")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment