Commit ca768109 by Lianmin Zheng Committed by Tianqi Chen

[TOPI][ARM CPU] fuse bias to depthwise conv2d (#1631)

parent b7beb1eb
......@@ -9,11 +9,11 @@ from ..nn import depthwise_conv2d_nchw
from ..util import traverse_inline
# register original implementation of depthwise_conv2d_nchw since we don't need to change this part
autotvm.task.register_topi_compute(depthwise_conv2d_nchw, 'arm_cpu', 'direct',
autotvm.task.register_topi_compute(depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct',
depthwise_conv2d_nchw.fdefault)
# register customized schedule for arm cpu.
@autotvm.task.register_topi_schedule(schedule_depthwise_conv2d_nchw, 'arm_cpu', 'direct')
@autotvm.task.register_topi_schedule(schedule_depthwise_conv2d_nchw, ['arm_cpu', 'cpu'], 'direct')
def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
"""Schedule depthwise conv2d
......@@ -44,15 +44,15 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
cfg.define_split('tile_w', w, num_outputs=2)
if cfg.is_fallback:
cfg.fallback_split('tile_c', [-1, 8])
cfg.fallback_split('tile_c', [-1, 4])
cfg.fallback_split('tile_h', [-1, 2])
cfg.fallback_split('tile_w', [-1, 8])
cfg.fallback_split('tile_w', [-1, 4])
# park data to vector form [n, c, h, w] -> [n, C, h, w, VC]
A0 = s.cache_read(data_pad, "global", C)
_, c, h, w = s[A0].op.axis
n, c, h, w = s[A0].op.axis
c, vc = cfg['tile_c'].apply(s, A0, c)
s[A0].reorder(c, h, w, vc)
s[A0].reorder(n, c, h, w, vc)
A1 = s.cache_write(A0, 'global')
s[A0].compute_inline()
......@@ -64,9 +64,9 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
B1 = s.cache_write(B0, 'global')
s[B0].compute_inline()
_, c, h, w = s[C].op.axis
n, c, h, w = s[C].op.axis
c, vc, = cfg['tile_c'].apply(s, C, c)
s[C].reorder(c, h, w, vc)
s[C].reorder(n, c, h, w, vc)
# depthwise conv
C0 = s.cache_write(C, 'global')
......@@ -86,9 +86,14 @@ def schedule_depthwise_conv2d_nchw_arm(cfg, outs):
max_unroll=16,
cfg=cfg)
# fusion
if C.op not in s.outputs:
s[C].compute_inline()
# mark parallel
n, c, h, w = s[C].op.axis
s[C].parallel(c)
last = outs[0]
n, c, h, w = s[last].op.axis
s[last].parallel(c)
n, c, h, w, vc = s[C0].op.axis
s[C0].parallel(c)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment