Commit 84590063 by Haichen Shen Committed by Leyuan Wang

[Relay/TOPI][Op] Add batch_matmul in relay and TOPI (#2561)

* Add batch_dot and cpu schedule

* Add relay support for batch_dot

* Rename batch_dot to batch_matmul

* nits

* Add missing file

* Put batch_matmul and dense x86 schedule in separate files

* Fix pylint

* Remove unused import

* Add cuda schedule for batch_matmul

* Add test case with larger batch size

* Add batch_matmul in api doc

* Fix quantize pass rounding error

* Fix pylint and minor change

* bug fix
parent d546bb77
...@@ -41,6 +41,7 @@ List of operators ...@@ -41,6 +41,7 @@ List of operators
topi.nn.upsampling topi.nn.upsampling
topi.nn.softmax topi.nn.softmax
topi.nn.dense topi.nn.dense
topi.nn.batch_matmul
topi.nn.log_softmax topi.nn.log_softmax
topi.nn.conv2d_nchw topi.nn.conv2d_nchw
topi.nn.conv2d_hwcn topi.nn.conv2d_hwcn
...@@ -138,6 +139,7 @@ topi.nn ...@@ -138,6 +139,7 @@ topi.nn
.. autofunction:: topi.nn.upsampling .. autofunction:: topi.nn.upsampling
.. autofunction:: topi.nn.softmax .. autofunction:: topi.nn.softmax
.. autofunction:: topi.nn.dense .. autofunction:: topi.nn.dense
.. autofunction:: topi.nn.batch_matmul
.. autofunction:: topi.nn.log_softmax .. autofunction:: topi.nn.log_softmax
.. autofunction:: topi.nn.conv2d_nchw .. autofunction:: topi.nn.conv2d_nchw
.. autofunction:: topi.nn.conv2d_hwcn .. autofunction:: topi.nn.conv2d_hwcn
......
...@@ -152,6 +152,7 @@ This level support backpropagation of broadcast operators. It is temporary. ...@@ -152,6 +152,7 @@ This level support backpropagation of broadcast operators. It is temporary.
tvm.relay.device_copy tvm.relay.device_copy
tvm.relay.annotation.on_device tvm.relay.annotation.on_device
tvm.relay.reverse_reshape tvm.relay.reverse_reshape
tvm.relay.nn.batch_matmul
Level 1 Definitions Level 1 Definitions
...@@ -264,3 +265,4 @@ Level 10 Definitions ...@@ -264,3 +265,4 @@ Level 10 Definitions
.. autofunction:: tvm.relay.device_copy .. autofunction:: tvm.relay.device_copy
.. autofunction:: tvm.relay.annotation.on_device .. autofunction:: tvm.relay.annotation.on_device
.. autofunction:: tvm.relay.reverse_reshape .. autofunction:: tvm.relay.reverse_reshape
.. autofunction:: tvm.relay.nn.batch_matmul
...@@ -283,6 +283,18 @@ def _mx_multibox_detection(inputs, attrs): ...@@ -283,6 +283,18 @@ def _mx_multibox_detection(inputs, attrs):
return _op.vision.nms(ret[0], ret[1], **new_attrs1) return _op.vision.nms(ret[0], ret[1], **new_attrs1)
def _mx_batch_dot(inputs, attrs):
assert len(inputs) == 2
a, b = inputs
transpose_a = attrs.get_bool("transpose_a", False)
transpose_b = attrs.get_bool("transpose_b", False)
if transpose_a is True:
raise RuntimeError("batch_dot: only support transpose_a=False")
if transpose_b is False:
b = _op.transpose(b, axes=[0, 2, 1])
return _op.batch_matmul(a, b)
def _mx_arange(inputs, attrs): def _mx_arange(inputs, attrs):
assert len(inputs) == 0 assert len(inputs) == 0
if attrs.get_int("repeat", 1) != 1: if attrs.get_int("repeat", 1) != 1:
...@@ -389,6 +401,7 @@ _convert_map = { ...@@ -389,6 +401,7 @@ _convert_map = {
"expand_dims" : _mx_expand_dims, "expand_dims" : _mx_expand_dims,
"Concat" : _mx_concat, "Concat" : _mx_concat,
"concat" : _mx_concat, "concat" : _mx_concat,
"batch_dot" : _mx_batch_dot,
"LeakyReLU" : _mx_leaky_relu, "LeakyReLU" : _mx_leaky_relu,
"_arange" : _mx_arange, "_arange" : _mx_arange,
"SoftmaxOutput" : _mx_softmax_output, "SoftmaxOutput" : _mx_softmax_output,
...@@ -403,7 +416,6 @@ _convert_map = { ...@@ -403,7 +416,6 @@ _convert_map = {
# "broadcast_to", # "broadcast_to",
# "gather_nd", # "gather_nd",
# "Crop" : _crop_like, # "Crop" : _crop_like,
} }
# set identity list # set identity list
......
...@@ -46,6 +46,21 @@ def schedule_dense(attrs, outputs, target): ...@@ -46,6 +46,21 @@ def schedule_dense(attrs, outputs, target):
reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
# batch_matmul
@reg.register_compute("nn.batch_matmul")
def compute_batch_matmul(attrs, inputs, out_type, target):
"""Compute definition of batch_matmul"""
return [topi.nn.batch_matmul(inputs[0], inputs[1])]
@reg.register_schedule("nn.batch_matmul")
def schedule_batch_matmul(attrs, outputs, target):
"""Schedule definition of batch_matmul"""
with target:
return topi.generic.schedule_batch_matmul(outputs)
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
# conv2d # conv2d
@reg.register_compute("nn.conv2d") @reg.register_compute("nn.conv2d")
def compute_conv2d(attrs, inputs, out_type, target): def compute_conv2d(attrs, inputs, out_type, target):
......
...@@ -767,6 +767,31 @@ def batch_norm(data, ...@@ -767,6 +767,31 @@ def batch_norm(data,
return TupleWrapper(result, 3) return TupleWrapper(result, 3)
def batch_matmul(x, y):
r"""
Computes batch matrix multiplication of `x` and `y` when `x` and `y` are data
in batch.
.. math::
\mbox{batch_matmul}(x, y)[i, :, :] = \mbox{matmul}(x[i, :, :], y[i, :, :]^T)
Parameters
----------
x : tvm.relay.Expr
The first input.
y : tvm.relay.Expr
The second input.
Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.batch_matmul(x, y)
def contrib_conv2d_winograd_without_weight_transform(data, def contrib_conv2d_winograd_without_weight_transform(data,
weight, weight,
tile_size, tile_size,
......
...@@ -654,5 +654,68 @@ axis to be the last item in the input shape. ...@@ -654,5 +654,68 @@ axis to be the last item in the input shape.
.set_support_level(1) .set_support_level(1)
.add_type_rel("BatchNorm", BatchNormRel); .add_type_rel("BatchNorm", BatchNormRel);
// relay.nn.batch_matmul
bool BatchMatmulRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* x = types[0].as<TensorTypeNode>();
const auto* y = types[1].as<TensorTypeNode>();
if (x == nullptr || y == nullptr) return false;
if (x->shape.size() != 3 || y->shape.size() != 3) return false;
CHECK(reporter->AssertEQ(x->shape[0], y->shape[0]))
<< "BatchDot: batch dimension doesn't match, "
<< " x shape=" << x->shape
<< ", y shape=" << y->shape;
CHECK(reporter->AssertEQ(x->shape[2], y->shape[2]))
<< "BatchDot: shapes of x and y is inconsistent, "
<< " x shape=" << x->shape
<< ", y shape=" << y->shape;
Array<tvm::Expr> oshape = x->shape;
oshape.Set(2, y->shape[1]);
// assign output type
reporter->Assign(types[2], TensorTypeNode::make(oshape, x->dtype));
return true;
}
// Positional relay function to create batch_matmul operator used by frontend FFI.
Expr MakeBatchMatmul(Expr x,
Expr y) {
static const Op& op = Op::Get("nn.batch_matmul");
return CallNode::make(op, {x, y}, Attrs(), {});
}
TVM_REGISTER_API("relay.op.nn._make.batch_matmul")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeBatchMatmul, args, rv);
});
RELAY_REGISTER_OP("nn.batch_matmul")
.describe(R"code(Computes matrix multiplication of `x` and `y` when `x` and `y`
are data in batch.
.. math::
batch\_matmul(x, y)[i, :, :] = matmul(x[i, :, :], y[i, :, :]^T)
- **x**: `(b, m, k)`
- **y**: `(b, n, k)`
- **out**: `(b, m, n)`.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("x", "3D Tensor", "First input.")
.add_argument("y", "3D Tensor", "Second input.")
.set_support_level(10)
.add_type_rel("BatchMatmul", BatchMatmulRel);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -306,7 +306,6 @@ def test_dense(): ...@@ -306,7 +306,6 @@ def test_dense():
tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
test_concatenate() test_concatenate()
test_bias_add() test_bias_add()
......
...@@ -4,6 +4,8 @@ import numpy as np ...@@ -4,6 +4,8 @@ import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.testing import ctx_list from tvm.relay.testing import ctx_list
import topi
import topi.testing
def test_collapse_sum_like(): def test_collapse_sum_like():
shape = (3, 4, 5, 6) shape = (3, 4, 5, 6)
...@@ -126,7 +128,6 @@ def test_reverse_reshape(): ...@@ -126,7 +128,6 @@ def test_reverse_reshape():
x = relay.var("x", relay.TensorType(shape, "float32")) x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.reverse_reshape(x, newshape=newshape) z = relay.reverse_reshape(x, newshape=newshape)
zz = relay.ir_pass.infer_type(z) zz = relay.ir_pass.infer_type(z)
print(zz.checked_type)
assert "newshape=" in z.astext() assert "newshape=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(oshape, "float32") assert zz.checked_type == relay.ty.TensorType(oshape, "float32")
...@@ -144,8 +145,41 @@ def test_reverse_reshape(): ...@@ -144,8 +145,41 @@ def test_reverse_reshape():
verify_reverse_reshape((2, 3, 4), (-1, 0), (6, 4)) verify_reverse_reshape((2, 3, 4), (-1, 0), (6, 4))
verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12)) verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12))
def verify_batch_matmul(x_shape, y_shape, out_shape, dtype="float32"):
x = relay.var("x", relay.TensorType(x_shape, dtype))
y = relay.var("y", relay.TensorType(y_shape, dtype))
z = relay.nn.batch_matmul(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType(out_shape, dtype)
func = relay.Function([x, y], z)
x_np = np.random.uniform(size=x_shape).astype(dtype)
y_np = np.random.uniform(size=y_shape).astype(dtype)
z_np = topi.testing.batch_matmul(x_np, y_np)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
z = intrp.evaluate(func)(x_np, y_np)
tvm.testing.assert_allclose(z.asnumpy(), z_np, rtol=1e-5)
def test_batch_matmul():
b, m, n, k = tvm.var("b"), tvm.var("m"), tvm.var("n"), tvm.var("k")
x = relay.var("x", relay.TensorType((b, m, k), "float32"))
y = relay.var("y", relay.TensorType((b, n, k), "float32"))
z = relay.nn.batch_matmul(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((b, m, n), "float32")
verify_batch_matmul((1, 16, 32), (1, 16, 32), (1, 16, 16))
verify_batch_matmul((5, 16, 32), (5, 16, 32), (5, 16, 16))
verify_batch_matmul((5, 16, 32), (5, 20, 32), (5, 16, 20))
verify_batch_matmul((30, 16, 32), (30, 20, 32), (30, 16, 20))
if __name__ == "__main__": if __name__ == "__main__":
test_collapse_sum_like() test_collapse_sum_like()
test_broadcast_to_like() test_broadcast_to_like()
test_slice_like() test_slice_like()
test_reverse_reshape() test_reverse_reshape()
test_batch_matmul()
...@@ -75,7 +75,7 @@ def test_quantize_pass(): ...@@ -75,7 +75,7 @@ def test_quantize_pass():
graph = relay.create_executor('graph') graph = relay.create_executor('graph')
res0 = graph.evaluate(qgraph0)(dataset[0]['data']) res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
res1 = graph.evaluate(qgraph1)(dataset[0]['data']) res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy()) tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
if __name__ == "__main__": if __name__ == "__main__":
......
/*!
* Copyright (c) 2019 by Contributors
* \brief Batch matmul op constructions
* \file nn/batch_matmul.h
*/
#ifndef TOPI_NN_BATCH_MATMUL_H_
#define TOPI_NN_BATCH_MATMUL_H_
#include <string>
#include "topi/tags.h"
#include "tvm/tvm.h"
namespace topi {
namespace nn {
using namespace tvm;
/*!
* \brief Creates an operation that calculates matrix multiplication in batch.
*
* \param x Tensor with shape [batch, M, K]
* \param y Tensor with shape [batch, N, K]
*
* \return Tensor with shape [batch, M, N]
*/
inline tvm::Tensor batch_matmul(const tvm::Tensor& x,
const tvm::Tensor& y) {
CHECK_EQ(x->shape.size(), 3) << "batch_matmul requires 3-D data";
CHECK_EQ(y->shape.size(), 3) << "batch_matmul requires 3-D data";
auto batch = x->shape[0];
auto M = x->shape[1];
auto K = x->shape[2];
auto N = y->shape[1];
auto k = tvm::reduce_axis(Range(0, K), "k");
auto result = tvm::compute(
{ batch, M, N },
[&](Var b, Var i, Var j) {
return tvm::sum(x(b, i, k) * y(b, j, k), { k });
}, "tensor", "batch_matmul");
return result;
}
} // namespace nn
} // namespace topi
#endif // TOPI_NN_BATCH_MATMUL_H_
...@@ -14,6 +14,7 @@ from .dense import dense_cuda, schedule_dense ...@@ -14,6 +14,7 @@ from .dense import dense_cuda, schedule_dense
from .pooling import schedule_pool, schedule_global_pool from .pooling import schedule_pool, schedule_global_pool
from .extern import schedule_extern from .extern import schedule_extern
from .nn import schedule_lrn, schedule_l2_normalize from .nn import schedule_lrn, schedule_l2_normalize
from .batch_matmul import schedule_batch_matmul
from .vision import * from .vision import *
from . import ssd from . import ssd
from .ssd import * from .ssd import *
......
# pylint: disable=invalid-name,too-many-locals,unused-variable
"""cuda batch_matmul operators"""
from __future__ import absolute_import as _abs
import tvm
from .. import generic
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
@generic.schedule_batch_matmul.register(["cuda", "gpu"])
def schedule_batch_matmul(outs):
"""Schedule for batch_matmul
Parameters
----------
outs: Array of Tensor
The computation graph description of batch_matmul
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
s = tvm.create_schedule([x.op for x in outs])
def _schedule(op):
C = op.output(0)
A, B = s[C].op.input_tensors
_, M, N = get_const_tuple(C.shape)
AA = s.cache_read(A, "shared", [C])
AL = s.cache_read(AA, "local", [C])
BB = s.cache_read(B, "shared", [C])
BL = s.cache_read(BB, "local", [C])
CC = s.cache_write(C, "local")
b, y, x = s[C].op.axis
y_bn = get_max_power2_factor(M, 64)
x_bn = get_max_power2_factor(N, 64)
by, y = s[C].split(y, y_bn)
bx, x = s[C].split(x, x_bn)
y_nthreads = min(y_bn, 8)
x_nthreads = min(x_bn, 8)
ty, yi = s[C].split(y, nparts=y_nthreads)
tx, xi = s[C].split(x, nparts=x_nthreads)
thread_x = tvm.thread_axis((0, x_nthreads), "threadIdx.x")
thread_y = tvm.thread_axis((0, y_nthreads), "threadIdx.y")
s[C].reorder(b, by, bx, ty, tx, yi, xi)
s[C].bind(b, tvm.thread_axis("blockIdx.z"))
s[C].bind(by, tvm.thread_axis("blockIdx.y"))
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
s[C].pragma(yi, "auto_unroll_max_step", 16)
s[CC].compute_at(s[C], tx)
_, yi, xi = s[CC].op.axis
k, = s[CC].op.reduce_axis
ko, ki = s[CC].split(k, 8)
s[CC].reorder(ko, ki, yi, xi)
s[CC].pragma(ki, "auto_unroll_max_step", 16)
s[AA].compute_at(s[CC], ko)
s[AL].compute_at(s[CC], ki)
s[BB].compute_at(s[CC], ko)
s[BL].compute_at(s[CC], ki)
_, y, k = s[AA].op.axis
ty, yi = s[AA].split(y, nparts=y_nthreads)
tx, ki = s[AA].split(k, nparts=x_nthreads)
s[AA].reorder(ty, tx, yi, ki)
s[AA].bind(ty, thread_y)
s[AA].bind(tx, thread_x)
s[AA].pragma(yi, "auto_unroll_max_step", 16)
_, x, k = s[BB].op.axis
ty, xi = s[BB].split(x, nparts=y_nthreads)
tx, ki = s[BB].split(k, nparts=x_nthreads)
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
s[BB].reorder(ty, tx, xi, ki)
s[BB].pragma(xi, "auto_unroll_max_step", 16)
def _callback(op):
if "batch_matmul" in op.tag:
_schedule(op)
traverse_inline(s, outs[0].op, _callback)
return s
...@@ -410,3 +410,9 @@ def schedule_l2_normalize(outs): ...@@ -410,3 +410,9 @@ def schedule_l2_normalize(outs):
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name) cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False) return cpp.generic.default_schedule(cpp_target, outs, False)
@tvm.target.generic_func
def schedule_batch_matmul(outs):
target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name)
return cpp.generic.default_schedule(cpp_target, outs, False)
...@@ -17,3 +17,4 @@ from .upsampling import * ...@@ -17,3 +17,4 @@ from .upsampling import *
from .local_response_norm import * from .local_response_norm import *
from .bitserial_conv2d import * from .bitserial_conv2d import *
from .l2_normalize import * from .l2_normalize import *
from .batch_matmul import *
"""Binary Neural Network (BNN) Operators"""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import tvm
from ..util import get_const_tuple
def batch_matmul(x, y):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Parameters
----------
x : tvm.Tensor
3-D with shape [batch, M, K]
y : tvm.TEnsor
3-D with shape [batch, N, K]
Returns
-------
output : tvm.Tensor
3-D with shape [batch, M, N]
"""
assert len(x.shape) == 3 and len(y.shape) == 3, "only support 3-dim batch_matmul"
x_shape = get_const_tuple(x.shape)
y_shape = get_const_tuple(y.shape)
assert x_shape[0] == y_shape[0], "batch dimension doesn't match"
assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistant"
batch, M, K = x.shape
N = y.shape[1]
k = tvm.reduce_axis((0, K), name='k')
return tvm.compute((batch, M, N),
lambda b, i, j: tvm.sum(x[b, i, k] * y[b, j, k], axis=k),
tag='batch_matmul')
...@@ -19,3 +19,4 @@ from .lrn_python import lrn_python ...@@ -19,3 +19,4 @@ from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python from .l2_normalize_python import l2_normalize_python
from .gather_nd_python import gather_nd_python from .gather_nd_python import gather_nd_python
from .strided_slice_python import strided_slice_python from .strided_slice_python import strided_slice_python
from .batch_matmul import batch_matmul
# pylint: disable=invalid-name
"""Batch matmul in python"""
import numpy as np
def batch_matmul(x, y):
"""batch_matmul operator implemented in numpy.
Parameters
----------
x : numpy.ndarray
3-D with shape [batch, M, K]
y : numpy.ndarray
3-D with shape [batch, N, K]
Returns
-------
out : numpy.ndarray
3-D with shape [batch, M, N]
"""
batch, M, _ = x.shape
N = y.shape[1]
out = np.zeros((batch, M, N)).astype(x.dtype)
for i in range(batch):
out[i] = np.dot(x[i], y[i].T)
return out
...@@ -255,3 +255,29 @@ def const_matrix(matrix, name="const_matrix"): ...@@ -255,3 +255,29 @@ def const_matrix(matrix, name="const_matrix"):
return now return now
return tvm.compute(matrix.shape, select_array, name=name) return tvm.compute(matrix.shape, select_array, name=name)
def get_max_power2_factor(n, max_value=None):
"""Get max factor of n in power of 2. If max_value is specificed, max factor
value will be no more max_value,
Parameter
---------
n : int
The input value
max_value : int, optional
The max value for the factor
Returns
-------
factor : int
The max factor in power of 2.
"""
x = 1
while n % 2 == 0:
if max_value is not None and max_value < x * 2:
break
x *= 2
n /= 2
return x
# pylint: disable=invalid-name,too-many-locals,unused-variable
"""x86 batch_matmul operators"""
from __future__ import absolute_import as _abs
import tvm
from .. import generic
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor
@generic.schedule_batch_matmul.register(["cpu"])
def schedule_batch_matmul(outs):
"""Schedule for batch_matmul
Parameters
----------
outs: Array of Tensor
The computation graph description of batch_matmul
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if "batch_matmul" in op.tag:
C = op.output(0)
A, B = s[C].op.input_tensors
_, M, N = get_const_tuple(C.shape)
k, = s[C].op.reduce_axis
ko, ki = s[C].split(k, 16)
CC = s.rfactor(C, ki)
b, y, x = s[C].op.axis
y_bn = get_max_power2_factor(M, 8)
x_bn = get_max_power2_factor(N, 8)
yo, yi = s[C].split(y, y_bn)
xo, xi = s[C].split(x, x_bn)
s[C].reorder(b, yo, xo, yi, xi)
bxyo = s[C].fuse(b, yo, xo)
s[C].parallel(bxyo)
s[C].fuse(yi, xi)
s[CC].compute_at(s[C], bxyo)
_, _, y, x = s[CC].op.axis
s[CC].fuse(y, x)
s[CC].vectorize(s[CC].op.axis[0])
s[C].pragma(bxyo, 'auto_unroll_max_step', 16)
traverse_inline(s, outs[0].op, _callback)
return s
# pylint: disable=invalid-name,too-many-locals,unused-variable
"""x86 dense operators"""
from __future__ import absolute_import as _abs
import tvm
from tvm import autotvm
from tvm.autotvm.task.space import SplitEntity
from .util import get_fp32_len
from .. import generic, tag, nn
from ..util import traverse_inline, get_const_tuple
@autotvm.register_topi_compute(nn.dense, "cpu", "direct")
def _declaration_dense(cfg, data, weight, bias=None):
batch, _ = get_const_tuple(data.shape)
# For small batch sizes, don't pack weight into cache-friendly layout
# because of overhead in packing and limited reuse from batch dimension
# TODO(icemelon9): use a more systematic way to determine which schedule to use
if batch <= 16:
return _declaration_dense_nopack(cfg, data, weight, bias)
return _declaration_dense_pack(cfg, data, weight, bias)
# Declare dense compute with packing weight into cache-friendly layout
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack")
def _declaration_dense_pack(cfg, data, weight, bias=None):
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_y", batch, num_outputs=3)
cfg.define_split("tile_x", out_dim, num_outputs=3)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_pack_config(cfg, batch, out_dim, in_dim)
packw_bn = cfg["tile_x"].size[-1]
packw_shape = (out_dim // packw_bn, in_dim, packw_bn)
packw = tvm.compute(packw_shape,
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
k = tvm.reduce_axis((0, in_dim), name="k")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(
data[y, k] * packw[x // packw_bn, k, x % packw_bn],
axis=k),
tag="dense_pack")
if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j],
tag=tag.BROADCAST)
return C
# Declare dense compute without packing weight
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack")
def _declaration_dense_nopack(cfg, data, weight, bias=None):
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_x", out_dim, num_outputs=2)
cfg.define_split("tile_y", batch, num_outputs=2)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_nopack_config(cfg, batch, out_dim, in_dim)
vec = cfg["tile_k"].size[-1]
k = tvm.reduce_axis((0, in_dim // vec), "k")
CC = tvm.compute((batch, out_dim, vec),
lambda z, y, x: tvm.sum(
data[z, k * vec + x] * weight[y, k * vec + x], axis=k))
kk = tvm.reduce_axis((0, vec), "kk")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
tag="dense_nopack")
if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j],
tag=tag.BROADCAST)
return C
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct")
def _schedule_dense(cfg, outs):
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if "dense_pack" in op.tag:
_schedule_dense_pack_template(cfg, s, op.output(0))
elif 'dense_nopack' in op.tag:
_schedule_dense_nopack_template(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack")
def _schedule_dense_pack(cfg, outs):
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if "dense_pack" in op.tag:
_schedule_dense_pack_template(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack")
def _schedule_dense_nopack(cfg, outs):
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
if 'dense_nopack' in op.tag:
_schedule_dense_nopack_template(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
def _schedule_dense_pack_template(cfg, s, C):
A, packedB = s[C].op.input_tensors
CC = s.cache_write(C, "global")
y, x = s[C].op.axis
k, = s[CC].op.reduce_axis
yt, yo, yi = cfg["tile_y"].apply(s, C, y)
xt, xo, xi = cfg["tile_x"].apply(s, C, x)
s[C].reorder(yt, xt, yo, xo, yi, xi)
xyt = s[C].fuse(yt, xt)
s[C].parallel(xyt)
xyo = s[C].fuse(yo, xo)
s[C].unroll(yi)
s[C].vectorize(xi)
s[CC].compute_at(s[C], xyo)
y, x = s[CC].op.axis
ko, ki = cfg["tile_k"].apply(s, CC, k)
s[CC].reorder(ko, ki, y, x)
s[CC].vectorize(x)
s[CC].unroll(y)
s[CC].unroll(ki)
z, y, x = s[packedB].op.axis
s[packedB].reorder(z, x, y)
s[packedB].parallel(z)
s[packedB].vectorize(y)
return s
def _schedule_dense_nopack_template(cfg, s, C):
y, x = s[C].op.axis
kk, = s[C].op.reduce_axis
yo, yi = cfg["tile_y"].apply(s, C, y)
xo, xi = cfg["tile_x"].apply(s, C, x)
s[C].reorder(yo, xo, yi, xi)
xyo = s[C].fuse(yo, xo)
s[C].parallel(xyo)
s[C].unroll(kk)
CC, = s[C].op.input_tensors
s[CC].compute_at(s[C], xyo)
z, y, x = s[CC].op.axis
k, = s[CC].op.reduce_axis
yz = s[CC].fuse(z, y)
s[CC].reorder(k, yz, x)
s[CC].unroll(yz)
s[CC].vectorize(x)
return s
def _default_dense_pack_config(cfg, M, N, K):
vec_width = get_fp32_len()
tilex_ii = 1
for bn in range(vec_width*2, 0, -1):
if N % bn == 0:
tilex_ii = bn
break
NN = N // tilex_ii
tilex_oi = 1
while NN // tilex_oi > 4:
if (NN // tilex_oi) % 2 == 1:
break
tilex_oi *= 2
tiley_ii = 8
while M % tiley_ii != 0:
tiley_ii //= 2
MM = M // tiley_ii
tiley_oi = 1
while MM // tiley_oi > 4:
if (MM // tiley_oi) % 2 == 1:
break
tiley_oi *= 2
cfg["tile_y"] = SplitEntity([MM // tiley_oi, tiley_oi, tiley_ii])
cfg["tile_x"] = SplitEntity([NN // tilex_oi, tilex_oi, tilex_ii])
cfg["tile_k"] = SplitEntity([K, 1])
def _default_dense_nopack_config(cfg, M, N, K):
vec_width = get_fp32_len()
tilek_bn = 1
for bn in range(vec_width*2, 0, -1):
if K % bn == 0:
tilek_bn = bn
break
cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn])
cfg["tile_x"] = SplitEntity([N, 1])
cfg["tile_y"] = SplitEntity([1, M])
...@@ -2,12 +2,7 @@ ...@@ -2,12 +2,7 @@
"""x86 nn operators""" """x86 nn operators"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from tvm import autotvm from .. import generic
from tvm.autotvm.task.space import SplitEntity
from .util import get_fp32_len
from .. import generic, tag, nn
from ..util import traverse_inline, get_const_tuple
@generic.schedule_softmax.register(["cpu"]) @generic.schedule_softmax.register(["cpu"])
def schedule_softmax(outs): def schedule_softmax(outs):
...@@ -37,205 +32,3 @@ def schedule_softmax(outs): ...@@ -37,205 +32,3 @@ def schedule_softmax(outs):
else: else:
s[x].parallel(s[x].op.axis[0]) s[x].parallel(s[x].op.axis[0])
return s return s
@autotvm.register_topi_compute(nn.dense, "cpu", "direct")
def _declaration_dense(cfg, data, weight, bias=None):
batch, _ = get_const_tuple(data.shape)
# For small batch sizes, don't pack weight into cache-friendly layout
# because of overhead in packing and limited reuse from batch dimension
# TODO(icemelon9): use a more systematic way to determine which schedule to use
if batch <= 16:
return _declaration_dense_nopack(cfg, data, weight, bias)
return _declaration_dense_pack(cfg, data, weight, bias)
# Declare dense compute with packing weight into cache-friendly layout
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_pack")
def _declaration_dense_pack(cfg, data, weight, bias=None):
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_y", batch, num_outputs=3)
cfg.define_split("tile_x", out_dim, num_outputs=3)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_pack_config(cfg, batch, out_dim, in_dim)
packw_bn = cfg["tile_x"].size[-1]
packw_shape = (out_dim // packw_bn, in_dim, packw_bn)
packw = tvm.compute(packw_shape,
lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")
k = tvm.reduce_axis((0, in_dim), name="k")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(
data[y, k] * packw[x // packw_bn, k, x % packw_bn],
axis=k),
tag="dense_pack")
if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j],
tag=tag.BROADCAST)
return C
# Declare dense compute without packing weight
@autotvm.register_topi_compute(nn.dense, "cpu", "direct_nopack")
def _declaration_dense_nopack(cfg, data, weight, bias=None):
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
# create tuning space
cfg.define_split("tile_x", out_dim, num_outputs=2)
cfg.define_split("tile_y", batch, num_outputs=2)
cfg.define_split("tile_k", in_dim, num_outputs=2)
if cfg.is_fallback:
_default_dense_nopack_config(cfg, batch, out_dim, in_dim)
vec = cfg["tile_k"].size[-1]
k = tvm.reduce_axis((0, in_dim // vec), "k")
CC = tvm.compute((batch, out_dim, vec),
lambda z, y, x: tvm.sum(
data[z, k * vec + x] * weight[y, k * vec + x], axis=k))
kk = tvm.reduce_axis((0, vec), "kk")
C = tvm.compute((batch, out_dim),
lambda y, x: tvm.sum(CC[y, x, kk], axis=kk),
tag="dense_nopack")
if bias is not None:
C = tvm.compute((batch, out_dim), lambda i, j: C[i, j] + bias[j],
tag=tag.BROADCAST)
return C
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct")
def _schedule_dense(cfg, outs):
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _callback(op):
if "dense_pack" in op.tag:
_schedule_dense_pack_template(cfg, s, op.output(0))
elif 'dense_nopack' in op.tag:
_schedule_dense_nopack_template(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_pack")
def _schedule_dense_pack(cfg, outs):
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _callback(op):
if "dense_pack" in op.tag:
_schedule_dense_pack_template(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
@autotvm.register_topi_schedule(generic.schedule_dense, "cpu", "direct_nopack")
def _schedule_dense_nopack(cfg, outs):
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
def _callback(op):
if 'dense_nopack' in op.tag:
_schedule_dense_nopack_template(cfg, s, op.output(0))
traverse_inline(s, outs[0].op, _callback)
return s
def _schedule_dense_pack_template(cfg, s, C):
A, packedB = s[C].op.input_tensors
CC = s.cache_write(C, "global")
y, x = s[C].op.axis
k, = s[CC].op.reduce_axis
yt, yo, yi = cfg["tile_y"].apply(s, C, y)
xt, xo, xi = cfg["tile_x"].apply(s, C, x)
s[C].reorder(yt, xt, yo, xo, yi, xi)
xyt = s[C].fuse(yt, xt)
s[C].parallel(xyt)
xyo = s[C].fuse(yo, xo)
s[C].unroll(yi)
s[C].vectorize(xi)
s[CC].compute_at(s[C], xyo)
y, x = s[CC].op.axis
ko, ki = cfg["tile_k"].apply(s, CC, k)
s[CC].reorder(ko, ki, y, x)
s[CC].vectorize(x)
s[CC].unroll(y)
s[CC].unroll(ki)
z, y, x = s[packedB].op.axis
s[packedB].reorder(z, x, y)
s[packedB].parallel(z)
s[packedB].vectorize(y)
return s
def _schedule_dense_nopack_template(cfg, s, C):
y, x = s[C].op.axis
kk, = s[C].op.reduce_axis
yo, yi = cfg["tile_y"].apply(s, C, y)
xo, xi = cfg["tile_x"].apply(s, C, x)
s[C].reorder(yo, xo, yi, xi)
xyo = s[C].fuse(yo, xo)
s[C].parallel(xyo)
s[C].unroll(kk)
CC, = s[C].op.input_tensors
s[CC].compute_at(s[C], xyo)
z, y, x = s[CC].op.axis
k, = s[CC].op.reduce_axis
yz = s[CC].fuse(z, y)
s[CC].reorder(k, yz, x)
s[CC].unroll(yz)
s[CC].vectorize(x)
return s
def _default_dense_pack_config(cfg, M, N, K):
vec_width = get_fp32_len()
tilex_ii = 1
for bn in range(vec_width*2, 0, -1):
if N % bn == 0:
tilex_ii = bn
break
NN = N // tilex_ii
tilex_oi = 1
while NN // tilex_oi > 4:
if (NN // tilex_oi) % 2 == 1:
break
tilex_oi *= 2
tiley_ii = 8
while M % tiley_ii != 0:
tiley_ii //= 2
MM = M // tiley_ii
tiley_oi = 1
while MM // tiley_oi > 4:
if (MM // tiley_oi) % 2 == 1:
break
tiley_oi *= 2
cfg["tile_y"] = SplitEntity([MM // tiley_oi, tiley_oi, tiley_ii])
cfg["tile_x"] = SplitEntity([NN // tilex_oi, tilex_oi, tilex_ii])
cfg["tile_k"] = SplitEntity([K, 1])
def _default_dense_nopack_config(cfg, M, N, K):
vec_width = get_fp32_len()
tilek_bn = 1
for bn in range(vec_width*2, 0, -1):
if K % bn == 0:
tilek_bn = bn
break
cfg["tile_k"] = SplitEntity([K // tilek_bn, tilek_bn])
cfg["tile_x"] = SplitEntity([N, 1])
cfg["tile_y"] = SplitEntity([1, M])
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <topi/nn/upsampling.h> #include <topi/nn/upsampling.h>
#include <topi/nn/l2_normalize.h> #include <topi/nn/l2_normalize.h>
#include <topi/nn/local_response_norm.h> #include <topi/nn/local_response_norm.h>
#include <topi/nn/batch_matmul.h>
#include <topi/vision/reorg.h> #include <topi/vision/reorg.h>
#include <topi/image/resize.h> #include <topi/image/resize.h>
...@@ -351,6 +352,12 @@ TVM_REGISTER_GLOBAL("topi.nn.dense") ...@@ -351,6 +352,12 @@ TVM_REGISTER_GLOBAL("topi.nn.dense")
*rv = nn::dense(args[0], args[1], args[2]); *rv = nn::dense(args[0], args[1], args[2]);
}); });
/* Ops from nn/batch_matmul.h */
TVM_REGISTER_GLOBAL("topi.nn.batch_matmul")
.set_body([](TVMArgs args, TVMRetValue *rv) {
*rv = nn::batch_matmul(args[0], args[1]);
});
/* Ops from nn/dilate.h */ /* Ops from nn/dilate.h */
TVM_REGISTER_GLOBAL("topi.nn.dilate") TVM_REGISTER_GLOBAL("topi.nn.dilate")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
...@@ -589,6 +596,9 @@ TVM_REGISTER_GENERIC_FUNC(schedule_dense) ...@@ -589,6 +596,9 @@ TVM_REGISTER_GENERIC_FUNC(schedule_dense)
.register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense)) .register_func({ "cuda", "gpu" }, WrapSchedule(topi::cuda::schedule_dense))
.register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense)); .register_func({ "rocm" }, WrapSchedule(topi::rocm::schedule_dense));
TVM_REGISTER_GENERIC_FUNC(schedule_batch_matmul)
.set_default(WrapSchedule(topi::generic::default_schedule));
TVM_REGISTER_GENERIC_FUNC(schedule_pool) TVM_REGISTER_GENERIC_FUNC(schedule_pool)
.set_default(WrapSchedule(topi::generic::default_schedule)) .set_default(WrapSchedule(topi::generic::default_schedule))
.register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule)) .register_func({ "cpu" }, WrapSchedule(topi::x86::default_schedule))
......
"""Test code for batch_matmul operator"""
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
from common import get_all_backend
def verify_batch_matmul(batch, M, N, K):
x = tvm.placeholder((batch, M, K), name='x')
y = tvm.placeholder((batch, N, K), name='y')
dtype = x.dtype
# use memoize to pickle the test data for next time use
@memoize("topi.tests.test_topi_batch_matmul")
def get_ref_data():
a_np = np.random.uniform(size=(batch, M, K)).astype(dtype)
b_np = np.random.uniform(size=(batch, N, K)).astype(dtype)
c_np = topi.testing.batch_matmul(a_np, b_np)
return (a_np, b_np, c_np)
# get the test data
a_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
out = topi.nn.batch_matmul(x, y)
s = topi.generic.schedule_batch_matmul([out])
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=dtype), ctx)
f = tvm.build(s, [x, y, out], device, name="dense")
f(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in get_all_backend():
check_device(device)
def test_batch_matmul():
verify_batch_matmul(1, 16, 16, 32)
verify_batch_matmul(5, 16, 16, 32)
verify_batch_matmul(5, 16, 20, 32)
verify_batch_matmul(30, 16, 20, 32)
if __name__ == "__main__":
test_batch_matmul()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment