Commit 8ba9e8bc by Haichen Shen Committed by Yao Wang

[TOPI] Allow batch matmul to be fused into injective ops (#4537)

parent e274e66e
...@@ -92,33 +92,40 @@ def schedule_batch_matmul(cfg, outs): ...@@ -92,33 +92,40 @@ def schedule_batch_matmul(cfg, outs):
def _callback(op): def _callback(op):
if "batch_matmul" in op.tag: if "batch_matmul" in op.tag:
C = op.output(0) C = op.output(0)
A, B = s[C].op.input_tensors A, B = op.input_tensors
_, M, K = get_const_tuple(A.shape) _, M, K = get_const_tuple(A.shape)
_, _, N = get_const_tuple(C.shape) _, _, N = get_const_tuple(C.shape)
if op not in s.outputs:
s[C].compute_inline()
O = outs[0]
else:
O = C
CC = s.cache_write(C, "global")
# create tuning space # create tuning space
cfg.define_split("tile_y", M, num_outputs=2) cfg.define_split("tile_y", M, num_outputs=2)
cfg.define_split("tile_x", N, num_outputs=2) cfg.define_split("tile_x", N, num_outputs=2)
cfg.define_split("tile_k", K, num_outputs=2) cfg.define_split("tile_k", K, num_outputs=2)
k, = s[C].op.reduce_axis b, y, x = s[O].op.axis
yo, yi = cfg["tile_y"].apply(s, O, y)
ko, ki = cfg["tile_k"].apply(s, C, k) xo, xi = cfg["tile_x"].apply(s, O, x)
CC = s.rfactor(C, ki) s[O].reorder(b, yo, xo, yi, xi)
bxyo = s[O].fuse(b, yo, xo)
b, y, x = s[C].op.axis s[O].parallel(bxyo)
yo, yi = cfg["tile_y"].apply(s, C, y)
xo, xi = cfg["tile_x"].apply(s, C, x) s[CC].compute_at(s[O], bxyo)
s[C].reorder(b, yo, xo, yi, xi) k, = s[CC].op.reduce_axis
bxyo = s[C].fuse(b, yo, xo) ko, ki = cfg["tile_k"].apply(s, CC, k)
s[C].parallel(bxyo)
s[C].fuse(yi, xi) Crf = s.rfactor(CC, ki)
s[Crf].compute_at(s[CC], s[CC].op.axis[0])
s[CC].compute_at(s[C], bxyo) _, _, y, x = s[Crf].op.axis
_, _, y, x = s[CC].op.axis s[Crf].fuse(y, x)
s[CC].fuse(y, x) s[Crf].vectorize(s[Crf].op.axis[0])
s[CC].vectorize(s[CC].op.axis[0]) s[O].pragma(bxyo, 'auto_unroll_max_step', 16)
s[C].pragma(bxyo, 'auto_unroll_max_step', 16)
traverse_inline(s, outs[0].op, _callback) traverse_inline(s, outs[0].op, _callback)
return s return s
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment