op.py 7.36 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
#pylint: disable=unused-argument
18
"""The base node types for the Relay language."""
19 20
import topi

21 22 23 24
from ..._ffi.function import _init_api

from ..base import register_relay_node
from ..expr import Expr
25 26
from ...api import register_func
from ...build_module import lower, build
27
from . import _make
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50

@register_relay_node
class Op(Expr):
    """A Relay operator definition."""

    def __init__(self):
        raise RuntimeError("Cannot create op, use get instead")

    def get_attr(self, attr_name):
        """Get additional attribute about the operator.

        Parameters
        ----------
        attr_name : str
            The attribute name.

        Returns
        -------
        value : object
            The attribute value
        """
        return _OpGetAttr(self, attr_name)

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
    def set_attr(self, attr_name, value, plevel=10):
        """Set attribute about the operator.

        Parameters
        ----------
        attr_name : str
            The attribute name

        value : object
            The attribute value

        plevel : int
            The priority level
        """
        _OpSetAttr(self, attr_name, value, plevel)

67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110

def get(op_name):
    """Get the Op for a given name

    Parameters
    ----------
    op_name : str
        The operator name

    Returns
    -------
    op : Op
        The op of the corresponding name
    """
    return _GetOp(op_name)


def register(op_name, attr_key, value=None, level=10):
    """Register an operator property of an operator.


    Parameters
    ----------
    op_name : str
        The name of operator

    attr_key : str
        The attribute name.

    value : object, optional
        The value to set

    level : int, optional
        The priority level

    Returns
    -------
    fregister : function
        Register function if value is not specified.
    """
    def _register(v):
        """internal register function"""
        _Register(op_name, attr_key, v, level)
        return v
111
    return _register(value) if value is not None else _register
112

113

114 115 116 117 118 119 120 121 122 123 124 125 126
class OpPattern(object):
    """Operator generic patterns

    See Also
    --------
    top.tag : Contains explanation of the tag type.
    """
    # Elementwise operator
    ELEMWISE = 0
    # Broadcast operator
    BROADCAST = 1
    # Injective mapping
    INJECTIVE = 2
Haichen Shen committed
127
    # Communication
128 129 130
    COMM_REDUCE = 3
    # Complex op, can still fuse ewise into it
    OUT_ELEMWISE_FUSABLE = 4
131 132
    # Represents tuple node
    TUPLE = 7
133 134 135 136 137 138 139 140 141 142 143 144
    # Not fusable opaque op
    OPAQUE = 8


def register_schedule(op_name, schedule=None, level=10):
    """Register schedule function for an op

    Parameters
    ----------
    op_name : str
        The name of the op.

145
    schedule : function (attrs: Attrs, outs: List[Tensor], target: Target) -> sch: Schedule
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
        The schedule function.

    level : int
        The priority level
    """
    return register(op_name, "FTVMSchedule", schedule, level)


def register_compute(op_name, compute=None, level=10):
    """Register compute function for an op.

    Parameters
    ----------
    op_name : str
        The name of the op.

162 163
    compute : function (attrs: Attrs, inputs: List[Tensor], out_type: Type, target:Target)
                       -> List[Tensor]
164 165 166 167 168 169 170 171
        The compute function.

    level : int
        The priority level
    """
    return register(op_name, "FTVMCompute", compute, level)


172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
def register_alter_op_layout(op_name, alter_layout=None, level=10):
    """Register alter op layout function for an op

    Parameters
    ----------
    op_name : str
        The name of the operator

    alter_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
        The function for changing the layout or replacing the operator

    level : int
        The priority level
    """
    return register(op_name, "FTVMAlterOpLayout", alter_layout, level)


189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
def register_legalize(op_name, legal_op=None, level=10):
    """Register legal transformation function for an op

    Parameters
    ----------
    op_name : str
        The name of the operator

    legal_op: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
        The function for transforming an expr to another expr.

    level : int
        The priority level
    """
    return register(op_name, "FTVMLegalize", legal_op, level)


206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
def register_pattern(op_name, pattern, level=10):
    """Register operator pattern for an op.

    Parameters
    ----------
    op_name : str
        The name of the op.

    pattern : int
        The pattern being used.

    level : int
        The priority level
    """
    return register(op_name, "TOpPattern", pattern, level)

222
def register_gradient(op_name, fgradient=None, level=10):
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
    """Register operator pattern for an op.

    Parameters
    ----------
    op_name : str
        The name of the op.

    fgradient : function (orig_expr : Expr, output_grad : Expr) -> new_expr : Expr
        The gradient being used.

    level : int
        The priority level
    """
    return register(op_name, "FPrimalGradient", fgradient, level)

238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257
def register_shape_func(op_name, data_dependant, shape_func=None, level=10):
    """Register operator shape function for an op.

    Parameters
    ----------
    op_name : str
        The name of the op.

    data_dependant : bool
        Whether the shape function depends on input data.

    shape_func : function (attrs: Attrs, inputs: List[Tensor], out_ndims: List[IndexExpr])
                 -> shape_tensors: List<Tensor>
        The function for computing the dynamic output shapes

    level : int
        The priority level
    """
    get(op_name).set_attr("TShapeDataDependant", data_dependant, level)
    return register(op_name, "FShapeFunc", shape_func, level)
258 259

_init_api("relay.op", __name__)
260 261 262 263 264 265 266 267

@register_func("relay.op.compiler._lower")
def _lower(name, schedule, inputs, outputs):
    return lower(schedule, list(inputs) + list(outputs), name=name)

@register_func("relay.op.compiler._build")
def _build(lowered_funcs):
    return build(lowered_funcs, target="llvm")
268 269 270 271 272 273


def schedule_injective(attrs, outputs, target):
    """Generic schedule for binary broadcast."""
    with target:
        return topi.generic.schedule_injective(outputs)
274

hlu1 committed
275 276 277 278 279 280 281

def schedule_concatenate(attrs, outputs, target):
    """Generic schedule for concatinate."""
    with target:
        return topi.generic.schedule_concatenate(outputs)


282 283 284 285 286 287 288 289 290 291 292 293 294 295
__DEBUG_COUNTER__ = 0

def debug(expr, debug_func=None):
    """The main entry point to the debugger."""
    global __DEBUG_COUNTER__

    if debug_func:
        name = "debugger_func{}".format(__DEBUG_COUNTER__)
        register_func(name, debug_func)
        __DEBUG_COUNTER__ += 1
    else:
        name = ''

    return _make.debug(expr, name)