_tensor.py 8.01 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
#pylint: disable=invalid-name, unused-argument, len-as-condition
18
"""Backend compiler related feature registration"""
19
import topi
20 21

from tvm.runtime import convert
22
from tvm.te.hybrid import script
23
from topi.util import get_const_tuple
24 25 26
from .op import register_compute, register_shape_func
from .op import register_broadcast_schedule, register_injective_schedule
from .op import register_pattern, OpPattern
27

28 29

register_broadcast_schedule("log")
30
register_broadcast_schedule("tan")
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
register_broadcast_schedule("cos")
register_broadcast_schedule("sin")
register_broadcast_schedule("atan")
register_broadcast_schedule("exp")
register_broadcast_schedule("erf")
register_broadcast_schedule("sqrt")
register_broadcast_schedule("rsqrt")
register_broadcast_schedule("sigmoid")
register_broadcast_schedule("floor")
register_broadcast_schedule("ceil")
register_broadcast_schedule("trunc")
register_broadcast_schedule("round")
register_broadcast_schedule("sign")
register_broadcast_schedule("abs")
register_broadcast_schedule("tanh")
register_broadcast_schedule("add")
register_broadcast_schedule("subtract")
register_broadcast_schedule("multiply")
register_broadcast_schedule("divide")
register_broadcast_schedule("floor_divide")
register_broadcast_schedule("power")
register_broadcast_schedule("copy")
register_broadcast_schedule("logical_not")
register_broadcast_schedule("logical_and")
register_broadcast_schedule("logical_or")
register_broadcast_schedule("bitwise_not")
register_broadcast_schedule("bitwise_and")
register_broadcast_schedule("bitwise_or")
register_broadcast_schedule("bitwise_xor")
register_broadcast_schedule("negative")
register_broadcast_schedule("mod")
register_broadcast_schedule("floor_mod")
register_broadcast_schedule("equal")
register_broadcast_schedule("not_equal")
register_broadcast_schedule("less")
register_broadcast_schedule("less_equal")
register_broadcast_schedule("greater")
register_broadcast_schedule("greater_equal")
69 70
register_broadcast_schedule("isfinite")
register_broadcast_schedule("isinf")
71 72 73 74 75
register_injective_schedule("maximum")
register_injective_schedule("minimum")
register_injective_schedule("right_shift")
register_injective_schedule("left_shift")
register_injective_schedule("shape_of")
76
register_injective_schedule("ndarray_size")
77 78
register_broadcast_schedule("fast_exp")
register_broadcast_schedule("fast_tanh")
79

80 81

# zeros
82
@register_compute("zeros")
83
def zeros_compute(attrs, inputs, output_type):
84 85 86
    assert not inputs
    return [topi.full(output_type.shape, output_type.dtype, 0.0)]

87
register_broadcast_schedule("zeros")
88
register_pattern("zeros", OpPattern.ELEMWISE)
89 90

# zeros_like
91
@register_compute("zeros_like")
92
def zeros_like_compute(attrs, inputs, output_type):
93 94 95
    assert len(inputs) == 1
    return [topi.full_like(inputs[0], 0.0)]

96
register_broadcast_schedule("zeros_like")
97 98

# ones
99
@register_compute("ones")
100
def ones_compute(attrs, inputs, output_type):
101 102 103
    assert not inputs
    return [topi.full(output_type.shape, output_type.dtype, 1.0)]

104
register_broadcast_schedule("ones")
105
register_pattern("ones", OpPattern.ELEMWISE)
106 107

# ones_like
108
@register_compute("ones_like")
109
def ones_like_compute(attrs, inputs, output_type):
110 111 112
    assert len(inputs) == 1
    return [topi.full_like(inputs[0], 1.0)]

113
register_broadcast_schedule("ones_like")
114 115

# clip
116
@register_compute("clip")
117
def clip_compute(attrs, inputs, output_type):
118 119 120
    assert len(inputs) == 1
    return [topi.clip(inputs[0], attrs.a_min, attrs.a_max)]

121
register_injective_schedule("clip")
122

123 124 125 126 127 128 129 130 131 132 133
@script
def _cast_shape_function(x):
    out_ndim = len(x)
    out = output_tensor((out_ndim,), "int64")
    for i in const_range(out_ndim):
        out[i] = x[i]
    return out

def cast_shape_func(attrs, inputs, out_ndims):
    return [_cast_shape_function(*inputs)]

134
@script
135 136
def _full_shape_func(shape):
    out_ndim = len(shape)
137 138
    out = output_tensor((out_ndim,), "int64")
    for i in const_range(out_ndim):
139
        out[i] = int64(shape[i])
140 141 142 143 144 145
    return out

def full_shape_func(attrs, inputs, out_ndims):
    """
    Shape func for zeros, zeros_like, ones, ones_like.
    """
146 147
    shape = get_const_tuple(attrs.shape)
    return [_full_shape_func(convert(shape))]
148 149

@script
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
def _broadcast_shape_func(x, y, ndim):
    out = output_tensor((ndim,), "int64")
    if len(x.shape) == 0:
        for i in const_range(ndim):
            out[i] = y[i]
    elif len(y.shape) == 0:
        for i in const_range(ndim):
            out[i] = x[i]
    else:
        ndim1 = x.shape[0]
        ndim2 = y.shape[0]
        for i in const_range(1, min(ndim1, ndim2)+1):
            if x[ndim1-i] == y[ndim2-i]:
                out[ndim-i] = x[ndim1-i]
            elif x[ndim1-i] == 1:
                out[ndim-i] = y[ndim2-i]
            else:
                assert y[ndim2 - i] == 1, "Incompatible broadcast type %s and %s" % (
                    x[ndim1-i], y[ndim2-i])
                out[ndim-i] = x[ndim1-i]
        for i in const_range(min(ndim1, ndim2)+1, ndim+1):
            if ndim1 >= ndim2:
                out[ndim-i] = x[ndim1-i]
            else:
                out[ndim-i] = y[ndim2-i]
    return out

def broadcast_shape_func(attrs, inputs, out_ndims):
178 179 180
    """
    Shape function for broadcast op.
    """
181 182
    return [_broadcast_shape_func(*inputs, out_ndims[0])]

183 184 185 186 187 188
def elemwise_shape_func(attrs, inputs, _):
    """
    Shape function for elemwise op.
    """
    return [topi.math.identity(inputs[0])]

189
register_shape_func("cast", False, cast_shape_func)
190
register_shape_func("zeros", False, full_shape_func)
191
register_shape_func("zeros_like", False, elemwise_shape_func)
192
register_shape_func("ones", False, full_shape_func)
193 194 195
register_shape_func("ones_like", False, elemwise_shape_func)
register_shape_func("full", False, full_shape_func)
register_shape_func("full_like", False, elemwise_shape_func)
196

197 198 199 200
register_shape_func("add", False, broadcast_shape_func)
register_shape_func("subtract", False, broadcast_shape_func)
register_shape_func("multiply", False, broadcast_shape_func)
register_shape_func("divide", False, broadcast_shape_func)
201
register_shape_func("floor_divide", False, broadcast_shape_func)
202
register_shape_func("mod", False, broadcast_shape_func)
203
register_shape_func("floor_mod", False, broadcast_shape_func)
204 205
register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func)
206
register_shape_func("bitwise_not", False, broadcast_shape_func)
207 208 209
register_shape_func("bitwise_and", False, broadcast_shape_func)
register_shape_func("bitwise_or", False, broadcast_shape_func)
register_shape_func("bitwise_xor", False, broadcast_shape_func)
210 211 212 213 214 215
register_shape_func("equal", False, broadcast_shape_func)
register_shape_func("not_equal", False, broadcast_shape_func)
register_shape_func("less", False, broadcast_shape_func)
register_shape_func("less_equal", False, broadcast_shape_func)
register_shape_func("greater", False, broadcast_shape_func)
register_shape_func("greater_equal", False, broadcast_shape_func)
216 217
register_shape_func("maximum", False, broadcast_shape_func)
register_shape_func("minimum", False, broadcast_shape_func)
218 219 220

register_shape_func("sqrt", False, elemwise_shape_func)
register_shape_func("negative", False, elemwise_shape_func)
221
register_shape_func("exp", False, elemwise_shape_func)
222
register_shape_func("tan", False, elemwise_shape_func)
223 224
register_shape_func("fast_exp", False, elemwise_shape_func)
register_shape_func("fast_tanh", False, elemwise_shape_func)