Commit 7b098c9a by Tianqi Chen Committed by GitHub

[TOPI] Fix x86 schedule for conv out_dtype (#1072)

parent 8065f2e0
...@@ -65,6 +65,7 @@ def _get_schedule_conv(wkl): ...@@ -65,6 +65,7 @@ def _get_schedule_conv(wkl):
@conv2d.register("cpu") @conv2d.register("cpu")
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
out_dtype = data.dtype if out_dtype is None else out_dtype
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
wkl = _get_workload(data, kernel, stride, padding, out_dtype) wkl = _get_workload(data, kernel, stride, padding, out_dtype)
if wkl in _WORKLOADS and 'avx' in str(target) and layout == 'NCHW': if wkl in _WORKLOADS and 'avx' in str(target) and layout == 'NCHW':
......
...@@ -12,6 +12,7 @@ from ..nn.pad import pad ...@@ -12,6 +12,7 @@ from ..nn.pad import pad
AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw']) AVXConvCommonFwd = namedtuple('AVXConvCommonFwd', ['ic_bn', 'oc_bn', 'reg_n', 'unroll_kw'])
def _declaration_conv(data, kernel, stride, padding, layout, out_dtype): def _declaration_conv(data, kernel, stride, padding, layout, out_dtype):
out_dtype = data.dtype if out_dtype is None else out_dtype
assert layout == 'NCHW', "only support NCHW convolution for AVX" assert layout == 'NCHW', "only support NCHW convolution for AVX"
wkl = _get_workload(data, kernel, stride, padding, out_dtype) wkl = _get_workload(data, kernel, stride, padding, out_dtype)
sch = _get_schedule(wkl) sch = _get_schedule(wkl)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment