Commit eb3a7382 by Yizhi Liu Committed by Haichen Shen

[Arm] parallel batch axis (#3931)

* support LLVM trunk

* guard with USE_LLVM in if condition for c++14

* GREATER_EQUAL -> GREATER

* [Arm] parallel batch axis
parent 968ffef6
......@@ -280,13 +280,15 @@ def _schedule_spatial_pack(cfg, s, data_vec, kernel_vec,
s[conv].compute_at(s[last], ow)
# mark parallel
s[last].parallel(co)
p = s[last].fuse(n, co)
s[last].parallel(p)
if data_vec.op.name == 'data_vec_undilated':
_, h, _, _, _, _, _, _ = s[data_vec].op.axis
n, h, _, _, _, _, _, _ = s[data_vec].op.axis
else:
_, h, _, _, _, _ = s[data_vec].op.axis
s[data_vec].parallel(h)
n, h, _, _, _, _ = s[data_vec].op.axis
p = s[data_vec].fuse(n, h)
s[data_vec].parallel(p)
if kernel_vec.op.name == 'kernel_vec':
co, _, _, _, _ = s[kernel_vec].op.axis
......@@ -470,8 +472,9 @@ def _schedule_winograd(cfg, s, output, last):
# output
n, co, h, w = s[last].op.axis
co, coi = cfg['tile_k'].apply(s, last, co)
s[M].compute_at(s[last], co)
s[last].parallel(co)
p = s[last].fuse(n, co)
s[M].compute_at(s[last], p)
s[last].parallel(p)
MM = s.cache_read(M, 'global', [Y])
m = get_const_int(V.shape[0]) + 1 - 3
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment