Commit c870261f by Jon Soifer Committed by hlu1

[TOPI] Use cblas for dense and batch_matmul when "cblas" is in the target libraries (#3787)

* Support cblas library in dense

* start to add support for generic batch_matmul compute

* Add x86 override for batch_matmul

* Fix linting

* reset file

* Fix typos

* dummy change to re-trigger CI
parent 95f12e31
...@@ -73,6 +73,7 @@ reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) ...@@ -73,6 +73,7 @@ reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("nn.batch_matmul") @reg.register_compute("nn.batch_matmul")
def compute_batch_matmul(attrs, inputs, out_type, target): def compute_batch_matmul(attrs, inputs, out_type, target):
"""Compute definition of batch_matmul""" """Compute definition of batch_matmul"""
with target:
return [topi.nn.batch_matmul(inputs[0], inputs[1])] return [topi.nn.batch_matmul(inputs[0], inputs[1])]
......
...@@ -20,8 +20,7 @@ from __future__ import absolute_import as _abs ...@@ -20,8 +20,7 @@ from __future__ import absolute_import as _abs
import tvm import tvm
from ..util import get_const_tuple from ..util import get_const_tuple
def batch_matmul_default(x, y):
def batch_matmul(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are """Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch. data in batch.
...@@ -30,7 +29,7 @@ def batch_matmul(x, y): ...@@ -30,7 +29,7 @@ def batch_matmul(x, y):
x : tvm.Tensor x : tvm.Tensor
3-D with shape [batch, M, K] 3-D with shape [batch, M, K]
y : tvm.TEnsor y : tvm.Tensor
3-D with shape [batch, N, K] 3-D with shape [batch, N, K]
Returns Returns
...@@ -49,3 +48,23 @@ def batch_matmul(x, y): ...@@ -49,3 +48,23 @@ def batch_matmul(x, y):
return tvm.compute((batch, M, N), return tvm.compute((batch, M, N),
lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k), lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k),
tag='batch_matmul') tag='batch_matmul')
@tvm.target.generic_func
def batch_matmul(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Parameters
----------
x : tvm.Tensor
3-D with shape [batch, M, K]
y : tvm.Tensor
3-D with shape [batch, N, K]
Returns
-------
output : tvm.Tensor
3-D with shape [batch, M, N]
"""
return batch_matmul_default(x, y)
...@@ -18,10 +18,33 @@ ...@@ -18,10 +18,33 @@
"""x86 batch_matmul operators""" """x86 batch_matmul operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from tvm.contrib import cblas
from topi.nn import batch_matmul, batch_matmul_default
from .. import generic from .. import generic
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
@batch_matmul.register(["cpu"])
def batch_matmul_x86(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Parameters
----------
x : tvm.Tensor
3-D with shape [batch, M, K]
y : tvm.Tensor
3-D with shape [batch, N, K]
Returns
-------
output : tvm.Tensor
3-D with shape [batch, M, N]
"""
target = tvm.target.current_target()
if "cblas" in target.libs:
return cblas.batch_matmul(x, y, False, True)
return batch_matmul_default(x, y)
@generic.schedule_batch_matmul.register(["cpu"]) @generic.schedule_batch_matmul.register(["cpu"])
def schedule_batch_matmul(outs): def schedule_batch_matmul(outs):
...@@ -38,6 +61,10 @@ def schedule_batch_matmul(outs): ...@@ -38,6 +61,10 @@ def schedule_batch_matmul(outs):
sch: Schedule sch: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
target = tvm.target.current_target()
if "cblas" in target.libs:
return generic.schedule_extern(outs)
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
def _callback(op): def _callback(op):
......
...@@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs ...@@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs
import tvm import tvm
from tvm import autotvm from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity from tvm.autotvm.task.space import SplitEntity
from tvm.contrib import cblas
from .util import get_fp32_len from .util import get_fp32_len
from .. import generic, tag, nn from .. import generic, tag, nn
...@@ -40,6 +41,10 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None): ...@@ -40,6 +41,10 @@ def _declaration_dense(cfg, data, weight, bias=None, out_dtype=None):
# Declare dense compute with packing weight into cache-friendly layout # Declare dense compute with packing weight into cache-friendly layout
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack") @autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack")
def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None): def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
target = tvm.target.current_target()
if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True)
else:
if out_dtype is None: if out_dtype is None:
out_dtype = data.dtype out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape) batch, in_dim = get_const_tuple(data.shape)
...@@ -72,6 +77,10 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None): ...@@ -72,6 +77,10 @@ def _declaration_dense_pack(cfg, data, weight, bias=None, out_dtype=None):
# Declare dense compute without packing weight # Declare dense compute without packing weight
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack") @autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack")
def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None): def _declaration_dense_nopack(cfg, data, weight, bias=None, out_dtype=None):
target = tvm.target.current_target()
if "cblas" in target.libs:
C = cblas.matmul(data, weight, False, True)
else:
if out_dtype is None: if out_dtype is None:
out_dtype = data.dtype out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape) batch, in_dim = get_const_tuple(data.shape)
...@@ -116,6 +125,10 @@ def _schedule_dense(cfg, outs): ...@@ -116,6 +125,10 @@ def _schedule_dense(cfg, outs):
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack") @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack")
def _schedule_dense_pack(cfg, outs): def _schedule_dense_pack(cfg, outs):
target = tvm.target.current_target()
if "cblas" in target.libs:
return generic.schedule_extern(outs)
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
def _callback(op): def _callback(op):
...@@ -127,6 +140,10 @@ def _schedule_dense_pack(cfg, outs): ...@@ -127,6 +140,10 @@ def _schedule_dense_pack(cfg, outs):
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack") @autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack")
def _schedule_dense_nopack(cfg, outs): def _schedule_dense_nopack(cfg, outs):
target = tvm.target.current_target()
if "cblas" in target.libs:
return generic.schedule_extern(outs)
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
def _callback(op): def _callback(op):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment