6.48 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
"""Tensor and Operation class for computation declaration."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs
import numpy as _np
from .. import expr as _expr
from .. import api as _api
from .. import tensor as _tensor
from .. import ndarray as _nd

float32 = "float32"
itype = 'int32'

class CSRNDArray(object):
    """Sparse tensor object in CSR format."""
    def __init__(self, arg1, ctx=None, shape=None):
        """Construct a sparse matrix in CSR format.

        arg1 : numpy.ndarray or a tuple with (data, indices, indptr)
            The corresponding a dense numpy array,
            or a tuple for constructing a sparse matrix directly.

        ctx: tvm.TVMContext
            The corresponding context.

        shape : tuple of int
            The shape of the array
        if isinstance(arg1, tuple):
            assert len(arg1) == 3
  , self.indices, self.indptr = arg1
            self.shape = shape
        elif isinstance(arg1, _np.ndarray):
            source_array = arg1
            ridx, cidx = _np.nonzero(source_array)
            data = source_array[ridx, cidx]
   = _nd.array(data, ctx)
            indices = _np.nonzero(source_array)[1].astype(itype)
            self.indices = _nd.array(indices, ctx)
            indptr = [0]+_np.apply_along_axis(_np.count_nonzero, axis=1, arr=source_array).tolist()
            indptr = _np.cumsum(_np.array(indptr, itype)).astype(itype)
            self.indptr = _nd.array(indptr, ctx)
            self.shape = source_array.shape
            raise RuntimeError("Construct CSRNDArray with either a tuple (data, indices, indptr) "
                               "or a numpy.array, can't handle type %s." % (type(arg1),))
        self.stype = 'csr'
        self.dtype =
        assert self.shape is not None
        assert isinstance(, _nd.NDArray)
        assert isinstance(self.indices, _nd.NDArray)
        assert str(self.indices.dtype) == 'int32' or \
            str(self.indices.dtype) == 'int64', str(self.indices.dtype)
        assert isinstance(self.indptr, _nd.NDArray)
        assert str(self.indptr.dtype) == 'int32' or \
            str(self.indptr.dtype) == 'int64', str(self.indptr.dtype)

    def asnumpy(self):
        """Construct a full matrix and convert it to numpy array."""
        full = _np.zeros(self.shape, self.dtype)
        ridx = _np.diff(self.indptr.asnumpy())
        ridx = _np.hstack((_np.ones((v,), itype)*i for i, v in enumerate(ridx)))
        full[ridx, self.indices.asnumpy().astype(itype)] =
        return full

def array(source_array, ctx=None, shape=None, stype='csr'):
    """Construct a sparse NDArray from numpy.ndarray"""
    ret = None
    if stype == 'csr':
        ret = CSRNDArray(source_array, shape=shape, ctx=ctx)
        raise NotImplementedError('stype=%s is not supported yet.' % (stype,))
    return ret

class SparsePlaceholderOp(object):
    """Placeholder class for sparse tensor representations."""
    def __init__(self, shape, nonzeros, dtype, name):
        # pylint: disable=unused-argument
        """Contructing a bare bone structure for a sparse matrix

        shape: Tuple of Expr
            The shape of the tensor

        nonzeros: int
            The number of non-zero values

        dtype: str, optional
            The data type of the tensor

        name: str, optional
            The name hint of the tensor
        self.shape = shape
        self.dtype = dtype = name
        self.stype = 'unknown'

class CSRPlaceholderOp(SparsePlaceholderOp):
    """Placeholder class for CSR based sparse tensor representation."""
    def __init__(self, shape, nonzeros, dtype, name):
        """Contructing a bare bone structure for a csr_matrix

        shape: Tuple of Expr
            The shape of the tensor

        nonzeros: int
            The number of non-zero values

        dtype: str, optional
            The data type of the tensor

        name: str, optional
            The name hint of the tensor
        SparsePlaceholderOp.__init__(self, shape, nonzeros, dtype, name)
        self.stype = 'csr' = _api.placeholder((nonzeros,), dtype=dtype,'_data')
        self.indices = _api.placeholder((nonzeros,), dtype=itype,'_indices')
        self.indptr = _api.placeholder((self.shape[0]+1,), dtype=itype,'_indptr')
        assert isinstance(, _tensor.Tensor)
        assert isinstance(self.indices, _tensor.Tensor)
        assert isinstance(self.indptr, _tensor.Tensor)

def placeholder(shape, nonzeros=None, dtype=None, name="placeholder", stype=None):
    """Construct an empty sparse tensor object.

    shape: Tuple of Expr
        The shape of the tensor

    nonzeros: int
        The number of non-zero values

    dtype: str, optional
        The data type of the tensor

    name: str, optional
        The name hint of the tensor

    stype: str, optional
        The name storage type of the sparse tensor (e.g. csr, coo, ell)

    tensor: SparsePlaceholderOp
        The created sparse tensor placeholder
    shape = (shape,) if isinstance(shape, _expr.PrimExpr) else shape
171 172 173 174 175 176 177 178 179
    nonzeros = 0 if nonzeros is None else nonzeros
    dtype = float32 if dtype is None else dtype
    stype = 'csr' if stype is None else stype
    ret = None
    if stype == 'csr':
        ret = CSRPlaceholderOp(shape=shape, nonzeros=nonzeros, dtype=dtype, name=name)
        raise NotImplementedError('stype=%s is not supported yet.' % (stype,))
    return ret