expr.pyi 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
from typing import List
import tvm
from .base import Span, NodeBase
from .ty import Type, TypeParam
from ._ir_pass import _get_checked_type


class Expr(NodeBase):
    def checked_type(self):
        ...

    def __call__(self, *args):
        ...


class Constant(Expr):
    data = ...  # type: tvm.nd.NDArray

    def __init__(self, data):
        # type: (tvm.nd.NDArray) -> None
        ...


class Tuple(Expr):
25
    fields = ...  # type: List[Expr]
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

    def __init__(self, fields):
        # type: (List[Expr]) -> None
        ...


class Var(Expr):
    """A local variable in Relay."""
    name_hint = ...  # type: str

    def __init__(self, name_hint):
        # type: (str) -> None
        ...


class GlobalVar(Expr):
    name_hint = ...  # type: str

    def __init__(self, name_hint):
        # type: (str) -> None
        ...


class Param(Expr):
    var = ...  # type: Var
    type = ...  # type: Type

    def __init__(self, var, ty):
        # type: (Var, Type) -> None
        ...


class Function(Expr):
    """A function in Relay, see tvm/relay/expr.h for more details."""
    type_params = ...  # type: List[TypeParam]
    params = ...  # type: List[Param]
    ret_type = ...  # type: Type
    body = ...  # type: Expr

    def __init__(self,
                 params,  # type: List[Param],
                 ret_type,  # type: Type,
                 body,  # type: Expr,
                 type_params=None,  # type: List[TypeParam]
                 ):
        # type: (...) -> None
        ...


@register_relay_node
class Call(Expr):
    """A function call in Relay, see tvm/relay/expr.h for more details."""
    op = ...  # type: Expr
    args = ...  # type: List[Expr]
80
    # todo(@jroesch): add attrs. revise attrs type in __init__
81

82 83
    def __init__(self, op, args, attrs=None, ty_args=None):
        # type: (Expr, List[Expr], Optional[List[Any]], Optional[List[Type]]) -> None
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
        if not ty_args:
            ty_args = []

        self.__init_handle_by_constructor__(
            _make.Call, op, args, attrs, ty_args)


@register_relay_node
class Let(Expr):
    """A variable bindings in Relay, see tvm/relay/expr.h for more details."""
    var = ...  # type: Var
    value = ...  # type: Expr
    body = ...  # type: Expr
    value_type = ...  # type: Type

    def __init__(self, var, value, body, value_type):
        # type: (Var, Expr, Expr, Type) -> None
        ...


@register_relay_node
class If(Expr):
    """A conditional expression in Relay, see tvm/relay/expr.h for more details."""
    cond = ...  # type: Expr
    true_value = ...  # type: Expr
    false_value = ...  # type: Expr
    span = ...  # type: Span

    def __init__(self, cond, true_value, false_value):
        # type: (Expr, Expr, Expr) -> None
114
        ...