Unverified Commit 7ca3212f by Zhi Committed by GitHub

create function.py (#5087)

parent 06bbc7c9
...@@ -35,9 +35,6 @@ tvm.relay.expr ...@@ -35,9 +35,6 @@ tvm.relay.expr
.. autoclass:: tvm.relay.expr.Tuple .. autoclass:: tvm.relay.expr.Tuple
:members: :members:
.. autoclass:: tvm.relay.expr.Function
:members:
.. autoclass:: tvm.relay.expr.Call .. autoclass:: tvm.relay.expr.Call
:members: :members:
......
...@@ -120,7 +120,7 @@ Additionally, functions in Relay are higher-order, which means that a function c ...@@ -120,7 +120,7 @@ Additionally, functions in Relay are higher-order, which means that a function c
function or returned by a function, as function expressions evaluate to closures (see the `Closures`_ subsection), function or returned by a function, as function expressions evaluate to closures (see the `Closures`_ subsection),
which are values like tensors and tuples. which are values like tensors and tuples.
See :py:class:`~tvm.relay.expr.Function` for the definition and documentation of function nodes. See :py:class:`~tvm.relay.function.Function` for the definition and documentation of function nodes.
Syntax Syntax
~~~~~~ ~~~~~~
......
...@@ -69,7 +69,7 @@ class BaseGraphTuner(object): ...@@ -69,7 +69,7 @@ class BaseGraphTuner(object):
target_op in the input graph and layout transformation benchmark need to be target_op in the input graph and layout transformation benchmark need to be
executed before initialization. executed before initialization.
graph : tvm.relay.Expr.Function graph : tvm.relay.function.Function
Input graph Input graph
input_shapes : dict of str to tuple. input_shapes : dict of str to tuple.
...@@ -143,7 +143,7 @@ class BaseGraphTuner(object): ...@@ -143,7 +143,7 @@ class BaseGraphTuner(object):
if isinstance(graph, tvm.IRModule): if isinstance(graph, tvm.IRModule):
graph = graph["main"] graph = graph["main"]
if isinstance(graph, relay.expr.Function): if isinstance(graph, relay.function.Function):
node_dict = {} node_dict = {}
graph = bind_inputs(graph, input_shapes, dtype) graph = bind_inputs(graph, input_shapes, dtype)
expr2graph(graph, self._target_ops, node_dict, self._node_list) expr2graph(graph, self._target_ops, node_dict, self._node_list)
......
...@@ -21,7 +21,8 @@ import threading ...@@ -21,7 +21,8 @@ import threading
import tvm import tvm
from tvm import relay, autotvm from tvm import relay, autotvm
from tvm.relay import transform from tvm.relay import transform
from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple from tvm.relay.expr import Call, TupleGetItem, Var, Constant, Tuple
from tvm.relay.function import Function
from tvm.relay.ty import TupleType, TensorType from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv from tvm.autotvm.task import TaskExtractEnv
......
...@@ -61,7 +61,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None): ...@@ -61,7 +61,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
Parameters Parameters
---------- ----------
mod: tvm.IRModule or relay.expr.Function mod: tvm.IRModule or relay.function.Function
The module or function to tune The module or function to tune
params: dict of str to numpy array params: dict of str to numpy array
The associated parameters of the program The associated parameters of the program
...@@ -88,7 +88,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No ...@@ -88,7 +88,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
Parameters Parameters
---------- ----------
mods: List[tvm.IRModule] or List[relay.expr.Function] mods: List[tvm.IRModule] or List[relay.function.Function]
The list of modules or functions to tune The list of modules or functions to tune
params: List of dict of str to numpy array params: List of dict of str to numpy array
The associated parameters of the programs The associated parameters of the programs
...@@ -118,7 +118,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No ...@@ -118,7 +118,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
logger.disabled = True logger.disabled = True
for mod, param in zip(mods, params): for mod, param in zip(mods, params):
if isinstance(mod, relay.expr.Function): if isinstance(mod, relay.function.Function):
mod = tvm.IRModule.from_expr(mod) mod = tvm.IRModule.from_expr(mod)
assert isinstance(mod, tvm.IRModule), \ assert isinstance(mod, tvm.IRModule), \
"only support relay Module or Function to be tuned" "only support relay Module or Function to be tuned"
......
...@@ -22,6 +22,7 @@ from sys import setrecursionlimit ...@@ -22,6 +22,7 @@ from sys import setrecursionlimit
from . import base from . import base
from . import ty from . import ty
from . import expr from . import expr
from . import function
from . import type_functor from . import type_functor
from . import expr_functor from . import expr_functor
from . import adt from . import adt
...@@ -87,7 +88,7 @@ Constant = expr.Constant ...@@ -87,7 +88,7 @@ Constant = expr.Constant
Tuple = expr.Tuple Tuple = expr.Tuple
Var = expr.Var Var = expr.Var
GlobalVar = expr.GlobalVar GlobalVar = expr.GlobalVar
Function = expr.Function Function = function.Function
Call = expr.Call Call = expr.Call
Let = expr.Let Let = expr.Let
If = expr.If If = expr.If
......
...@@ -43,6 +43,7 @@ from tvm.ir import IRModule ...@@ -43,6 +43,7 @@ from tvm.ir import IRModule
from .base import Span, SourceName from .base import Span, SourceName
from . import adt from . import adt
from . import expr from . import expr
from . import function
from . import ty from . import ty
from . import op from . import op
...@@ -481,7 +482,7 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -481,7 +482,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def mk_func( def mk_func(
self, self,
ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \ ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \
-> expr.Function: -> function.Function:
"""Construct a function from either a Func or Defn.""" """Construct a function from either a Func or Defn."""
# Enter var scope early to put params in scope. # Enter var scope early to put params in scope.
self.enter_var_scope() self.enter_var_scope()
...@@ -511,10 +512,10 @@ class ParseTreeToRelayIR(RelayVisitor): ...@@ -511,10 +512,10 @@ class ParseTreeToRelayIR(RelayVisitor):
self.exit_var_scope() self.exit_var_scope()
attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None
return expr.Function(var_list, body, ret_type, type_params, attrs) return function.Function(var_list, body, ret_type, type_params, attrs)
@spanify @spanify
def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function: def visitFunc(self, ctx: RelayParser.FuncContext) -> function.Function:
return self.mk_func(ctx) return self.mk_func(ctx)
# TODO: how to set spans for definitions? # TODO: how to set spans for definitions?
......
...@@ -421,7 +421,7 @@ def extract_fused_functions(mod): ...@@ -421,7 +421,7 @@ def extract_fused_functions(mod):
Returns Returns
------- -------
ret : Dict[int, tvm.relay.ir.expr.Function] ret : Dict[int, tvm.relay.function.Function]
A module containing only fused primitive functions A module containing only fused primitive functions
""" """
ret_mod = _ffi_api.ExtractFusedFunctions()(mod) ret_mod = _ffi_api.ExtractFusedFunctions()(mod)
......
...@@ -25,7 +25,7 @@ from tvm import te ...@@ -25,7 +25,7 @@ from tvm import te
from tvm.runtime import Object from tvm.runtime import Object
from ... import target as _target from ... import target as _target
from ... import autotvm from ... import autotvm
from .. import expr as _expr from .. import function as _function
from .. import op as _op from .. import op as _op
from .. import ty as _ty from .. import ty as _ty
from . import _backend from . import _backend
...@@ -65,7 +65,7 @@ class CCacheValue(Object): ...@@ -65,7 +65,7 @@ class CCacheValue(Object):
def _get_cache_key(source_func, target): def _get_cache_key(source_func, target):
if isinstance(source_func, _expr.Function): if isinstance(source_func, _function.Function):
if isinstance(target, str): if isinstance(target, str):
target = _target.create(target) target = _target.create(target)
if not target: if not target:
......
...@@ -27,7 +27,8 @@ from tvm.ir import IRModule ...@@ -27,7 +27,8 @@ from tvm.ir import IRModule
from . import _backend from . import _backend
from .. import _make, analysis, transform from .. import _make, analysis, transform
from ... import nd from ... import nd
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, const
from ..function import Function
from ..scope_builder import ScopeBuilder from ..scope_builder import ScopeBuilder
......
...@@ -29,6 +29,7 @@ from ..contrib import graph_runtime as _graph_rt ...@@ -29,6 +29,7 @@ from ..contrib import graph_runtime as _graph_rt
from . import _build_module from . import _build_module
from . import ty as _ty from . import ty as _ty
from . import expr as _expr from . import expr as _expr
from . import function as _function
from .backend import interpreter as _interpreter from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor from .backend.vm import VMExecutor
...@@ -218,16 +219,16 @@ def build(mod, target=None, target_host=None, params=None): ...@@ -218,16 +219,16 @@ def build(mod, target=None, target_host=None, params=None):
params : dict params : dict
The parameters of the final graph. The parameters of the final graph.
""" """
if not isinstance(mod, (IRModule, _expr.Function)): if not isinstance(mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule") raise ValueError("Type of input parameter mod must be tvm.IRModule")
if isinstance(mod, _expr.Function): if isinstance(mod, _function.Function):
if params: if params:
mod = bind_params_by_name(mod, params) mod = bind_params_by_name(mod, params)
mod = IRModule.from_expr(mod) mod = IRModule.from_expr(mod)
warnings.warn( warnings.warn(
"Please use input parameter mod (tvm.IRModule) " "Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter mod (tvm.relay.expr.Function)", "instead of deprecated parameter mod (tvm.relay.function.Function)",
DeprecationWarning) DeprecationWarning)
target = _update_target(target) target = _update_target(target)
...@@ -276,16 +277,16 @@ def optimize(mod, target=None, params=None): ...@@ -276,16 +277,16 @@ def optimize(mod, target=None, params=None):
params : dict params : dict
The parameters of the final graph. The parameters of the final graph.
""" """
if not isinstance(mod, (IRModule, _expr.Function)): if not isinstance(mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule") raise ValueError("Type of input parameter mod must be tvm.IRModule")
if isinstance(mod, _expr.Function): if isinstance(mod, _function.Function):
if params: if params:
mod = bind_params_by_name(mod, params) mod = bind_params_by_name(mod, params)
mod = IRModule.from_expr(mod) mod = IRModule.from_expr(mod)
warnings.warn( warnings.warn(
"Please use input parameter mod (tvm.IRModule) " "Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)", "instead of deprecated parameter func (tvm.relay.function.Function)",
DeprecationWarning) DeprecationWarning)
target = _update_target(target) target = _update_target(target)
......
...@@ -22,8 +22,8 @@ from numbers import Number as _Number ...@@ -22,8 +22,8 @@ from numbers import Number as _Number
import numpy as _np import numpy as _np
import tvm._ffi import tvm._ffi
from tvm._ffi import base as _base from tvm._ffi import base as _base
from tvm.runtime import NDArray, convert, ndarray as _nd from tvm.runtime import NDArray, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar, BaseFunc from tvm.ir import RelayExpr, GlobalVar
from .base import RelayNode from .base import RelayNode
from . import _ffi_api from . import _ffi_api
...@@ -225,68 +225,6 @@ class Var(ExprWithOp): ...@@ -225,68 +225,6 @@ class Var(ExprWithOp):
return name return name
@tvm._ffi.register_object("relay.Function")
class Function(BaseFunc):
"""A function declaration expression.
Parameters
----------
params: List[tvm.relay.Var]
List of input parameters to the function.
body: tvm.relay.Expr
The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def __init__(self,
params,
body,
ret_type=None,
type_params=None,
attrs=None):
if type_params is None:
type_params = convert([])
self.__init_handle_by_constructor__(
_ffi_api.Function, params, body, ret_type, type_params, attrs)
def __call__(self, *args):
"""Invoke the global function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return Call(self, args, None, None)
def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.FunctionWithAttr(
self, attr_key, convert(attr_value))
@tvm._ffi.register_object("relay.Call") @tvm._ffi.register_object("relay.Call")
class Call(ExprWithOp): class Call(ExprWithOp):
"""Function call node in Relay. """Function call node in Relay.
......
...@@ -17,7 +17,8 @@ ...@@ -17,7 +17,8 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression functor of Relay.""" """The expression functor of Relay."""
from .expr import Function, Call, Let, Var, GlobalVar from .function import Function
from .expr import Call, Let, Var, GlobalVar
from .expr import If, Tuple, TupleGetItem, Constant from .expr import If, Tuple, TupleGetItem, Constant
from .expr import RefCreate, RefRead, RefWrite from .expr import RefCreate, RefRead, RefWrite
from .adt import Constructor, Match, Clause from .adt import Constructor, Match, Clause
......
...@@ -21,6 +21,7 @@ from tvm.ir import IRModule ...@@ -21,6 +21,7 @@ from tvm.ir import IRModule
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import function as _function
from .. import op as _op from .. import op as _op
from ... import nd as _nd from ... import nd as _nd
from .common import AttrCvt, Renamer from .common import AttrCvt, Renamer
...@@ -451,7 +452,7 @@ class Caffe2NetDef(object): ...@@ -451,7 +452,7 @@ class Caffe2NetDef(object):
else: else:
outputs = out[0] outputs = out[0]
func = _expr.Function(analysis.free_vars(outputs), outputs) func = _function.Function(analysis.free_vars(outputs), outputs)
self._mod["main"] = func self._mod["main"] = func
return self._mod, self._params return self._mod, self._params
...@@ -517,7 +518,7 @@ class Caffe2NetDef(object): ...@@ -517,7 +518,7 @@ class Caffe2NetDef(object):
---------- ----------
op_type : str op_type : str
Operator name, such as Convolution, FullyConnected Operator name, such as Convolution, FullyConnected
inputs : list of tvm.relay.expr.Function inputs : list of tvm.relay.function.Function
List of input inputs. List of input inputs.
args : dict args : dict
Dict of operator attributes Dict of operator attributes
...@@ -530,7 +531,7 @@ class Caffe2NetDef(object): ...@@ -530,7 +531,7 @@ class Caffe2NetDef(object):
Returns Returns
------- -------
func : tvm.relay.expr.Function func : tvm.relay.function.Function
Converted relay function Converted relay function
""" """
identity_list = identity_list if identity_list else _identity_list identity_list = identity_list if identity_list else _identity_list
......
...@@ -24,6 +24,7 @@ from tvm.ir import IRModule ...@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from topi.util import get_const_tuple from topi.util import get_const_tuple
from .. import expr as _expr from .. import expr as _expr
from .. import function as _function
from .. import transform as _transform from .. import transform as _transform
from .. import op as _op from .. import op as _op
from .. import analysis from .. import analysis
...@@ -459,7 +460,7 @@ def infer_type(node, mod=None): ...@@ -459,7 +460,7 @@ def infer_type(node, mod=None):
new_mod.update(mod) new_mod.update(mod)
new_mod = _transform.InferType()(new_mod) new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"] entry = new_mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body return entry if isinstance(node, _function.Function) else entry.body
def infer_shape(inputs, mod=None): def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph.""" """A method to get the output type of an intermediate node in the graph."""
...@@ -491,7 +492,7 @@ def infer_value(input_val, params): ...@@ -491,7 +492,7 @@ def infer_value(input_val, params):
# Check that all free variables have associated parameters. # Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars( assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params." input_val)), "All inputs to infer must be available in params."
func = _expr.Function(analysis.free_vars(input_val), input_val) func = _function.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0): with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params) graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0) ctx = tvm.cpu(0)
......
...@@ -24,6 +24,7 @@ from tvm.ir import IRModule ...@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import function as _function
from .. import op as _op from .. import op as _op
from ... import nd as _nd from ... import nd as _nd
from ..._ffi import base as _base from ..._ffi import base as _base
...@@ -503,6 +504,6 @@ def from_coreml(model, shape=None): ...@@ -503,6 +504,6 @@ def from_coreml(model, shape=None):
for o in spec.description.output] for o in spec.description.output]
# for now return first output # for now return first output
outexpr = outexpr[0] outexpr = outexpr[0]
func = _expr.Function(analysis.free_vars(outexpr), outexpr) func = _function.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return IRModule.from_expr(func), params return IRModule.from_expr(func), params
...@@ -26,6 +26,7 @@ from tvm.ir import IRModule ...@@ -26,6 +26,7 @@ from tvm.ir import IRModule
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import function as _function
from .common import get_relay_op, new_var from .common import get_relay_op, new_var
__all__ = ['from_darknet'] __all__ = ['from_darknet']
...@@ -821,7 +822,7 @@ class GraphProto(object): ...@@ -821,7 +822,7 @@ class GraphProto(object):
outputs = _as_list(sym) + self._outs outputs = _as_list(sym) + self._outs
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
sym = _expr.Function(analysis.free_vars(outputs), outputs) sym = _function.Function(analysis.free_vars(outputs), outputs)
return IRModule.from_expr(sym), self._tvmparams return IRModule.from_expr(sym), self._tvmparams
def from_darknet(net, def from_darknet(net,
......
...@@ -23,6 +23,7 @@ from tvm.ir import IRModule ...@@ -23,6 +23,7 @@ from tvm.ir import IRModule
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import function as _function
from .. import op as _op from .. import op as _op
from ... import nd as _nd from ... import nd as _nd
from .common import ExprTable, new_var from .common import ExprTable, new_var
...@@ -914,6 +915,6 @@ def from_keras(model, shape=None, layout='NCHW'): ...@@ -914,6 +915,6 @@ def from_keras(model, shape=None, layout='NCHW'):
outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \ outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \
for oc in model._output_coordinates] for oc in model._output_coordinates]
outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr) outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
func = _expr.Function(analysis.free_vars(outexpr), outexpr) func = _function.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return IRModule.from_expr(func), params return IRModule.from_expr(func), params
...@@ -25,6 +25,7 @@ from tvm import relay ...@@ -25,6 +25,7 @@ from tvm import relay
from topi.util import get_const_tuple from topi.util import get_const_tuple
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import function as _function
from .. import op as _op from .. import op as _op
from .. import scope_builder as _scope_builder from .. import scope_builder as _scope_builder
from ... import nd as _nd from ... import nd as _nd
...@@ -1096,7 +1097,7 @@ def _mx_cond(inputs, attrs, subgraphs): ...@@ -1096,7 +1097,7 @@ def _mx_cond(inputs, attrs, subgraphs):
else_arg_dtype_info = [arg.type_annotation.dtype for arg in else_args] else_arg_dtype_info = [arg.type_annotation.dtype for arg in else_args]
else_func = _from_mxnet_impl(subgraphs[2], else_arg_shapes, else_arg_dtype_info) else_func = _from_mxnet_impl(subgraphs[2], else_arg_shapes, else_arg_dtype_info)
sb.ret(_expr.Call(else_func, else_args)) sb.ret(_expr.Call(else_func, else_args))
func = _expr.Function(input_args, sb.get()) func = _function.Function(input_args, sb.get())
ret = _expr.Call(func, inputs) ret = _expr.Call(func, inputs)
if num_outputs > 1: if num_outputs > 1:
ret = _expr.TupleWrapper(ret, num_outputs) ret = _expr.TupleWrapper(ret, num_outputs)
...@@ -1969,7 +1970,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None): ...@@ -1969,7 +1970,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]] outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(analysis.free_vars(outputs), outputs) func = _function.Function(analysis.free_vars(outputs), outputs)
return func return func
......
...@@ -24,6 +24,7 @@ from tvm.ir import IRModule ...@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from ... import nd as _nd from ... import nd as _nd
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import function as _function
from .. import op as _op from .. import op as _op
from .common import AttrCvt, Renamer from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels from .common import get_relay_op, new_var, infer_shape, infer_channels
...@@ -1708,7 +1709,7 @@ class GraphProto(object): ...@@ -1708,7 +1709,7 @@ class GraphProto(object):
# now return the outputs # now return the outputs
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output] outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(analysis.free_vars(outputs), outputs) func = _function.Function(analysis.free_vars(outputs), outputs)
return IRModule.from_expr(func), self._params return IRModule.from_expr(func), self._params
def _parse_value_proto(self, value_proto): def _parse_value_proto(self, value_proto):
...@@ -1774,7 +1775,7 @@ class GraphProto(object): ...@@ -1774,7 +1775,7 @@ class GraphProto(object):
---------- ----------
op_name : str op_name : str
Operator name, such as Convolution, FullyConnected Operator name, such as Convolution, FullyConnected
inputs : list of tvm.relay.expr.Function inputs : list of tvm.relay.function.Function
List of inputs. List of inputs.
attrs : dict attrs : dict
Dict of operator attributes Dict of operator attributes
...@@ -1783,7 +1784,7 @@ class GraphProto(object): ...@@ -1783,7 +1784,7 @@ class GraphProto(object):
Returns Returns
------- -------
sym : tvm.relay.expr.Function sym : tvm.relay.function.Function
Converted relay function Converted relay function
""" """
convert_map = _get_convert_map(opset) convert_map = _get_convert_map(opset)
......
...@@ -31,6 +31,7 @@ from tvm.relay.prelude import Prelude ...@@ -31,6 +31,7 @@ from tvm.relay.prelude import Prelude
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import function as _function
from .. import op as _op from .. import op as _op
from ..expr_functor import ExprMutator from ..expr_functor import ExprMutator
from .common import AttrCvt, get_relay_op from .common import AttrCvt, get_relay_op
...@@ -2461,7 +2462,7 @@ class GraphProto(object): ...@@ -2461,7 +2462,7 @@ class GraphProto(object):
out.append(out_rnn) out.append(out_rnn)
out = out[0] if len(out) == 1 else _expr.Tuple(out) out = out[0] if len(out) == 1 else _expr.Tuple(out)
func = _expr.Function(analysis.free_vars(out), out) func = _function.Function(analysis.free_vars(out), out)
self._mod["main"] = func self._mod["main"] = func
return self._mod, self._params return self._mod, self._params
......
...@@ -24,6 +24,7 @@ from tvm.ir import IRModule ...@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from tvm import relay from tvm import relay
from .. import analysis from .. import analysis
from .. import expr as _expr from .. import expr as _expr
from .. import function as _function
from .. import op as _op from .. import op as _op
from .. import qnn as _qnn from .. import qnn as _qnn
from ... import nd as _nd from ... import nd as _nd
...@@ -2365,6 +2366,6 @@ def from_tflite(model, shape_dict, dtype_dict): ...@@ -2365,6 +2366,6 @@ def from_tflite(model, shape_dict, dtype_dict):
params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()} params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()}
outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs] outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs) outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(analysis.free_vars(outputs), outputs) func = _function.Function(analysis.free_vars(outputs), outputs)
mod = IRModule.from_expr(func) mod = IRModule.from_expr(func)
return mod, params return mod, params
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, invalid-name, unused-import
"""The expression nodes of Relay."""
from __future__ import absolute_import
import tvm._ffi
from tvm.runtime import convert
from tvm.ir import BaseFunc
from .expr import Call
from . import _ffi_api
@tvm._ffi.register_object("relay.Function")
class Function(BaseFunc):
"""A function declaration expression.
Parameters
----------
params: List[tvm.relay.Var]
List of input parameters to the function.
body: tvm.relay.Expr
The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def __init__(self,
params,
body,
ret_type=None,
type_params=None,
attrs=None):
if type_params is None:
type_params = convert([])
self.__init_handle_by_constructor__(
_ffi_api.Function, params, body, ret_type, type_params, attrs)
def __call__(self, *args):
"""Invoke the global function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return Call(self, args, None, None)
def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.FunctionWithAttr(
self, attr_key, convert(attr_value))
...@@ -20,6 +20,7 @@ Utilities for building Relay loops. ...@@ -20,6 +20,7 @@ Utilities for building Relay loops.
""" """
from .scope_builder import ScopeBuilder from .scope_builder import ScopeBuilder
from . import expr as _expr from . import expr as _expr
from . import function as _function
def while_loop(cond, loop_vars, loop_bodies): def while_loop(cond, loop_vars, loop_bodies):
""" """
...@@ -60,6 +61,6 @@ def while_loop(cond, loop_vars, loop_bodies): ...@@ -60,6 +61,6 @@ def while_loop(cond, loop_vars, loop_bodies):
with sb.else_scope(): with sb.else_scope():
sb.ret(_expr.Tuple(fresh_vars)) sb.ret(_expr.Tuple(fresh_vars))
func = _expr.Function(fresh_vars, sb.get()) func = _function.Function(fresh_vars, sb.get())
let = _expr.Let(loop, func, loop) let = _expr.Let(loop, func, loop)
return let return let
...@@ -19,7 +19,8 @@ ...@@ -19,7 +19,8 @@
from tvm.ir import IRModule from tvm.ir import IRModule
from .ty import GlobalTypeVar, TensorType, Any, scalar_type from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, Function, GlobalVar, If, const from .expr import Var, GlobalVar, If, const
from .function import Function
from .op.tensor import add, subtract, equal from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard from .adt import PatternConstructor, PatternVar, PatternWildcard
......
...@@ -21,7 +21,8 @@ test cases for recursion and pattern matching.""" ...@@ -21,7 +21,8 @@ test cases for recursion and pattern matching."""
from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar
from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay.backend.interpreter import ConstructorValue
from tvm.relay.expr import Var, Function, GlobalVar from tvm.relay.expr import Var, GlobalVar
from tvm.relay.function import Function
from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType
def define_nat_adt(prelude): def define_nat_adt(prelude):
......
...@@ -23,7 +23,8 @@ import tvm ...@@ -23,7 +23,8 @@ import tvm
from tvm import relay from tvm import relay
from tvm.relay.adt import Pattern from tvm.relay.adt import Pattern
from tvm.relay.backend import compile_engine from tvm.relay.backend import compile_engine
from tvm.relay.expr import Expr, Function, GlobalVar, Var from tvm.relay.expr import Expr, GlobalVar, Var
from tvm.relay.function import Function
from tvm.relay.expr_functor import ExprFunctor from tvm.relay.expr_functor import ExprFunctor
OUTPUT_VAR_NAME = '_py_out' OUTPUT_VAR_NAME = '_py_out'
......
...@@ -27,10 +27,10 @@ namespace tvm { ...@@ -27,10 +27,10 @@ namespace tvm {
namespace relay { namespace relay {
Function::Function(tvm::Array<Var> params, Function::Function(tvm::Array<Var> params,
Expr body, Expr body,
Type ret_type, Type ret_type,
tvm::Array<TypeVar> type_params, tvm::Array<TypeVar> type_params,
DictAttrs attrs) { DictAttrs attrs) {
ObjectPtr<FunctionNode> n = make_object<FunctionNode>(); ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
CHECK(params.defined()); CHECK(params.defined());
CHECK(type_params.defined()); CHECK(type_params.defined());
...@@ -66,7 +66,6 @@ TVM_REGISTER_GLOBAL("relay.ir.Function") ...@@ -66,7 +66,6 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
return Function(params, body, ret_type, ty_params, attrs); return Function(params, body, ret_type, ty_params, attrs);
}); });
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) { .set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionNode*>(ref.get()); auto* node = static_cast<const FunctionNode*>(ref.get());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment