Unverified Commit 7ca3212f by Zhi Committed by GitHub

create function.py (#5087)

parent 06bbc7c9
......@@ -35,9 +35,6 @@ tvm.relay.expr
.. autoclass:: tvm.relay.expr.Tuple
:members:
.. autoclass:: tvm.relay.expr.Function
:members:
.. autoclass:: tvm.relay.expr.Call
:members:
......
......@@ -120,7 +120,7 @@ Additionally, functions in Relay are higher-order, which means that a function c
function or returned by a function, as function expressions evaluate to closures (see the `Closures`_ subsection),
which are values like tensors and tuples.
See :py:class:`~tvm.relay.expr.Function` for the definition and documentation of function nodes.
See :py:class:`~tvm.relay.function.Function` for the definition and documentation of function nodes.
Syntax
~~~~~~
......
......@@ -69,7 +69,7 @@ class BaseGraphTuner(object):
target_op in the input graph and layout transformation benchmark need to be
executed before initialization.
graph : tvm.relay.Expr.Function
graph : tvm.relay.function.Function
Input graph
input_shapes : dict of str to tuple.
......@@ -143,7 +143,7 @@ class BaseGraphTuner(object):
if isinstance(graph, tvm.IRModule):
graph = graph["main"]
if isinstance(graph, relay.expr.Function):
if isinstance(graph, relay.function.Function):
node_dict = {}
graph = bind_inputs(graph, input_shapes, dtype)
expr2graph(graph, self._target_ops, node_dict, self._node_list)
......
......@@ -21,7 +21,8 @@ import threading
import tvm
from tvm import relay, autotvm
from tvm.relay import transform
from tvm.relay.expr import Call, Function, TupleGetItem, Var, Constant, Tuple
from tvm.relay.expr import Call, TupleGetItem, Var, Constant, Tuple
from tvm.relay.function import Function
from tvm.relay.ty import TupleType, TensorType
from tvm.autotvm.task import TaskExtractEnv
......
......@@ -61,7 +61,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
Parameters
----------
mod: tvm.IRModule or relay.expr.Function
mod: tvm.IRModule or relay.function.Function
The module or function to tune
params: dict of str to numpy array
The associated parameters of the program
......@@ -88,7 +88,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
Parameters
----------
mods: List[tvm.IRModule] or List[relay.expr.Function]
mods: List[tvm.IRModule] or List[relay.function.Function]
The list of modules or functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
......@@ -118,7 +118,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
logger.disabled = True
for mod, param in zip(mods, params):
if isinstance(mod, relay.expr.Function):
if isinstance(mod, relay.function.Function):
mod = tvm.IRModule.from_expr(mod)
assert isinstance(mod, tvm.IRModule), \
"only support relay Module or Function to be tuned"
......
......@@ -22,6 +22,7 @@ from sys import setrecursionlimit
from . import base
from . import ty
from . import expr
from . import function
from . import type_functor
from . import expr_functor
from . import adt
......@@ -87,7 +88,7 @@ Constant = expr.Constant
Tuple = expr.Tuple
Var = expr.Var
GlobalVar = expr.GlobalVar
Function = expr.Function
Function = function.Function
Call = expr.Call
Let = expr.Let
If = expr.If
......
......@@ -43,6 +43,7 @@ from tvm.ir import IRModule
from .base import Span, SourceName
from . import adt
from . import expr
from . import function
from . import ty
from . import op
......@@ -481,7 +482,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def mk_func(
self,
ctx: Union[RelayParser.FuncContext, RelayParser.DefnContext]) \
-> expr.Function:
-> function.Function:
"""Construct a function from either a Func or Defn."""
# Enter var scope early to put params in scope.
self.enter_var_scope()
......@@ -511,10 +512,10 @@ class ParseTreeToRelayIR(RelayVisitor):
self.exit_var_scope()
attrs = tvm.ir.make_node("DictAttrs", **attr_list) if attr_list is not None else None
return expr.Function(var_list, body, ret_type, type_params, attrs)
return function.Function(var_list, body, ret_type, type_params, attrs)
@spanify
def visitFunc(self, ctx: RelayParser.FuncContext) -> expr.Function:
def visitFunc(self, ctx: RelayParser.FuncContext) -> function.Function:
return self.mk_func(ctx)
# TODO: how to set spans for definitions?
......
......@@ -421,7 +421,7 @@ def extract_fused_functions(mod):
Returns
-------
ret : Dict[int, tvm.relay.ir.expr.Function]
ret : Dict[int, tvm.relay.function.Function]
A module containing only fused primitive functions
"""
ret_mod = _ffi_api.ExtractFusedFunctions()(mod)
......
......@@ -25,7 +25,7 @@ from tvm import te
from tvm.runtime import Object
from ... import target as _target
from ... import autotvm
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import ty as _ty
from . import _backend
......@@ -65,7 +65,7 @@ class CCacheValue(Object):
def _get_cache_key(source_func, target):
if isinstance(source_func, _expr.Function):
if isinstance(source_func, _function.Function):
if isinstance(target, str):
target = _target.create(target)
if not target:
......
......@@ -27,7 +27,8 @@ from tvm.ir import IRModule
from . import _backend
from .. import _make, analysis, transform
from ... import nd
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, const
from ..function import Function
from ..scope_builder import ScopeBuilder
......
......@@ -29,6 +29,7 @@ from ..contrib import graph_runtime as _graph_rt
from . import _build_module
from . import ty as _ty
from . import expr as _expr
from . import function as _function
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor
......@@ -218,16 +219,16 @@ def build(mod, target=None, target_host=None, params=None):
params : dict
The parameters of the final graph.
"""
if not isinstance(mod, (IRModule, _expr.Function)):
if not isinstance(mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")
if isinstance(mod, _expr.Function):
if isinstance(mod, _function.Function):
if params:
mod = bind_params_by_name(mod, params)
mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter mod (tvm.relay.expr.Function)",
"instead of deprecated parameter mod (tvm.relay.function.Function)",
DeprecationWarning)
target = _update_target(target)
......@@ -276,16 +277,16 @@ def optimize(mod, target=None, params=None):
params : dict
The parameters of the final graph.
"""
if not isinstance(mod, (IRModule, _expr.Function)):
if not isinstance(mod, (IRModule, _function.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")
if isinstance(mod, _expr.Function):
if isinstance(mod, _function.Function):
if params:
mod = bind_params_by_name(mod, params)
mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)",
"instead of deprecated parameter func (tvm.relay.function.Function)",
DeprecationWarning)
target = _update_target(target)
......
......@@ -22,8 +22,8 @@ from numbers import Number as _Number
import numpy as _np
import tvm._ffi
from tvm._ffi import base as _base
from tvm.runtime import NDArray, convert, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar, BaseFunc
from tvm.runtime import NDArray, ndarray as _nd
from tvm.ir import RelayExpr, GlobalVar
from .base import RelayNode
from . import _ffi_api
......@@ -225,68 +225,6 @@ class Var(ExprWithOp):
return name
@tvm._ffi.register_object("relay.Function")
class Function(BaseFunc):
"""A function declaration expression.
Parameters
----------
params: List[tvm.relay.Var]
List of input parameters to the function.
body: tvm.relay.Expr
The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def __init__(self,
params,
body,
ret_type=None,
type_params=None,
attrs=None):
if type_params is None:
type_params = convert([])
self.__init_handle_by_constructor__(
_ffi_api.Function, params, body, ret_type, type_params, attrs)
def __call__(self, *args):
"""Invoke the global function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return Call(self, args, None, None)
def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.FunctionWithAttr(
self, attr_key, convert(attr_value))
@tvm._ffi.register_object("relay.Call")
class Call(ExprWithOp):
"""Function call node in Relay.
......
......@@ -17,7 +17,8 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression functor of Relay."""
from .expr import Function, Call, Let, Var, GlobalVar
from .function import Function
from .expr import Call, Let, Var, GlobalVar
from .expr import If, Tuple, TupleGetItem, Constant
from .expr import RefCreate, RefRead, RefWrite
from .adt import Constructor, Match, Clause
......
......@@ -21,6 +21,7 @@ from tvm.ir import IRModule
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from ... import nd as _nd
from .common import AttrCvt, Renamer
......@@ -451,7 +452,7 @@ class Caffe2NetDef(object):
else:
outputs = out[0]
func = _expr.Function(analysis.free_vars(outputs), outputs)
func = _function.Function(analysis.free_vars(outputs), outputs)
self._mod["main"] = func
return self._mod, self._params
......@@ -517,7 +518,7 @@ class Caffe2NetDef(object):
----------
op_type : str
Operator name, such as Convolution, FullyConnected
inputs : list of tvm.relay.expr.Function
inputs : list of tvm.relay.function.Function
List of input inputs.
args : dict
Dict of operator attributes
......@@ -530,7 +531,7 @@ class Caffe2NetDef(object):
Returns
-------
func : tvm.relay.expr.Function
func : tvm.relay.function.Function
Converted relay function
"""
identity_list = identity_list if identity_list else _identity_list
......
......@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from topi.util import get_const_tuple
from .. import expr as _expr
from .. import function as _function
from .. import transform as _transform
from .. import op as _op
from .. import analysis
......@@ -459,7 +460,7 @@ def infer_type(node, mod=None):
new_mod.update(mod)
new_mod = _transform.InferType()(new_mod)
entry = new_mod["main"]
return entry if isinstance(node, _expr.Function) else entry.body
return entry if isinstance(node, _function.Function) else entry.body
def infer_shape(inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
......@@ -491,7 +492,7 @@ def infer_value(input_val, params):
# Check that all free variables have associated parameters.
assert all(var.name_hint in params.keys() for var in analysis.free_vars(
input_val)), "All inputs to infer must be available in params."
func = _expr.Function(analysis.free_vars(input_val), input_val)
func = _function.Function(analysis.free_vars(input_val), input_val)
with tvm.relay.build_config(opt_level=0):
graph, lib, params = tvm.relay.build(func, target="llvm", params=params)
ctx = tvm.cpu(0)
......
......@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from ... import nd as _nd
from ..._ffi import base as _base
......@@ -503,6 +504,6 @@ def from_coreml(model, shape=None):
for o in spec.description.output]
# for now return first output
outexpr = outexpr[0]
func = _expr.Function(analysis.free_vars(outexpr), outexpr)
func = _function.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return IRModule.from_expr(func), params
......@@ -26,6 +26,7 @@ from tvm.ir import IRModule
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .common import get_relay_op, new_var
__all__ = ['from_darknet']
......@@ -821,7 +822,7 @@ class GraphProto(object):
outputs = _as_list(sym) + self._outs
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
sym = _expr.Function(analysis.free_vars(outputs), outputs)
sym = _function.Function(analysis.free_vars(outputs), outputs)
return IRModule.from_expr(sym), self._tvmparams
def from_darknet(net,
......
......@@ -23,6 +23,7 @@ from tvm.ir import IRModule
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from ... import nd as _nd
from .common import ExprTable, new_var
......@@ -914,6 +915,6 @@ def from_keras(model, shape=None, layout='NCHW'):
outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \
for oc in model._output_coordinates]
outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
func = _expr.Function(analysis.free_vars(outexpr), outexpr)
func = _function.Function(analysis.free_vars(outexpr), outexpr)
params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
return IRModule.from_expr(func), params
......@@ -25,6 +25,7 @@ from tvm import relay
from topi.util import get_const_tuple
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import scope_builder as _scope_builder
from ... import nd as _nd
......@@ -1096,7 +1097,7 @@ def _mx_cond(inputs, attrs, subgraphs):
else_arg_dtype_info = [arg.type_annotation.dtype for arg in else_args]
else_func = _from_mxnet_impl(subgraphs[2], else_arg_shapes, else_arg_dtype_info)
sb.ret(_expr.Call(else_func, else_args))
func = _expr.Function(input_args, sb.get())
func = _function.Function(input_args, sb.get())
ret = _expr.Call(func, inputs)
if num_outputs > 1:
ret = _expr.TupleWrapper(ret, num_outputs)
......@@ -1969,7 +1970,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
outputs = [node_map[e[0]][e[1]] for e in jgraph["heads"]]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(analysis.free_vars(outputs), outputs)
func = _function.Function(analysis.free_vars(outputs), outputs)
return func
......
......@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from ... import nd as _nd
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .common import AttrCvt, Renamer
from .common import get_relay_op, new_var, infer_shape, infer_channels
......@@ -1708,7 +1709,7 @@ class GraphProto(object):
# now return the outputs
outputs = [self._nodes[self._parse_value_proto(i)] for i in graph.output]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(analysis.free_vars(outputs), outputs)
func = _function.Function(analysis.free_vars(outputs), outputs)
return IRModule.from_expr(func), self._params
def _parse_value_proto(self, value_proto):
......@@ -1774,7 +1775,7 @@ class GraphProto(object):
----------
op_name : str
Operator name, such as Convolution, FullyConnected
inputs : list of tvm.relay.expr.Function
inputs : list of tvm.relay.function.Function
List of inputs.
attrs : dict
Dict of operator attributes
......@@ -1783,7 +1784,7 @@ class GraphProto(object):
Returns
-------
sym : tvm.relay.expr.Function
sym : tvm.relay.function.Function
Converted relay function
"""
convert_map = _get_convert_map(opset)
......
......@@ -31,6 +31,7 @@ from tvm.relay.prelude import Prelude
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from ..expr_functor import ExprMutator
from .common import AttrCvt, get_relay_op
......@@ -2461,7 +2462,7 @@ class GraphProto(object):
out.append(out_rnn)
out = out[0] if len(out) == 1 else _expr.Tuple(out)
func = _expr.Function(analysis.free_vars(out), out)
func = _function.Function(analysis.free_vars(out), out)
self._mod["main"] = func
return self._mod, self._params
......
......@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from tvm import relay
from .. import analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import qnn as _qnn
from ... import nd as _nd
......@@ -2365,6 +2366,6 @@ def from_tflite(model, shape_dict, dtype_dict):
params = {k:_nd.array(np.array(v)) for k, v in exp_tab.params.items()}
outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in model_outputs]
outputs = outputs[0] if len(outputs) == 1 else _expr.Tuple(outputs)
func = _expr.Function(analysis.free_vars(outputs), outputs)
func = _function.Function(analysis.free_vars(outputs), outputs)
mod = IRModule.from_expr(func)
return mod, params
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, invalid-name, unused-import
"""The expression nodes of Relay."""
from __future__ import absolute_import
import tvm._ffi
from tvm.runtime import convert
from tvm.ir import BaseFunc
from .expr import Call
from . import _ffi_api
@tvm._ffi.register_object("relay.Function")
class Function(BaseFunc):
"""A function declaration expression.
Parameters
----------
params: List[tvm.relay.Var]
List of input parameters to the function.
body: tvm.relay.Expr
The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def __init__(self,
params,
body,
ret_type=None,
type_params=None,
attrs=None):
if type_params is None:
type_params = convert([])
self.__init_handle_by_constructor__(
_ffi_api.Function, params, body, ret_type, type_params, attrs)
def __call__(self, *args):
"""Invoke the global function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return Call(self, args, None, None)
def with_attr(self, attr_key, attr_value):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return _ffi_api.FunctionWithAttr(
self, attr_key, convert(attr_value))
......@@ -20,6 +20,7 @@ Utilities for building Relay loops.
"""
from .scope_builder import ScopeBuilder
from . import expr as _expr
from . import function as _function
def while_loop(cond, loop_vars, loop_bodies):
"""
......@@ -60,6 +61,6 @@ def while_loop(cond, loop_vars, loop_bodies):
with sb.else_scope():
sb.ret(_expr.Tuple(fresh_vars))
func = _expr.Function(fresh_vars, sb.get())
func = _function.Function(fresh_vars, sb.get())
let = _expr.Let(loop, func, loop)
return let
......@@ -19,7 +19,8 @@
from tvm.ir import IRModule
from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, Function, GlobalVar, If, const
from .expr import Var, GlobalVar, If, const
from .function import Function
from .op.tensor import add, subtract, equal
from .adt import Constructor, TypeData, Clause, Match
from .adt import PatternConstructor, PatternVar, PatternWildcard
......
......@@ -21,7 +21,8 @@ test cases for recursion and pattern matching."""
from tvm.relay.adt import Constructor, TypeData, Clause, Match, PatternConstructor, PatternVar
from tvm.relay.backend.interpreter import ConstructorValue
from tvm.relay.expr import Var, Function, GlobalVar
from tvm.relay.expr import Var, GlobalVar
from tvm.relay.function import Function
from tvm.relay.ty import GlobalTypeVar, TypeVar, FuncType
def define_nat_adt(prelude):
......
......@@ -23,7 +23,8 @@ import tvm
from tvm import relay
from tvm.relay.adt import Pattern
from tvm.relay.backend import compile_engine
from tvm.relay.expr import Expr, Function, GlobalVar, Var
from tvm.relay.expr import Expr, GlobalVar, Var
from tvm.relay.function import Function
from tvm.relay.expr_functor import ExprFunctor
OUTPUT_VAR_NAME = '_py_out'
......
......@@ -27,10 +27,10 @@ namespace tvm {
namespace relay {
Function::Function(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> type_params,
DictAttrs attrs) {
Expr body,
Type ret_type,
tvm::Array<TypeVar> type_params,
DictAttrs attrs) {
ObjectPtr<FunctionNode> n = make_object<FunctionNode>();
CHECK(params.defined());
CHECK(type_params.defined());
......@@ -66,7 +66,6 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
return Function(params, body, ret_type, ty_params, attrs);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const FunctionNode*>(ref.get());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment