sparse.py 5.71 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
"""Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import numpy as _np
from .. import expr as _expr
from .. import api as _api
from .. import tensor as _tensor
from .. import ndarray as _nd

float32 = "float32"
itype = 'int32'

class CSRNDArray(object):
    """Sparse tensor object in CSR format."""
    def __init__(self, arg1, ctx=None, shape=None):
        """Construct a sparse matrix in CSR format.

        Parameters
        ----------
        arg1 : numpy.ndarray or a tuple with (data, indices, indptr)
            The corresponding a dense numpy array,
            or a tuple for constructing a sparse matrix directly.

        ctx: tvm.TVMContext
            The corresponding context.

        shape : tuple of int
            The shape of the array
        """
        if isinstance(arg1, tuple):
            assert len(arg1) == 3
            self.data, self.indices, self.indptr = arg1
            self.shape = shape
        elif isinstance(arg1, _np.ndarray):
            source_array = arg1
            ridx, cidx = _np.nonzero(source_array)
            data = source_array[ridx, cidx]
            self.data = _nd.array(data, ctx)
            indices = _np.nonzero(source_array)[1].astype(itype)
            self.indices = _nd.array(indices, ctx)
            indptr = [0]+_np.apply_along_axis(_np.count_nonzero, axis=1, arr=source_array).tolist()
            indptr = _np.cumsum(_np.array(indptr, itype)).astype(itype)
            self.indptr = _nd.array(indptr, ctx)
            self.shape = source_array.shape
        else:
            raise RuntimeError("Construct CSRNDArray with either a tuple (data, indices, indptr) "
                               "or a numpy.array, can't handle type %s." % (type(arg1),))
        self.stype = 'csr'
        self.dtype = self.data.dtype
        assert self.shape is not None
        assert isinstance(self.data, _nd.NDArray)
        assert isinstance(self.indices, _nd.NDArray)
        assert str(self.indices.dtype) == 'int32' or \
            str(self.indices.dtype) == 'int64', str(self.indices.dtype)
        assert isinstance(self.indptr, _nd.NDArray)
        assert str(self.indptr.dtype) == 'int32' or \
            str(self.indptr.dtype) == 'int64', str(self.indptr.dtype)

    def asnumpy(self):
        """Construct a full matrix and convert it to numpy array."""
        full = _np.zeros(self.shape, self.dtype)
        ridx = _np.diff(self.indptr.asnumpy())
        ridx = _np.hstack((_np.ones((v,), itype)*i for i, v in enumerate(ridx)))
        full[ridx, self.indices.asnumpy().astype(itype)] = self.data.asnumpy()
        return full

def array(source_array, ctx=None, shape=None, stype='csr'):
    """Construct a sparse NDArray from numpy.ndarray"""
    ret = None
    if stype == 'csr':
        ret = CSRNDArray(source_array, shape=shape, ctx=ctx)
    else:
        raise NotImplementedError('stype=%s is not supported yet.' % (stype,))
    return ret

class SparsePlaceholderOp(object):
    """Placeholder class for sparse tensor representations."""
    def __init__(self, shape, nonzeros, dtype, name):
        # pylint: disable=unused-argument
        """Contructing a bare bone structure for a sparse matrix

        Parameters
        ----------
        shape: Tuple of Expr
            The shape of the tensor

        nonzeros: int
            The number of non-zero values

        dtype: str, optional
            The data type of the tensor

        name: str, optional
            The name hint of the tensor
        """
        self.shape = shape
        self.dtype = dtype
        self.name = name
        self.stype = 'unknown'

class CSRPlaceholderOp(SparsePlaceholderOp):
    """Placeholder class for CSR based sparse tensor representation."""
    def __init__(self, shape, nonzeros, dtype, name):
        """Contructing a bare bone structure for a csr_matrix

        Parameters
        ----------
        shape: Tuple of Expr
            The shape of the tensor

        nonzeros: int
            The number of non-zero values

        dtype: str, optional
            The data type of the tensor

        name: str, optional
            The name hint of the tensor
        """
        SparsePlaceholderOp.__init__(self, shape, nonzeros, dtype, name)
        self.stype = 'csr'
        self.data = _api.placeholder((nonzeros,), dtype=dtype, name=self.name+'_data')
        self.indices = _api.placeholder((nonzeros,), dtype=itype, name=self.name+'_indices')
        self.indptr = _api.placeholder((self.shape[0]+1,), dtype=itype, name=self.name+'_indptr')
        assert isinstance(self.data, _tensor.Tensor)
        assert isinstance(self.indices, _tensor.Tensor)
        assert isinstance(self.indptr, _tensor.Tensor)

def placeholder(shape, nonzeros=None, dtype=None, name="placeholder", stype=None):
    """Construct an empty sparse tensor object.

    Parameters
    ----------
    shape: Tuple of Expr
        The shape of the tensor

    nonzeros: int
        The number of non-zero values

    dtype: str, optional
        The data type of the tensor

    name: str, optional
        The name hint of the tensor

    stype: str, optional
        The name storage type of the sparse tensor (e.g. csr, coo, ell)

    Returns
    -------
    tensor: SparsePlaceholderOp
        The created sparse tensor placeholder
    """
    shape = (shape,) if isinstance(shape, _expr.Expr) else shape
    nonzeros = 0 if nonzeros is None else nonzeros
    dtype = float32 if dtype is None else dtype
    stype = 'csr' if stype is None else stype
    ret = None
    if stype == 'csr':
        ret = CSRPlaceholderOp(shape=shape, nonzeros=nonzeros, dtype=dtype, name=name)
    else:
        raise NotImplementedError('stype=%s is not supported yet.' % (stype,))
    return ret