Commit d6dcd6c5 by Andrew Tulloch Committed by Tianqi Chen

We observe multiple groups across a range of domains (ASR, NMT, LM, etc), (#3566)

internally and externally, interested in replacing standard dense layers with
block-sparse matrix multiplication layers. The motivations are generally: higher
performance (due to reduction in FLOPs, memory bandwidth/cache footprint),
enabling larger models (e.g. fitting more layers in a given memory budget).

Some public work along these lines:

* https://openai.com/blog/block-sparse-gpu-kernels/
* https://openai.com/blog/sparse-transformer/
* https://arxiv.org/abs/1802.08435
* https://arxiv.org/abs/1711.02782

Various groups have been able to successfully train models with reasonable
levels of sparsity (90%+) with marginal accuracy changes, which suggests
substantial speedups are possible (as this implies a >10x reduction in FLOPs).

It is fairly straightforward to realize these theoretical speedups, see e.g. TVM
benchmarks for Intel CPUs in
https://gist.github.com/ajtulloch/e65f90487bceb8848128e8db582fe902, and CUDA
results in https://github.com/openai/blocksparse, etc.

* https://github.com/openai/blocksparse (CUDA)
* https://software.intel.com/en-us/mkl-developer-reference-c-mkl-bsrmm (MKL BSRM)
* https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.bsr_matrix.html (SCIPY BSR representation)

This is extracted from an internal patch we've been using internally. There are
various extensions possible (int8/fp16/bf16, CUDA/other GPU architectures), but
this is a reasonable starting point. This needs more thorough unit test coverage
however.

We follow the conventions established by scipy.sparse.bsr_matrix and other
libraries, see the unit tests for details.

For folks interested in experimenting with scheduling/AutoTVM etc,
https://gist.github.com/ajtulloch/e65f90487bceb8848128e8db582fe902 is a useful
starting point.
parent 2ed31b24
...@@ -366,6 +366,10 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> { ...@@ -366,6 +366,10 @@ struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
} }
}; };
/*! \brief Attributes for sparse_dense operator */
struct SparseDenseAttrs : public tvm::AttrsNode<SparseDenseAttrs> {
TVM_DECLARE_ATTRS(SparseDenseAttrs, "relay.attrs.SparseDenseAttrs") {}
};
/*! \brief Attributes for upsampling operator */ /*! \brief Attributes for upsampling operator */
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> { struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
......
...@@ -85,6 +85,19 @@ def schedule_batch_matmul(attrs, outputs, target): ...@@ -85,6 +85,19 @@ def schedule_batch_matmul(attrs, outputs, target):
reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE) reg.register_pattern("nn.batch_matmul", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
# sparse_dense
@reg.register_compute("nn.sparse_dense")
def compute_sparse_dense(attrs, inputs, out_type, target):
"""Compute definition of sparse_dense"""
return [topi.nn.sparse_dense(inputs[0], inputs[1], inputs[2], inputs[3])]
@reg.register_schedule("nn.sparse_dense")
def schedule_sparse_dense(attrs, outputs, target):
"""Schedule definition of batch_matmul"""
with target:
return topi.generic.schedule_sparse_dense(outputs)
reg.register_pattern("nn.sparse_dense", reg.OpPattern.OUT_ELEMWISE_FUSABLE)
# conv2d # conv2d
def _find_conv2d_op(op): def _find_conv2d_op(op):
......
...@@ -839,6 +839,39 @@ def batch_matmul(x, y): ...@@ -839,6 +839,39 @@ def batch_matmul(x, y):
""" """
return _make.batch_matmul(x, y) return _make.batch_matmul(x, y)
def sparse_dense(data, weight):
r"""
Computes the matrix multiplication of `data` and `weight`, where `data` is
a dense matrix and `weight` is a sparse (either BSR or CSR) namedtuple with
fields `data`, `indices`, and `indptr`.
.. math::
\mbox{sparse_dense}(data, weight)[m, n] = \mbox{matmul}(x, \mbox{as_dense}(weight)^T)[m, n]
where `as_dense` returns dense equivalent of the given sparse matrix.
See
https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.csr_matrix.html
and
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.sparse.bsr_matrix.html
for more detail on the sparse matrix representation.
Parameters
----------
data : tvm.relay.Expr
The input data for the matrix multiplication
weight : namedtuple.
The sparse weight matrix for the matrix multiplication.
Returns
-------
result: tvm.relay.Expr
The computed result.
"""
return _make.sparse_dense(data, weight.data, weight.indices, weight.indptr)
def contrib_conv2d_winograd_without_weight_transform(data, def contrib_conv2d_winograd_without_weight_transform(data,
weight, weight,
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file sparse.cc
* \brief Property def of nn.sparse_dense operator.
*/
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <vector>
#include "../../pass/alter_op_layout.h"
namespace tvm {
namespace relay {
// relay.nn.sparse_dense
TVM_REGISTER_NODE_TYPE(SparseDenseAttrs);
bool SparseDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 5);
const auto* data = types[0].as<TensorTypeNode>();
const auto* weight_data = types[1].as<TensorTypeNode>();
CHECK(weight_data->shape.size() == 1 || weight_data->shape.size() == 3);
const auto* weight_indptr = types[3].as<TensorTypeNode>();
if (data == nullptr) return false;
if (weight_data->shape.size() == 1) {
// CSR case.
Array<IndexExpr> oshape({data->shape[0], weight_indptr->shape[0] - 1});
reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
return true;
}
if (weight_data->shape.size() == 3) {
// BSR case.
Array<IndexExpr> oshape({
data->shape[0],
(weight_indptr->shape[0] - 1) * weight_data->shape[1]});
reporter->Assign(types[4], TensorTypeNode::make(oshape, data->dtype));
return true;
}
LOG(FATAL) << "Unknown weight ndim for nn.sparse_dense, should be 1 (CSR) or 3 (BSR)";
return false;
}
// Positional relay function to create dense operator used by frontend FFI.
Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr) {
auto attrs = make_node<SparseDenseAttrs>();
static const Op& op = Op::Get("nn.sparse_dense");
return CallNode::make(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op.nn._make.sparse_dense")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeSparseDense, args, rv);
});
RELAY_REGISTER_OP("nn.sparse_dense")
.describe(R"code(Applies a sparse linear transformation: :math:`Y = XW^T` with X sparse.
- **data**: `(x1, x2, ..., xn, input_dim)`
- **weight**: `(units, input_dim)`
- **out**: `(x1, x2, ..., xn, units)`.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SparseDenseAttrs")
.set_num_inputs(4)
.add_argument("data", "nD Tensor", "Input data.")
.add_argument("weight_data", "1D Tensor", "Weight data matrix.")
.add_argument("weight_indices", "1D Tensor", "Weight indices matrix.")
.add_argument("weight_indptr", "1D Tensor", "Weight indptr matrix.")
.set_support_level(1)
.add_type_rel("SparseDense", SparseDenseRel);
} // namespace relay
} // namespace tvm
...@@ -514,6 +514,23 @@ def schedule_l2_normalize(outs): ...@@ -514,6 +514,23 @@ def schedule_l2_normalize(outs):
return cpp.generic.default_schedule(cpp_target, outs, False) return cpp.generic.default_schedule(cpp_target, outs, False)
@tvm.target.generic_func @tvm.target.generic_func
def schedule_sparse_dense(outs):
"""Schedule for sparse_dense
Parameters
----------
outs: Array of Tensor
The computation graph description of sparse_dense
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
@tvm.target.generic_func
def schedule_batch_matmul(outs): def schedule_batch_matmul(outs):
target = tvm.target.current_target(allow_none=False) target = tvm.target.current_target(allow_none=False)
cpp_target = cpp.TEST_create_target(target.target_name) cpp_target = cpp.TEST_create_target(target.target_name)
......
...@@ -20,3 +20,4 @@ from .bitserial_conv2d import * ...@@ -20,3 +20,4 @@ from .bitserial_conv2d import *
from .bitserial_dense import * from .bitserial_dense import *
from .l2_normalize import * from .l2_normalize import *
from .batch_matmul import * from .batch_matmul import *
from .sparse import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Sparse operators"""
from __future__ import absolute_import
import tvm
from ..util import get_const_tuple
@tvm.target.generic_func
def sparse_dense(data, weight_data, weight_indices, weight_indptr):
"""
Computes sparse-dense matrix multiplication of `data` and
`(weight_data, weight_indices, weight_indptr).T`
Parameters
----------
x : tvm.Tensor
2-D with shape [M, K], float32
weight_data : tvm.Tensor
1-D with shape [nnz] (CSR) or
3-D with shape [num_blocks, bs_r, bs_c] (BSR)
weight_indices : tvm.Tensor
1-D with shape [nnz] (CSR) or
1-D with shape [num_blocks] (BSR)
weight_indptr : tvm.Tensor
1-D with shape [N + 1] (CSR) or
1-D with shape [(N + 1) // bs_r] (BSR)
Returns
-------
output : tvm.Tensor
2-D with shape [M, N]
"""
assert len(weight_data.shape) in (1, 3)
if len(weight_data.shape) == 1:
func = _sparse_dense_csrmm
if len(weight_data.shape) == 3:
func = _sparse_dense_bsrmm
return func(data, weight_data, weight_indices, weight_indptr)
def _sparse_dense_csrmm(data, weight_data, weight_indices, weight_indptr):
oshape = (
get_const_tuple(data.shape)[0],
get_const_tuple(weight_indptr.shape)[0] - 1)
def f(i, row):
row_start = weight_indptr[row]
row_end = weight_indptr[row + 1]
row_elems = row_end - row_start
elem_idx = tvm.reduce_axis((0, row_elems), name="elem_idx")
elem = row_start + elem_idx
a_val = weight_data[elem]
weight_val = data[i, weight_indices[elem]]
return tvm.sum(a_val * weight_val, axis=elem_idx)
return tvm.compute(oshape, f, tag="sparse_dense_csrmm")
def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
(m, _) = get_const_tuple(data.shape)
(_, bs_r, bs_c) = get_const_tuple(weight_data.shape)
(num_blocks_plus_1, ) = get_const_tuple(weight_indptr.shape)
num_blocks = num_blocks_plus_1 - 1
def _compute_block(i, nb_j, j):
row_start = weight_indptr[nb_j]
row_end = weight_indptr[nb_j + 1]
row_elems = row_end - row_start
elem_idx = tvm.reduce_axis(
(0, row_elems), name="elem_idx")
block_offset = row_start + elem_idx
c = tvm.reduce_axis((0, bs_c), name="c")
block_j = weight_indices[block_offset]
block_ij_val = weight_data[block_offset][j][c]
x_val = data[i, bs_c * block_j + c]
return tvm.sum(block_ij_val * x_val, axis=[elem_idx, c])
bsrmm_block = tvm.compute(
(m, num_blocks, bs_r), _compute_block,
tag="sparse_dense_bsrmm_block")
return tvm.compute(
(m, num_blocks * bs_r),
lambda m, n: bsrmm_block[m, n // bs_r, n % bs_r],
tag="sparse_dense_bsrmm")
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""sparse_dense schedule on x86"""
import tvm
from .. import generic
from ..util import traverse_inline, get_const_int
from .util import get_fp32_len
@generic.schedule_sparse_dense.register(["cpu"])
def _schedule_sparse_dense(outs):
s = tvm.create_schedule([x.op for x in outs])
def _callback(op):
simd_width = get_fp32_len()
if op.tag == "sparse_dense_csrmm" and op != outs[0].op:
(_, v_i) = s[op].op.axis
s[op].vectorize(v_i)
(y_o, y_i) = s[outs[0].op].split(
s[outs[0].op].op.axis[1], 2 * simd_width)
s[op].compute_at(s[outs[0]], y_o)
s[outs[0].op].vectorize(y_i)
if op.tag == "sparse_dense_bsrmm":
y_bsrmm = op.input_tensors[0]
assert y_bsrmm.op.tag == "sparse_dense_bsrmm_block"
y_reshape = op
(m, num_blocks, b_r) = s[y_bsrmm].op.axis
bs_r = get_const_int(b_r.dom.extent)
(elem_idx, c) = s[y_bsrmm].op.reduce_axis
s[y_bsrmm].reorder(num_blocks, m, elem_idx, b_r, c)
s[y_bsrmm].vectorize(b_r)
(m_o, n_o) = s[y_reshape].op.axis
(noo, noi) = s[y_reshape].split(n_o, bs_r)
s[y_bsrmm].compute_at(s[y_reshape], noi)
s[y_reshape].vectorize(noi)
if op != s[outs[0]].op:
(y_o, y_i) = s[outs[0].op].split(
s[outs[0].op].op.axis[1], 2 * simd_width)
s[y_reshape].compute_at(s[outs[0]], y_o)
s[outs[0].op].parallel(y_o)
s[outs[0].op].vectorize(y_i)
else:
m_o_noo = s[y_reshape].fuse(m_o, noo)
s[y_reshape].parallel(m_o_noo)
traverse_inline(s, outs[0].op, _callback)
return s
...@@ -215,7 +215,106 @@ def test_dense(): ...@@ -215,7 +215,106 @@ def test_dense():
test_dense_si() test_dense_si()
test_dense_sw() test_dense_sw()
def test_sparse_dense_csr():
import scipy.sparse as sp
M, N, K, density = 1, 17, 47, 0.2
X_np = np.random.randn(M, K).astype("float32")
W_sp_np = sp.random(N, K, density=density, format='csr', dtype="float32")
W_np = W_sp_np.todense()
Y_np = X_np.dot(W_np.T)
W_data = tvm.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype))
W_indices = tvm.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype))
W_indptr = tvm.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype))
X = tvm.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = tvm.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.ndarray.array(X_np), tvm.ndarray.array(W_sp_np.data), tvm.ndarray.array(W_sp_np.indices), tvm.ndarray.array(W_sp_np.indptr), Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
def random_bsr_matrix(M, N, BS_R, BS_C, density, dtype):
import scipy.sparse as sp
import itertools
Y = np.zeros((M, N), dtype=dtype)
assert M % BS_R == 0
assert N % BS_C == 0
nnz = int(density * M * N)
num_blocks = int(nnz / (BS_R * BS_C)) + 1
candidate_blocks = np.asarray(list(itertools.product(range(0, M, BS_R), range(0, N, BS_C))))
assert candidate_blocks.shape[0] == M // BS_R * N // BS_C
chosen_blocks = candidate_blocks[np.random.choice(candidate_blocks.shape[0], size=num_blocks, replace=False)]
for i in range(len(chosen_blocks)):
r, c = chosen_blocks[i]
Y[r:r + BS_R, c:c + BS_C] = np.random.randn(BS_R, BS_C)
s = sp.bsr_matrix(Y, blocksize=(BS_R, BS_C))
assert s.data.shape == (num_blocks, BS_R, BS_C)
assert s.indices.shape == (num_blocks, )
assert s.indptr.shape == (M // BS_R + 1, )
return s
def test_sparse_dense_bsr():
M, N, K, BS_R, BS_C, density = 1, 64, 128, 8, 16, 0.9
X_np = np.random.randn(M, K).astype("float32")
W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32")
W_np = W_sp_np.todense()
Y_np = X_np.dot(W_np.T)
W_data = tvm.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype))
W_indices = tvm.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype))
W_indptr = tvm.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype))
X = tvm.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = tvm.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.ndarray.array(X_np),
tvm.ndarray.array(W_sp_np.data),
tvm.ndarray.array(W_sp_np.indices),
tvm.ndarray.array(W_sp_np.indptr),
Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-4, rtol=1e-4)
def test_sparse_dense_bsr_randomized():
for _ in range(20):
BS_R = np.random.randint(1, 16)
BS_C = np.random.randint(1, 16)
M = np.random.randint(1, 32)
N = int(np.random.randint(1, 16) * BS_R)
K = int(np.random.randint(1, 16) * BS_C)
density = np.clip(np.random.random(), 0.1, 0.9)
X_np = np.random.randn(M, K).astype("float32")
W_sp_np = random_bsr_matrix(N, K, BS_R, BS_C, density=density, dtype="float32")
W_np = W_sp_np.todense()
Y_np = np.array(X_np.dot(W_np.T))
W_data = tvm.placeholder(shape=W_sp_np.data.shape, dtype=str(W_sp_np.data.dtype))
W_indices = tvm.placeholder(shape=W_sp_np.indices.shape, dtype=str(W_sp_np.indices.dtype))
W_indptr = tvm.placeholder(shape=W_sp_np.indptr.shape, dtype=str(W_sp_np.indptr.dtype))
X = tvm.placeholder(shape=X_np.shape, dtype=str(X_np.dtype))
Y = topi.nn.sparse_dense(X, W_data, W_indices, W_indptr)
s = tvm.create_schedule(Y.op)
func = tvm.build(s, [X, W_data, W_indices, W_indptr, Y])
Y_tvm = tvm.ndarray.array(np.zeros(Y_np.shape, dtype=Y_np.dtype))
func(tvm.ndarray.array(X_np),
tvm.ndarray.array(W_sp_np.data),
tvm.ndarray.array(W_sp_np.indices),
tvm.ndarray.array(W_sp_np.indptr),
Y_tvm)
tvm.testing.assert_allclose(Y_tvm.asnumpy(), Y_np, atol=1e-5, rtol=1e-5)
def test_sparse_dense():
test_sparse_dense_csr()
test_sparse_dense_bsr()
test_sparse_dense_bsr_randomized()
if __name__ == "__main__": if __name__ == "__main__":
test_csrmv() test_csrmv()
test_csrmm() test_csrmm()
test_dense() test_dense()
test_sparse_dense()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment