Commit d87c94d4 by Liangfu Chen Committed by Tianqi Chen

[Sparse] add sparse tensor computation support (#1289)

parent 75654835
......@@ -16,8 +16,8 @@ from __future__ import absolute_import as _abs
import logging
from decorator import decorate
import numpy as np
from decorator import decorate
from tvm import target as _target
......
"""Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import numpy as _np
from .. import expr as _expr
from .. import api as _api
from .. import tensor as _tensor
from .. import ndarray as _nd
float32 = "float32"
itype = 'int32'
class CSRNDArray(object):
"""Sparse tensor object in CSR format."""
def __init__(self, arg1, ctx=None, shape=None):
"""Construct a sparse matrix in CSR format.
Parameters
----------
arg1 : numpy.ndarray or a tuple with (data, indices, indptr)
The corresponding a dense numpy array,
or a tuple for constructing a sparse matrix directly.
ctx: tvm.TVMContext
The corresponding context.
shape : tuple of int
The shape of the array
"""
if isinstance(arg1, tuple):
assert len(arg1) == 3
self.data, self.indices, self.indptr = arg1
self.shape = shape
elif isinstance(arg1, _np.ndarray):
source_array = arg1
ridx, cidx = _np.nonzero(source_array)
data = source_array[ridx, cidx]
self.data = _nd.array(data, ctx)
indices = _np.nonzero(source_array)[1].astype(itype)
self.indices = _nd.array(indices, ctx)
indptr = [0]+_np.apply_along_axis(_np.count_nonzero, axis=1, arr=source_array).tolist()
indptr = _np.cumsum(_np.array(indptr, itype)).astype(itype)
self.indptr = _nd.array(indptr, ctx)
self.shape = source_array.shape
else:
raise RuntimeError("Construct CSRNDArray with either a tuple (data, indices, indptr) "
"or a numpy.array, can't handle type %s." % (type(arg1),))
self.stype = 'csr'
self.dtype = self.data.dtype
assert self.shape is not None
assert isinstance(self.data, _nd.NDArray)
assert isinstance(self.indices, _nd.NDArray)
assert str(self.indices.dtype) == 'int32' or \
str(self.indices.dtype) == 'int64', str(self.indices.dtype)
assert isinstance(self.indptr, _nd.NDArray)
assert str(self.indptr.dtype) == 'int32' or \
str(self.indptr.dtype) == 'int64', str(self.indptr.dtype)
def asnumpy(self):
"""Construct a full matrix and convert it to numpy array."""
full = _np.zeros(self.shape, self.dtype)
ridx = _np.diff(self.indptr.asnumpy())
ridx = _np.hstack((_np.ones((v,), itype)*i for i, v in enumerate(ridx)))
full[ridx, self.indices.asnumpy().astype(itype)] = self.data.asnumpy()
return full
def array(source_array, ctx=None, shape=None, stype='csr'):
"""Construct a sparse NDArray from numpy.ndarray"""
ret = None
if stype == 'csr':
ret = CSRNDArray(source_array, shape=shape, ctx=ctx)
else:
raise NotImplementedError('stype=%s is not supported yet.' % (stype,))
return ret
class SparsePlaceholderOp(object):
"""Placeholder class for sparse tensor representations."""
def __init__(self, shape, nonzeros, dtype, name):
# pylint: disable=unused-argument
"""Contructing a bare bone structure for a sparse matrix
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
nonzeros: int
The number of non-zero values
dtype: str, optional
The data type of the tensor
name: str, optional
The name hint of the tensor
"""
self.shape = shape
self.dtype = dtype
self.name = name
self.stype = 'unknown'
class CSRPlaceholderOp(SparsePlaceholderOp):
"""Placeholder class for CSR based sparse tensor representation."""
def __init__(self, shape, nonzeros, dtype, name):
"""Contructing a bare bone structure for a csr_matrix
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
nonzeros: int
The number of non-zero values
dtype: str, optional
The data type of the tensor
name: str, optional
The name hint of the tensor
"""
SparsePlaceholderOp.__init__(self, shape, nonzeros, dtype, name)
self.stype = 'csr'
self.data = _api.placeholder((nonzeros,), dtype=dtype, name=self.name+'_data')
self.indices = _api.placeholder((nonzeros,), dtype=itype, name=self.name+'_indices')
self.indptr = _api.placeholder((self.shape[0]+1,), dtype=itype, name=self.name+'_indptr')
assert isinstance(self.data, _tensor.Tensor)
assert isinstance(self.indices, _tensor.Tensor)
assert isinstance(self.indptr, _tensor.Tensor)
def placeholder(shape, nonzeros=None, dtype=None, name="placeholder", stype=None):
"""Construct an empty sparse tensor object.
Parameters
----------
shape: Tuple of Expr
The shape of the tensor
nonzeros: int
The number of non-zero values
dtype: str, optional
The data type of the tensor
name: str, optional
The name hint of the tensor
stype: str, optional
The name storage type of the sparse tensor (e.g. csr, coo, ell)
Returns
-------
tensor: SparsePlaceholderOp
The created sparse tensor placeholder
"""
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
nonzeros = 0 if nonzeros is None else nonzeros
dtype = float32 if dtype is None else dtype
stype = 'csr' if stype is None else stype
ret = None
if stype == 'csr':
ret = CSRPlaceholderOp(shape=shape, nonzeros=nonzeros, dtype=dtype, name=name)
else:
raise NotImplementedError('stype=%s is not supported yet.' % (stype,))
return ret
import tvm
import tvm.contrib.sparse as tvmsp
import tvm.ndarray as _nd
import numpy as np
from collections import namedtuple
def test_static_tensor():
dtype = 'float32'
stype = 'csr'
target = 'llvm'
ctx = tvm.context(target, 0)
m = tvm.var('m')
n = tvm.var('n')
A = tvmsp.placeholder(shape=(m, n), name='A', dtype=dtype)
assert(A.stype == 'csr')
n = 3
a = np.maximum(np.random.uniform(size=(n,n)).astype(dtype)-.6, 0.)
a = tvmsp.array(a, ctx)
A.data = tvm.placeholder(a.data.shape, dtype, name='A_data')
Ab = tvm.decl_buffer(a.data.shape, dtype, name='A_data')
binds = {A.data: Ab}
C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
s = tvm.create_schedule(C.op)
f = tvm.build(s, [A.data, C], target, binds=binds)
c = tvmsp.array(np.zeros((n,n), dtype), ctx)
c.data = tvm.nd.empty(a.data.shape, dtype)
c.indices = a.indices
c.indptr = a.indptr
f(a.data, c.data)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)
def test_dynamic_tensor():
dtype = 'float32'
stype = 'csr'
target = 'llvm'
ctx = tvm.context(target, 0)
nr, nc, n = tvm.var('nr'), tvm.var('nc'), tvm.var('n')
A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype)
assert(A.stype == 'csr')
C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
s = tvm.create_schedule(C.op)
_nr, _nc = 3, 5
a = np.maximum(np.random.uniform(size=(_nr, _nc)).astype(dtype)-.6, 0.)
a = tvmsp.array(a, ctx)
assert a.data.dtype == a.dtype
Ab = namedtuple('CSRBuffer', ['data', 'indices', 'indptr'])
Ab.data = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_data')
Ab.indices = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_indices')
binds = {A.data: Ab.data, A.indices: Ab.indices}
f = tvm.build(s, [nr, A.data, C], target, binds=binds)
c = tvmsp.array(np.zeros((_nr, _nc), dtype), ctx)
c.data = tvm.nd.empty(a.data.shape, dtype)
c.indices = a.indices
c.indptr = a.indptr
f(a.data.shape[0], a.data, c.data)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)
def test_sparse_array_tuple():
dtype, itype = 'float32', 'int32'
stype = 'csr'
target = 'llvm'
ctx = tvm.context(target, 0)
nr, nc, n = tvm.var('nr'), tvm.var('nc'), tvm.var('n')
A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, name='A', dtype=dtype)
assert(A.stype == 'csr')
C = tvm.compute(A.data.shape, lambda i: A.data[i] * 2., tag='cs_scatter')
s = tvm.create_schedule(C.op)
_nr, _nc = 3, 5
a = np.maximum(np.random.uniform(size=(_nr, _nc)).astype(dtype)-.6, 0.)
# convert to sparse array tuple
source_array = a
ridx, cidx = np.nonzero(source_array)
data = source_array[ridx, cidx]
a_data = _nd.array(data, ctx)
indices = np.nonzero(source_array)[1].astype(itype)
a_indices = _nd.array(indices, ctx)
indptr = [0]+np.apply_along_axis(np.count_nonzero, axis=1, arr=source_array).tolist()
indptr = np.cumsum(np.array(indptr, itype)).astype(itype)
a_indptr = _nd.array(indptr, ctx)
a_init = (a_data, a_indices, a_indptr)
# construct tvm sparse array with tuple
a = tvmsp.array(a_init, shape=source_array.shape, ctx=ctx)
assert a.data.dtype == a.dtype
Ab = namedtuple('CSRBuffer', ['data', 'indices', 'indptr'])
Ab.data = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_data')
Ab.indices = tvm.decl_buffer(a.data.shape, a.data.dtype, name='A_indices')
binds = {A.data: Ab.data, A.indices: Ab.indices}
f = tvm.build(s, [nr, A.data, C], target, binds=binds)
c = tvmsp.array(np.zeros((_nr, _nc), dtype), ctx)
c.data = tvm.nd.empty(a.data.shape, dtype)
c.indices = a.indices
c.indptr = a.indptr
f(a.data.shape[0], a.data, c.data)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() * 2., rtol=1e-5)
if __name__ == "__main__":
test_static_tensor()
test_dynamic_tensor()
test_sparse_array_tuple()
......@@ -32,6 +32,7 @@ from . import util
from . import rocm
from . import vision
from . import image
from . import sparse
from . import hls
# not import testing by default
# because testing can have extra deps that are not necessary
......
# pylint: disable=wildcard-import
"""Sparse operators"""
from __future__ import absolute_import as _abs
from .csrmv import csrmv
from .csrmm import csrmm
from .dense import dense
"""TVM operator compute SpMM in CSR format."""
from __future__ import absolute_import
import tvm
from .. import tag
from ..util import simplify
def csrmm_default(data, indices, indptr, weight, bias=None):
# pylint: disable=invalid-name
"""The default implementation of csrmm in topi.
Parameters
----------
data : tvm.Tensor
1-D with shape [nonzeros]
indices : tvm.Tensor
1-D with shape [nonzeros]
indptr : tvm.Tensor
1-D with shape [m+1]
weight : tvm.Tensor
2-D with shape [k, n]
bias : tvm.Tensor, optional
1-D with shape [m]
Returns
-------
output : tvm.Tensor
2-D with shape [m, n]
"""
assert len(data.shape) == 1 and len(indices.shape) == 1 and len(indptr.shape) == 1 \
and len(weight.shape) == 2, "only support 2-dim csrmm"
assert isinstance(weight, tvm.tensor.Tensor), \
"weight matrix is assumed to be tvm.Tensor, but weight is `%s`" % (type(weight))
if bias is not None:
assert len(bias.shape) == 1
M = simplify(indptr.shape[0]-1)
_, N = weight.shape
def csrmm_default_ir(data, indices, indptr, weight, out):
"""define ir for csrmm"""
irb = tvm.ir_builder.create()
data_ptr = irb.buffer_ptr(data)
indices_ptr = irb.buffer_ptr(indices)
indptr_ptr = irb.buffer_ptr(indptr)
weight_ptr = irb.buffer_ptr(weight)
out_ptr = irb.buffer_ptr(out)
M = simplify(indptr.shape[0]-1)
_, N = weight.shape
with irb.for_range(0, N, for_type="vectorize", name='n') as n:
with irb.for_range(0, M, for_type="parallel", name='row') as row:
dot = irb.allocate('float32', (1,), name='dot', scope='local')
out_ptr[row*N+n] = 0.
dot[0] = 0.
row_start = indptr_ptr[row]
row_end = indptr_ptr[row+1]
row_elems = row_end-row_start
with irb.for_range(0, row_elems, name='idx') as idx:
elem = row_start+idx
dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]*N+n]
out_ptr[row*N+n] += dot[0]
return irb.get()
oshape = (M, N)
matmul = tvm.extern(oshape, [data, indices, indptr, weight],
lambda ins, outs: csrmm_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
tag="csrmm", dtype='float32', name='out')
if bias is not None:
matmul = tvm.compute(oshape, lambda i, j: matmul[i, j] + bias[i], \
tag=tag.BROADCAST)
return matmul
def csrmm(a, b, c=None):
"""The `csrmm` routine performs a matrix-matrix operation defined as :math:`C := A*B + C`,
where `B` and `C` are dense matrices, `A` is an m-by-k sparse matrix in the CSR format.
Parameters
----------
a : tvm.contrib.sparse.CSRNDArray
2-D sparse matrix with shape [m, k]
b : tvm.Tensor
2-D dense matrix with shape [k, n]
c : tvm.Tensor, optional
1-D dense vector with shape [n]
Returns
-------
output : tvm.Tensor
2-D with shape [m, n]
"""
return csrmm_default(a.data, a.indices, a.indptr, b, c)
"""TVM operator compute SpMV in CSR format."""
from __future__ import absolute_import
import tvm
from .. import tag
def csrmv_default(data, indices, indptr, weight, bias=None):
"""The default implementation of csrmv in topi.
Parameters
----------
data : tvm.Tensor
1-D with shape [nonzeros]
indices : tvm.Tensor
1-D with shape [nonzeros]
indptr : tvm.Tensor
1-D with shape [m+1]
weight : tvm.Tensor
2-D with shape [k, 1]
bias : tvm.Tensor, optional
1-D with shape [1]
Returns
-------
output : tvm.Tensor
2-D with shape [m, 1]
"""
assert len(data.shape) == 1 and len(weight.shape) == 2, \
"only support 2-dim csrmv"
assert isinstance(weight, tvm.tensor.Tensor), \
"weight matrix is assumed to be tvm.Tensor, but weight is `%s`" % (type(weight))
if bias is not None:
assert len(bias.shape) == 1
batch = indptr.shape[0]-1
def csrmv_default_ir(data, indices, indptr, weight, out):
"""define ir for csrmv"""
irb = tvm.ir_builder.create()
data_ptr = irb.buffer_ptr(data)
indices_ptr = irb.buffer_ptr(indices)
indptr_ptr = irb.buffer_ptr(indptr)
weight_ptr = irb.buffer_ptr(weight)
out_ptr = irb.buffer_ptr(out)
num_rows = indptr.shape[0]-1
with irb.for_range(0, num_rows, for_type="parallel", name='row') as row:
dot = irb.allocate('float32', (1,), name='dot', scope='local')
out_ptr[row] = 0.
dot[0] = 0.
row_start = indptr_ptr[row]
row_end = indptr_ptr[row+1]
row_elems = row_end-row_start
with irb.for_range(0, row_elems, name='elemidx') as elemidx:
elem = row_start+elemidx
dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]]
out_ptr[row] += dot[0]
return irb.get()
oshape = (batch, 1)
matmul = tvm.extern(oshape, [data, indices, indptr, weight],
lambda ins, outs: csrmv_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
tag="csrmv", dtype='float32', name='csrmv')
if bias is not None:
matmul = tvm.compute((batch, 1), lambda i, j: matmul[i, 0] + bias[i], \
tag=tag.BROADCAST)
return matmul
def csrmv(a, x, y=None):
"""The `csrmv` routine performs a matrix-vector operation defined as :math:`y := A*x + y`,
where `x` and `y` are vectors, `A` is an m-by-k sparse matrix in the CSR format.
Parameters
----------
a : tvm.contrib.sparse.CSRNDArray
2-D sparse matrix with shape [m, k]
x : tvm.Tensor
2-D dense matrix with shape [k, 1]
y : tvm.Tensor, optional
1-D dense vector with shape [1]
Returns
-------
output : tvm.Tensor
2-D dense matrix with shape [m, 1]
"""
return csrmv_default(a.data, a.indices, a.indptr, x, y)
"""TVM operator compute Dense in CSR format."""
from __future__ import absolute_import
import tvm
from .. import tag
from ..util import simplify
def dense_si(data, indices, indptr, weight, bias=None):
# pylint: disable=invalid-name
"""The implementation of dense in topi, assuming sparse input.
Parameters
----------
data : tvm.Tensor
1-D with shape [num_nonzeros]
indices : tvm.Tensor
1-D with shape [num_nonzeros]
indptr : tvm.Tensor
1-D with shape [m+1]
weight : tvm.Tensor
2-D with shape [k, n]
bias : tvm.Tensor, optional
1-D with shape [m]
Returns
-------
output : tvm.Tensor
2-D with shape [m, n]
"""
assert len(data.shape) == 1 and len(indices.shape) == 1 and len(indptr.shape) == 1 \
and len(weight.shape) == 2, "only support 2-dim dense"
assert isinstance(weight, tvm.tensor.Tensor), \
"weight matrix is assumed to be tvm.Tensor, but weight is `%s`" % (type(weight))
if bias is not None:
assert len(bias.shape) == 1
dtype = data.dtype
M = simplify(indptr.shape[0]-1)
N, _ = weight.shape
def dense_default_ir(data, indices, indptr, weight, out):
"""Define IR for Dense"""
dtype = data.dtype
irb = tvm.ir_builder.create()
data_ptr = irb.buffer_ptr(data)
indices_ptr = irb.buffer_ptr(indices)
indptr_ptr = irb.buffer_ptr(indptr)
weight_ptr = irb.buffer_ptr(weight)
out_ptr = irb.buffer_ptr(out)
M = simplify(indptr.shape[0]-1)
N, K = weight.shape
with irb.for_range(0, N, for_type="vectorize", name='n') as n:
with irb.for_range(0, M, for_type="parallel", name='m') as m:
dot = irb.allocate(dtype, (1,), name='dot', scope='local')
out_ptr[m*N+n] = tvm.const(0, dtype)
dot[0] = tvm.const(0, dtype)
row_start = indptr_ptr[m]
row_elems = indptr_ptr[m+1]-row_start
with irb.for_range(0, row_elems, name='k') as k:
elem = row_start+k
dot[0] += data_ptr[elem] * weight_ptr[indices_ptr[elem]+n*K]
out_ptr[m*N+n] += dot[0]
return irb.get()
oshape = (M, N)
matmul = tvm.extern(oshape, [data, indices, indptr, weight],
lambda ins, outs: dense_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
tag="dense", dtype=dtype, name='out')
if bias is not None:
matmul = tvm.compute(oshape, lambda i, j: matmul[i, j] + bias[j], \
tag=tag.BROADCAST)
return matmul
def dense_sw(data, w_data, w_indices, w_indptr, bias=None):
# pylint: disable=invalid-name
"""The implementation of dense in topi, assuming sparse weight.
Parameters
----------
data : tvm.Tensor
2-D with shape [m, k]
w_data : tvm.Tensor
1-D with shape [nonzeros]
w_indices : tvm.Tensor
1-D with shape [nonzeros]
w_indptr : tvm.Tensor
1-D with shape [n+1]
bias : tvm.Tensor, optional
1-D with shape [n]
Returns
-------
output : tvm.Tensor
2-D with shape [m, n]
"""
assert len(w_data.shape) == 1 and len(w_indices.shape) == 1 and len(w_indptr.shape) == 1 \
and len(data.shape) == 2, "only support 2-dim dense"
assert isinstance(data, tvm.tensor.Tensor), \
"data matrix is assumed to be tvm.Tensor, but weight is `%s`" % (type(data))
if bias is not None:
assert len(bias.shape) == 1
dtype = data.dtype
M, _ = data.shape
N = simplify(w_indptr.shape[0]-1)
def dense_default_ir(data, w_data, w_indices, w_indptr, out):
"""Define IR for Dense"""
dtype = data.dtype
irb = tvm.ir_builder.create()
data_ptr = irb.buffer_ptr(data)
w_data_ptr = irb.buffer_ptr(w_data)
w_indices_ptr = irb.buffer_ptr(w_indices)
w_indptr_ptr = irb.buffer_ptr(w_indptr)
out_ptr = irb.buffer_ptr(out)
M, K = data.shape
N = simplify(w_indptr.shape[0]-1)
with irb.for_range(0, M, for_type="vectorize", name='m') as m:
with irb.for_range(0, N, for_type="parallel", name='n') as n:
dot = irb.allocate(dtype, (1,), name='dot', scope='local')
out_ptr[m*N+n] = tvm.const(0, dtype)
dot[0] = tvm.const(0, dtype)
row_start = w_indptr_ptr[n]
row_elems = w_indptr_ptr[n+1]-row_start
with irb.for_range(0, row_elems, name='k') as k:
elem = row_start+k
dot[0] += w_data_ptr[elem] * data_ptr[w_indices_ptr[elem]+m*K]
out_ptr[m*N+n] += dot[0]
return irb.get()
oshape = (M, N)
matmul = tvm.extern(oshape, [data, w_data, w_indices, w_indptr],
lambda ins, outs: dense_default_ir(ins[0], ins[1], ins[2], ins[3], outs[0]),
tag="dense", dtype=dtype, name='out')
if bias is not None:
matmul = tvm.compute(oshape, lambda i, j: matmul[i, j] + bias[j], \
tag=tag.BROADCAST)
return matmul
def dense(data, weight, bias=None):
"""Applies a linear transformation: :math:`Y = XW^T + b`.
Either data or weight should be tvm.contrib.sparse.CSRNDArray.
Parameters
----------
data : tvm.contrib.sparse.CSRNDArray or tvm.tensor.Tensor
2-D with shape [batch, in_dim]
weight : tvm.tensor.Tensor or tvm.contrib.sparse.CSRNDArray
2-D with shape [out_dim, in_dim]
bias : tvm.tensor.Tensor, optional
1-D with shape [out_dim]
Returns
-------
output : tvm.Tensor
2-D with shape [batch, out_dim]
"""
ret = None
if isinstance(data, tvm.contrib.sparse.CSRPlaceholderOp) and \
isinstance(weight, tvm.tensor.Tensor):
ret = dense_si(data.data, data.indices, data.indptr, weight, bias)
elif isinstance(data, tvm.tensor.Tensor) and \
isinstance(weight, tvm.contrib.sparse.CSRPlaceholderOp):
ret = dense_sw(data, weight.data, weight.indices, weight.indptr, bias)
else:
raise NotImplementedError("implementation for %s as data and %s as weights, "
"is not supported yet." % (type(data), type(weight), ))
return ret
"""Test code for sparse operator"""
import numpy as np
import tvm
import topi
import topi.testing
from topi.util import get_const_tuple
import tvm.contrib.sparse as tvmsp
from collections import namedtuple
import time
def verify_dynamic_csrmv(batch, in_dim, out_dim, use_bias=True):
nr, nc, n = tvm.var("nr"), tvm.var("nc"), tvm.var("n")
dtype = 'float32'
A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name='A')
B = tvm.placeholder((in_dim, 1), name='B')
C = tvm.placeholder((nr,), name='C')
D = topi.sparse.csrmv(A, B, C if use_bias else None)
s = tvm.create_schedule(D.op)
dtype = A.dtype
# get the test data
def get_ref_data():
a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype)-0.5, 0.)
b_np = np.random.uniform(size=(in_dim, 1)).astype(dtype)-0.5
c_np = np.random.uniform(size=(batch, )).astype(dtype)
if use_bias:
d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1))
else:
d_np = np.dot(a_np, b_np)
return (a_np, b_np, c_np, d_np)
a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
a = tvmsp.array(a_np, ctx)
_nr, _nc, _n = a.shape[0], a.shape[1], a.data.shape[0]
assert a.shape[0] == a.indptr.shape[0]-1
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx)
d = tvm.nd.array(np.zeros((_nr, 1), dtype=dtype), ctx)
assert a.data.dtype == A.data.dtype
assert a.indices.dtype == A.indices.dtype
assert a.indptr.dtype == A.indptr.dtype
f = tvm.build(s, [nr, A.data, A.indices, A.indptr, B, C, D], device, name="csrmv")
f(_nr, a.data, a.indices, a.indptr, b, c, d)
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-4)
for device in ["llvm"]:
check_device(device)
def verify_dynamic_csrmm(batch, in_dim, out_dim, use_bias=True):
nr, nc, n = tvm.var("nr"), tvm.var("nc"), tvm.var("n")
dtype = 'float32'
A = tvmsp.placeholder(shape=(nr, nc), nonzeros=n, dtype=dtype, name='A')
B = tvm.placeholder((in_dim, out_dim), name='B')
C = tvm.placeholder((nr,), name='C')
D = topi.sparse.csrmm(A, B, C if use_bias else None)
s = tvm.create_schedule(D.op)
dtype = A.dtype
# get the test data
def get_ref_data():
a_np = np.maximum(np.random.uniform(size=(batch, in_dim)).astype(dtype)-0.5, 0.)
b_np = np.random.uniform(size=(in_dim, out_dim)).astype(dtype)-0.5
c_np = np.random.uniform(size=(batch, )).astype(dtype)
if use_bias:
d_np = np.dot(a_np, b_np) + c_np.reshape((batch, 1))
else:
d_np = np.dot(a_np, b_np)
return (a_np, b_np, c_np, d_np)
a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
a = tvmsp.array(a_np, ctx)
_nr, _nc, _n = a.shape[0], a.shape[1], a.data.shape[0]
assert a.shape[0] == a.indptr.shape[0]-1
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx)
d = tvm.nd.array(np.zeros((_nr, out_dim), dtype=dtype), ctx)
f = tvm.build(s, [nr, A.data, A.indices, A.indptr, B, C, D], device, name="csrmm")
f(_nr, a.data, a.indices, a.indptr, b, c, d)
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-2)
for device in ["llvm"]:
check_device(device)
def verify_dense_si(batch, in_dim, out_dim, use_bias=True, dtype='float32'):
nonzeros = tvm.var('nonzeros')
A = tvmsp.placeholder(shape=(batch, in_dim), nonzeros=nonzeros, dtype=dtype, name='A')
B = tvm.placeholder((out_dim, in_dim), dtype=dtype, name='B')
C = tvm.placeholder((out_dim,), dtype=dtype, name='C')
D = topi.sparse.dense(A, B, C if use_bias else None)
s = tvm.create_schedule(D.op)
# get the test data
def get_ref_data():
mag = 10.
a_np = np.maximum(mag*(np.random.uniform(size=(batch, in_dim)).astype('float32')-0.5), 0.).astype(dtype)
b_np = (mag*(np.random.uniform(size=(out_dim, in_dim)).astype('float32')-.5)).astype(dtype)
c_np = (mag*(np.random.uniform(size=(out_dim,)).astype('float32')-.5)).astype(dtype)
if use_bias:
d_np = np.dot(a_np, b_np.T) + c_np
else:
d_np = np.dot(a_np, b_np.T)
return (a_np, b_np, c_np, d_np)
a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
a = tvmsp.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx)
d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
f = tvm.build(s, [A.data, A.indices, A.indptr, B, C, D], device, name="dense")
f(a.data, a.indices, a.indptr, b, c, d)
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
check_device('llvm')
def verify_dense_sw(batch, in_dim, out_dim, use_bias=True, dtype='float32'):
nonzeros = tvm.var('nonzeros')
A = tvm.placeholder((batch, in_dim), dtype=dtype, name='A')
B = tvmsp.placeholder(shape=(out_dim, in_dim), nonzeros=nonzeros, dtype=dtype, name='B')
C = tvm.placeholder((out_dim,), dtype=dtype, name='C')
D = topi.sparse.dense(A, B, C if use_bias else None)
s = tvm.create_schedule(D.op)
# get the test data
def get_ref_data():
mag = 10.
a_np = (mag*(np.random.uniform(size=(batch, in_dim)).astype('float32')-.5)).astype(dtype)
b_np = np.maximum(mag*(np.random.uniform(size=(out_dim, in_dim)).astype('float32')-0.5), 0.).astype(dtype)
c_np = (mag*(np.random.uniform(size=(out_dim,)).astype('float32')-.5)).astype(dtype)
if use_bias:
d_np = np.dot(a_np, b_np.T) + c_np
else:
d_np = np.dot(a_np, b_np.T)
return (a_np, b_np, c_np, d_np)
a_np, b_np, c_np, d_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
a = tvm.nd.array(a_np, ctx)
b = tvmsp.array(b_np, ctx)
c = tvm.nd.array(c_np, ctx)
d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=dtype), ctx)
f = tvm.build(s, [A, B.data, B.indices, B.indptr, C, D], device, name="dense")
f(a, b.data, b.indices, b.indptr, c, d)
np.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-5)
check_device('llvm')
def test_csrmv():
verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=False)
verify_dynamic_csrmv(batch=5, in_dim=7, out_dim=1, use_bias=True)
def test_csrmm():
M, K, N = 5, 7, 2
verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=False)
verify_dynamic_csrmm(batch=M, in_dim=K, out_dim=N, use_bias=True)
def test_dense_si():
M, K, N = 3, 5, 2
verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='float32')
verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='float32')
verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='int32')
verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='int32')
verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='int16')
verify_dense_si(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='int16')
def test_dense_sw():
M, K, N = 3, 5, 2
verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='float32')
verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='float32')
verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='int32')
verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='int32')
verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=False, dtype='int16')
verify_dense_sw(batch=M, in_dim=K, out_dim=N, use_bias=True, dtype='int16')
def test_dense():
test_dense_si()
test_dense_sw()
if __name__ == "__main__":
test_csrmv()
test_csrmm()
test_dense()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment