# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name """The expression nodes of Relay.""" from __future__ import absolute_import from numbers import Number as _Number import numpy as _np from .base import RelayNode, register_relay_node from . import _make from . import _expr from . import ty as _ty from .._ffi import base as _base from .. import nd as _nd from .. import convert from ..ndarray import NDArray # will be registered afterwards _op_make = None class Expr(RelayNode): """The base type for all Relay expressions.""" @property def checked_type(self): """Get the checked type of tvm.relay.Expr. Returns ------- checked_type : tvm.relay.Type The checked type. """ ret = self._checked_type_ if ret is None: raise ValueError("The type checker has not populated" " the checked_type for this node") return ret def astype(self, dtype): """Cast the content type of the current data to dtype. Parameters ---------- dtype : str The target data type. Note ---- This function only works for TensorType Exprs. Returns ------- result : tvm.relay.Expr The result expression. """ return _make.cast(self, dtype) def __neg__(self): return _op_make.negative(self) def __lt__(self, other): if isinstance(other, Expr): return _op_make.less(self, other) elif isinstance(other, _Number): raise TypeError('convert "%s" with `const` first' % str(other)) else: raise TypeError("type %s not supported" % str(type(other))) def __gt__(self, other): if isinstance(other, Expr): return _op_make.greater(self, other) elif isinstance(other, _Number): raise TypeError('convert "%s" with `const` first' % str(other)) else: raise TypeError("type %s not supported" % str(type(other))) def __ge__(self, other): if isinstance(other, Expr): return _op_make.greater_equal(self, other) elif isinstance(other, _Number): raise TypeError('convert "%s" with `const` first' % str(other)) else: raise TypeError("type %s not supported" % str(type(other))) def __le__(self, other): if isinstance(other, Expr): return _op_make.less_equal(self, other) elif isinstance(other, _Number): raise TypeError('convert "%s" with `const` first' % str(other)) else: raise TypeError("type %s not supported" % str(type(other))) def __add__(self, other): if isinstance(other, Expr): return _op_make.add(self, other) elif isinstance(other, _Number): raise TypeError('convert "%s" with `const` first' % str(other)) else: raise TypeError("type %s not supported" % str(type(other))) def __radd__(self, other): return self.__add__(other) def __sub__(self, other): if isinstance(other, Expr): return _op_make.subtract(self, other) elif isinstance(other, _Number): raise TypeError('convert "%s" with `const` first' % str(other)) else: raise TypeError("type %s not supported" % str(type(other))) def __rsub__(self, other): if isinstance(other, _Number): raise TypeError('convert "%s" with `const` first' % str(other)) else: raise TypeError("type %s not supported" % str(type(other))) def __mul__(self, other): if isinstance(other, Expr): return _op_make.multiply(self, other) elif isinstance(other, _Number): raise TypeError('convert "%s" with `const` first' % str(other)) else: raise TypeError("type %s not supported" % str(type(other))) def __rmul__(self, other): return self.__mul__(other) def __div__(self, other): if isinstance(other, Expr): return _op_make.divide(self, other) elif isinstance(other, _Number): raise TypeError('convert "%s" with `const` first' % str(other)) else: raise TypeError("type %s not supported" % str(type(other))) def __rdiv__(self, other): if isinstance(other, _Number): raise TypeError('convert "%s" with `const` first' % str(other)) else: raise TypeError("type %s not supported" % str(type(other))) def __truediv__(self, other): return self.__div__(other) def __rtruediv__(self, other): return self.__rdiv__(other) def __call__(self, *args): """Call the variable (if it represents a function). Parameters ---------- args: List[relay.Expr] The arguments to the call. Returns ------- call: Call A call taking the variable as a function. """ return Call(self, args) @register_relay_node class Constant(Expr): """A constant expression in Relay. Parameters ---------- data : tvm.nd.NDArray The data content of the constant expression. """ def __init__(self, data): self.__init_handle_by_constructor__(_make.Constant, data) @register_relay_node class Tuple(Expr): """Tuple expression that groups several fields together. Parameters ---------- fields : List[tvm.relay.Expr] The fields in the tuple. """ def __init__(self, fields): self.__init_handle_by_constructor__(_make.Tuple, fields) def __getitem__(self, index): if index >= len(self): raise IndexError("Tuple index out of range") return self.fields[index] def __len__(self): return len(self.fields) def astype(self, _): raise TypeError("astype cannot be used on tuple") @register_relay_node class Var(Expr): """A local variable in Relay. Local variable can be used to declare input arguments to a function, or intermediate variables. Parameters ---------- name_hint: str The name of the variable. This name only acts as a hint, and is not used for equality. type_annotation: tvm.relay.Type, optional The type annotation on the variable. """ def __init__(self, name_hint, type_annotation=None): self.__init_handle_by_constructor__( _make.Var, name_hint, type_annotation) @property def name_hint(self): """Get name hint of the current var.""" name = self.vid.name_hint return name @register_relay_node class GlobalVar(Expr): """A global variable in Tvm.Relay. GlobalVar is used to refer to the global functions stored in the module. Parameters ---------- name_hint: str The name of the variable. """ def __init__(self, name_hint): self.__init_handle_by_constructor__(_make.GlobalVar, name_hint) def __call__(self, *args): """Invoke the gobal function. Parameters ---------- args: List[relay.Expr] Arguments. """ return Call(self, args, None, None) @register_relay_node class Function(Expr): """A function declaration expression. Parameters ---------- params: List[tvm.relay.Var] List of input parameters to the function. body: tvm.relay.Expr The body of the function. ret_type: Optional[tvm.relay.Type] The return type annotation of the function. type_params: Optional[List[tvm.relay.TypeParam]] The additional type parameters, this is only used in advanced usecase of template functions. """ def __init__(self, params, body, ret_type=None, type_params=None, attrs=None): if type_params is None: type_params = convert([]) self.__init_handle_by_constructor__( _make.Function, params, body, ret_type, type_params, attrs) def __call__(self, *args): """Invoke the global function. Parameters ---------- args: List[relay.Expr] Arguments. """ return Call(self, args, None, None) def get_params(self): return _expr.FunctionGetParams(self) def set_params(self, params): for key in params: value = params[key] if isinstance(value, NDArray): params[key] = Constant(value) return _expr.FunctionSetParams(self, params) def set_attribute(self, name, ref): return _expr.FunctionSetAttr(self, name, ref) @register_relay_node class Call(Expr): """Function call node in Relay. Call node corresponds the operator application node in computational graph terminology. Parameters ---------- op: tvm.relay.Op or any tvm.relay.Expr with function type. The operation to be called. args: List[tvm.relay.Expr] The arguments to the call. attrs: Optional[tvm.Attrs] Attributes to the call, can be None type_args: Optional[List[tvm.relay.Type]] The additional type arguments, this is only used in advanced usecase of template functions. """ def __init__(self, op, args, attrs=None, type_args=None): if not type_args: type_args = [] self.__init_handle_by_constructor__( _make.Call, op, args, attrs, type_args) @register_relay_node class Let(Expr): """Let variable binding expression. Parameters ---------- variable: tvm.relay.Var The local variable to be bound. value: tvm.relay.Expr The value to be bound. body: tvm.relay.Expr The body of the let binding. """ def __init__(self, variable, value, body): self.__init_handle_by_constructor__( _make.Let, variable, value, body) @register_relay_node class If(Expr): """A conditional expression in Relay. Parameters ---------- cond: tvm.relay.Expr The condition. true_branch: tvm.relay.Expr The expression evaluated when condition is true. false_branch: tvm.relay.Expr The expression evaluated when condition is false. """ def __init__(self, cond, true_branch, false_branch): self.__init_handle_by_constructor__( _make.If, cond, true_branch, false_branch) @register_relay_node class TupleGetItem(Expr): """Get index-th item from a tuple. Parameters ---------- tuple_value: tvm.relay.Expr The input tuple expression. index: int The index. """ def __init__(self, tuple_value, index): self.__init_handle_by_constructor__( _make.TupleGetItem, tuple_value, index) @register_relay_node class RefCreate(Expr): """Create a new reference from initial value. Parameters ---------- value: tvm.relay.Expr The initial value. """ def __init__(self, value): self.__init_handle_by_constructor__(_make.RefCreate, value) @register_relay_node class RefRead(Expr): """Get the value inside the reference. Parameters ---------- ref: tvm.relay.Expr The reference. """ def __init__(self, ref): self.__init_handle_by_constructor__(_make.RefRead, ref) @register_relay_node class RefWrite(Expr): """ Update the value inside the reference. The whole expression will evaluate to an empty tuple. Parameters ---------- ref: tvm.relay.Expr The reference. value: tvm.relay.Expr The new value. """ def __init__(self, ref, value): self.__init_handle_by_constructor__(_make.RefWrite, ref, value) class TempExpr(Expr): """Baseclass of all TempExpr. TempExprs are pass specific expression that can be useful to define intermediate result in the rewriting pass such as layout or type transformation. """ def realize(self): """Convert the expression to a normal(non-temp) Expr. Returns ------- The corresponding normal expression. """ return _expr.TempExprRealize(self) class TupleWrapper(object): """TupleWrapper. This class is a Python wrapper for a Relay tuple of known size. It allows for accessing the fields of the Relay tuple as though it were a Python tuple. Parameters ---------- tuple_value: tvm.relay.Expr The input tuple size: int The size of the tuple. """ def __init__(self, tuple_value, size): self.tuple_value = tuple_value self.size = size def astuple(self): """Returns the underlying Relay tuple if this wrapper is passed as an argument to an FFI function.""" return self.tuple_value def astext(self): """Get the text format of the tuple expression. Returns ------- text : str The text format of the tuple expression. """ return self.tuple_value.astext() def __getitem__(self, index): if index >= len(self): raise IndexError("Tuple index out of range") return TupleGetItem(self.tuple_value, index) def __len__(self): return self.size def __repr__(self): return ("TupleWrapper(" + self.tuple_value.__repr__() + ", " + str(self.size) + ")") def astype(self, _): raise TypeError("astype cannot be used on tuple") def var(name_hint, type_annotation=None, shape=None, dtype="float32"): """Create a new tvm.relay.Var. This is a simple wrapper function that allows specify shape and dtype directly. Parameters ---------- name_hint: str The name of the variable. This name only acts as a hint, and is not used for equality. type_annotation: Optional[tvm.relay.Type, str] The type annotation on the variable. When type_annotation is a str, we will create a scalar variable. shape: Optional[List[tvm.Expr]] The shape of the tensor type. dtype: str, optional The data type of the tensor. Examples -------- .. code-block:: python # The following 4 lines are equivalent to each other x = tvm.relay.Var("x", tvm.relay.TensorType([1, 2])) x = tvm.relay.var("x", tvm.relay.TensorType([1, 2])) x = tvm.relay.var("x", shape=[1, 2]) x = tvm.relay.var("x", shape=[1, 2], dtype="float32") # The following 2 lines are equivalent to each other. y = tvm.relay.var("x", "float32") y = tvm.relay.var("x", shape=(), dtype="float32") """ if type_annotation is not None and shape is not None: raise ValueError("Can only specify either type_annotation or shape.") if shape is not None: type_annotation = _ty.TensorType(shape, dtype) elif isinstance(type_annotation, str): type_annotation = _ty.TensorType((), type_annotation) return Var(name_hint, type_annotation) def const(value, dtype=None): """Create a constant value. Parameters ---------- value: Union[bool, int, float, numpy.ndarray, tvm.nd.NDArray] The constant value. dtype: str, optional The data type of the value. Note ---- When dtype is None, we use the following rule: - int maps to "int32" - float maps to "float32" - bool maps to "bool" - other using the same default rule as numpy. """ if isinstance(value, (_base.numeric_types, (bool, list))): value = _np.array(value, dtype=dtype) if not dtype: # when dtype is None: int maps to "int32", float maps to "float32" map_dtype = { _np.dtype('int64'): _np.int32, _np.dtype('float64'): _np.float32 }.get(value.dtype, None) if map_dtype: value = value.astype(map_dtype) if isinstance(value, (_np.ndarray, _np.generic)): value = _nd.array(value) if not isinstance(value, _nd.NDArray): raise ValueError("value has to be scalar or NDArray") return Constant(value) def bind(expr, binds): """Bind an free variables in expr or function arguments. We can bind parameters expr if it is a function. Parameters ---------- expr : tvm.relay.Expr The input expression. binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]] The specific bindings. Returns ------- result : tvm.relay.Expr The expression or function after binding. """ return _expr.Bind(expr, binds)