Commit 14a5a358 by Josh Fromm Committed by Haichen Shen

[AutoTVM] Add batch_matmul to tunable operations (#4242)

* Batch matmul tuning running but with errors.

* Default x86 schedule as good as before.

* Code Cleanup

* Remove unused argument.

* improved template documentation.

* Silly lint fix

* Removed leftover comment.

* Moved cfg declaration to schedule for batch_matmul

* Moved x86 dense cfg declaration to schedule.

* lint fix

* Removed duplicate cfg declaration in dense.

* Reverted changes to dense.
parent 7211c277
......@@ -117,6 +117,7 @@ def extract_from_multiple_program(funcs, params, ops, target, target_host=None):
topi.nn.group_conv2d_nchw, topi.nn.conv2d_NCHWc],
tvm.relay.op.nn.conv2d_transpose: [topi.nn.conv2d_transpose_nchw],
tvm.relay.op.nn.dense: [topi.nn.dense],
tvm.relay.op.nn.batch_matmul: [topi.nn.batch_matmul],
tvm.relay.op.nn.deformable_conv2d: [topi.nn.deformable_conv2d_nchw],
}
......
......@@ -87,6 +87,7 @@ class TaskExtractEnv:
topi.nn.conv2d_NCHWc: "topi_x86_conv2d_NCHWc",
topi.nn.conv2d_NCHWc_int8: "topi_x86_conv2d_NCHWc_int8",
topi.nn.dense: "topi_nn_dense",
topi.nn.batch_matmul: "topi_nn_batch_matmul",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
topi.nn.bitserial_dense: "topi_nn_bitserial_dense",
......@@ -103,6 +104,7 @@ class TaskExtractEnv:
topi.nn.conv2d_NCHWc: [topi.generic.schedule_conv2d_NCHWc],
topi.nn.conv2d_NCHWc_int8: [topi.generic.schedule_conv2d_NCHWc_int8],
topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.batch_matmul: [topi.generic.schedule_batch_matmul],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
topi.nn.bitserial_dense: [topi.generic.schedule_bitserial_dense],
......@@ -118,6 +120,7 @@ class TaskExtractEnv:
topi.nn.group_conv2d_nchw: lambda x: setattr(topi.nn, 'group_conv2d_nchw', x),
topi.nn.conv2d_transpose_nchw: lambda x: setattr(topi.nn, 'conv2d_transpose_nchw', x),
topi.nn.dense: lambda x: setattr(topi.nn, 'dense', x),
topi.nn.batch_matmul: lambda x: setattr(topi.nn, 'batch_matmul', x),
topi.nn.bitserial_conv2d_nchw: lambda x: setattr(topi.nn, 'bitserial_conv2d_nchw', x),
topi.nn.bitserial_conv2d_nhwc: lambda x: setattr(topi.nn, 'bitserial_conv2d_nhwc', x),
topi.nn.bitserial_dense: lambda x: setattr(topi.nn, 'bitserial_dense', x),
......@@ -226,6 +229,15 @@ class TaskExtractEnv:
return s, [data, weight, bias, C]
return s, [data, weight, C]
@register("topi_nn_batch_matmul")
def _topi_nn_batch_matmul(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, B = args
C = topi.nn.batch_matmul(A, B)
s = topi.generic.schedule_batch_matmul([C])
return s, [A, B, C]
@register("topi_nn_bitserial_conv2d_nhwc")
def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
args = deserialize_args(args)
......
......@@ -18,24 +18,26 @@
"""x86 batch_matmul operators"""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas
from topi.nn import batch_matmul, batch_matmul_default
from .. import generic
from .. import generic, nn
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
@batch_matmul.register(["cpu"])
def batch_matmul_x86(x, y):
@autotvm.register_topi_compute(nn.batch_matmul, "cpu", "direct")
def _declaration_batch_matmul_nopack(cfg, x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Parameters
----------
cfg : ConfigSpace
Autotvm tuning space config file
x : tvm.Tensor
3-D with shape [batch, M, K]
y : tvm.Tensor
3-D with shape [batch, N, K]
Returns
-------
output : tvm.Tensor
......@@ -44,17 +46,37 @@ def batch_matmul_x86(x, y):
target = tvm.target.current_target()
if "cblas" in target.libs:
return cblas.batch_matmul(x, y, False, True)
return batch_matmul_default(x, y)
@generic.schedule_batch_matmul.register(["cpu"])
def schedule_batch_matmul(outs):
assert len(x.shape) == 3 and len(
y.shape) == 3, "only support 3-dim batch_matmul"
XB, M, XK = get_const_tuple(x.shape)
YB, N, YK = get_const_tuple(y.shape)
assert XB == YB, "batch dimension doesn't match"
assert XK == YK, "shapes of x and y is inconsistant"
B = XB
K = XK
if cfg.is_fallback:
_default_batch_matmul_nopack_config(cfg, M, N, K)
k = tvm.reduce_axis((0, K), name='k')
C = tvm.compute(
(B, M, N),
lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k),
tag='batch_matmul')
return C
@autotvm.register_topi_schedule(generic.schedule_batch_matmul, "cpu", "direct")
def schedule_batch_matmul(cfg, outs):
"""Schedule for batch_matmul
Parameters
----------
outs: Array of Tensor
The computation graph description of batch_matmul
in the format of an array of tensors.
cfg : ConfigSpace
AutoTVM tuning space config file.
outs : Array of Tensor
The computation graph description of batch_matmul
in the format of an array of tensors.
Returns
-------
......@@ -71,16 +93,22 @@ def schedule_batch_matmul(outs):
if "batch_matmul" in op.tag:
C = op.output(0)
A, B = s[C].op.input_tensors
_, M, N = get_const_tuple(C.shape)
_, M, K = get_const_tuple(A.shape)
_, _, N = get_const_tuple(C.shape)
# create tuning space
cfg.define_split("tile_y", M, num_outputs=2)
cfg.define_split("tile_x", N, num_outputs=2)
cfg.define_split("tile_k", K, num_outputs=2)
k, = s[C].op.reduce_axis
ko, ki = s[C].split(k, 16)
ko, ki = cfg["tile_k"].apply(s, C, k)
CC = s.rfactor(C, ki)
b, y, x = s[C].op.axis
y_bn = get_max_power2_factor(M, 8)
x_bn = get_max_power2_factor(N, 8)
yo, yi = s[C].split(y, y_bn)
xo, xi = s[C].split(x, x_bn)
yo, yi = cfg["tile_y"].apply(s, C, y)
xo, xi = cfg["tile_x"].apply(s, C, x)
s[C].reorder(b, yo, xo, yi, xi)
bxyo = s[C].fuse(b, yo, xo)
s[C].parallel(bxyo)
......@@ -94,3 +122,11 @@ def schedule_batch_matmul(outs):
traverse_inline(s, outs[0].op, _callback)
return s
def _default_batch_matmul_nopack_config(cfg, M, N, K):
cfg["tile_k"] = SplitEntity([K // 16, 16])
x_bn = get_max_power2_factor(N, 8)
cfg["tile_x"] = SplitEntity([N // x_bn, x_bn])
y_bn = get_max_power2_factor(M, 8)
cfg["tile_y"] = SplitEntity([M // y_bn, y_bn])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment