Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
7ca3212f
Unverified
Commit
7ca3212f
authored
Mar 18, 2020
by
Zhi
Committed by
GitHub
Mar 18, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
create function.py (#5087)
parent
06bbc7c9
Show whitespace changes
Inline
Side-by-side
Showing
28 changed files
with
148 additions
and
109 deletions
+148
-109
docs/api/python/relay/expr.rst
+0
-3
docs/langref/relay_expr.rst
+1
-1
python/tvm/autotvm/graph_tuner/base_graph_tuner.py
+2
-2
python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
+2
-1
python/tvm/autotvm/task/relay_integration.py
+3
-3
python/tvm/relay/__init__.py
+2
-1
python/tvm/relay/_parser.py
+4
-3
python/tvm/relay/analysis/analysis.py
+1
-1
python/tvm/relay/backend/compile_engine.py
+2
-2
python/tvm/relay/backend/interpreter.py
+2
-1
python/tvm/relay/build_module.py
+7
-6
python/tvm/relay/expr.py
+2
-64
python/tvm/relay/expr_functor.py
+2
-1
python/tvm/relay/frontend/caffe2.py
+4
-3
python/tvm/relay/frontend/common.py
+3
-2
python/tvm/relay/frontend/coreml.py
+2
-1
python/tvm/relay/frontend/darknet.py
+2
-1
python/tvm/relay/frontend/keras.py
+2
-1
python/tvm/relay/frontend/mxnet.py
+3
-2
python/tvm/relay/frontend/onnx.py
+4
-3
python/tvm/relay/frontend/tensorflow.py
+2
-1
python/tvm/relay/frontend/tflite.py
+2
-1
python/tvm/relay/function.py
+86
-0
python/tvm/relay/loops.py
+2
-1
python/tvm/relay/prelude.py
+2
-1
python/tvm/relay/testing/nat.py
+2
-1
python/tvm/relay/testing/py_converter.py
+2
-1
src/relay/ir/function.cc
+0
-1
No files found.
docs/api/python/relay/expr.rst
View file @
7ca3212f
...
...
@@ -35,9 +35,6 @@ tvm.relay.expr
.. autoclass:: tvm.relay.expr.Tuple
:members:
.. autoclass:: tvm.relay.expr.Function
:members:
.. autoclass:: tvm.relay.expr.Call
:members:
...
...
docs/langref/relay_expr.rst
View file @
7ca3212f
...
...
@@ -120,7 +120,7 @@ Additionally, functions in Relay are higher-order, which means that a function c
function or returned by a function, as function expressions evaluate to closures (see the `Closures`_ subsection),
which are values like tensors and tuples.
See :py:class:`~tvm.relay.
expr
.Function` for the definition and documentation of function nodes.
See :py:class:`~tvm.relay.
function
.Function` for the definition and documentation of function nodes.
Syntax
~~~~~~
...
...
python/tvm/autotvm/graph_tuner/base_graph_tuner.py
View file @
7ca3212f
...
...
@@ -69,7 +69,7 @@ class BaseGraphTuner(object):
target_op in the input graph and layout transformation benchmark need to be
executed before initialization.
graph : tvm.relay.
Expr
.Function
graph : tvm.relay.
function
.Function
Input graph
input_shapes : dict of str to tuple.
...
...
@@ -143,7 +143,7 @@ class BaseGraphTuner(object):
if
isinstance
(
graph
,
tvm
.
IRModule
):
graph
=
graph
[
"main"
]
if
isinstance
(
graph
,
relay
.
expr
.
Function
):
if
isinstance
(
graph
,
relay
.
function
.
Function
):
node_dict
=
{}
graph
=
bind_inputs
(
graph
,
input_shapes
,
dtype
)
expr2graph
(
graph
,
self
.
_target_ops
,
node_dict
,
self
.
_node_list
)
...
...
python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
View file @
7ca3212f
...
...
@@ -21,7 +21,8 @@ import threading
import
tvm
from
tvm
import
relay
,
autotvm
from
tvm.relay
import
transform
from
tvm.relay.expr
import
Call
,
Function
,
TupleGetItem
,
Var
,
Constant
,
Tuple
from
tvm.relay.expr
import
Call
,
TupleGetItem
,
Var
,
Constant
,
Tuple
from
tvm.relay.function
import
Function
from
tvm.relay.ty
import
TupleType
,
TensorType
from
tvm.autotvm.task
import
TaskExtractEnv
...
...
python/tvm/autotvm/task/relay_integration.py
View file @
7ca3212f
...
...
@@ -61,7 +61,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
Parameters
----------
mod: tvm.IRModule or relay.
expr
.Function
mod: tvm.IRModule or relay.
function
.Function
The module or function to tune
params: dict of str to numpy array
The associated parameters of the program
...
...
@@ -88,7 +88,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
Parameters
----------
mods: List[tvm.IRModule] or List[relay.
expr
.Function]
mods: List[tvm.IRModule] or List[relay.
function
.Function]
The list of modules or functions to tune
params: List of dict of str to numpy array
The associated parameters of the programs
...
...
@@ -118,7 +118,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
logger
.
disabled
=
True
for
mod
,
param
in
zip
(
mods
,
params
):
if
isinstance
(
mod
,
relay
.
expr
.
Function
):
if
isinstance
(
mod
,
relay
.
function
.
Function
):
mod
=
tvm
.
IRModule
.
from_expr
(
mod
)
assert
isinstance
(
mod
,
tvm
.
IRModule
),
\
"only support relay Module or Function to be tuned"
...
...
python/tvm/relay/__init__.py
View file @
7ca3212f
...
...
@@ -22,6 +22,7 @@ from sys import setrecursionlimit
from
.
import
base
from
.
import
ty
from
.
import
expr
from
.
import
function
from
.
import
type_functor
from
.
import
expr_functor
from
.
import
adt
...
...
@@ -87,7 +88,7 @@ Constant = expr.Constant
Tuple
=
expr
.
Tuple
Var
=
expr
.
Var
GlobalVar
=
expr
.
GlobalVar
Function
=
expr
.
Function
Function
=
function
.
Function
Call
=
expr
.
Call
Let
=
expr
.
Let
If
=
expr
.
If
...
...
python/tvm/relay/_parser.py
View file @
7ca3212f
...
...
@@ -43,6 +43,7 @@ from tvm.ir import IRModule
from
.base
import
Span
,
SourceName
from
.
import
adt
from
.
import
expr
from
.
import
function
from
.
import
ty
from
.
import
op
...
...
@@ -481,7 +482,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def
mk_func
(
self
,
ctx
:
Union
[
RelayParser
.
FuncContext
,
RelayParser
.
DefnContext
])
\
->
expr
.
Function
:
->
function
.
Function
:
"""Construct a function from either a Func or Defn."""
# Enter var scope early to put params in scope.
self
.
enter_var_scope
()
...
...
@@ -511,10 +512,10 @@ class ParseTreeToRelayIR(RelayVisitor):
self
.
exit_var_scope
()
attrs
=
tvm
.
ir
.
make_node
(
"DictAttrs"
,
**
attr_list
)
if
attr_list
is
not
None
else
None
return
expr
.
Function
(
var_list
,
body
,
ret_type
,
type_params
,
attrs
)
return
function
.
Function
(
var_list
,
body
,
ret_type
,
type_params
,
attrs
)
@spanify
def
visitFunc
(
self
,
ctx
:
RelayParser
.
FuncContext
)
->
expr
.
Function
:
def
visitFunc
(
self
,
ctx
:
RelayParser
.
FuncContext
)
->
function
.
Function
:
return
self
.
mk_func
(
ctx
)
# TODO: how to set spans for definitions?
...
...
python/tvm/relay/analysis/analysis.py
View file @
7ca3212f
...
...
@@ -421,7 +421,7 @@ def extract_fused_functions(mod):
Returns
-------
ret : Dict[int, tvm.relay.
ir.expr
.Function]
ret : Dict[int, tvm.relay.
function
.Function]
A module containing only fused primitive functions
"""
ret_mod
=
_ffi_api
.
ExtractFusedFunctions
()(
mod
)
...
...
python/tvm/relay/backend/compile_engine.py
View file @
7ca3212f
...
...
@@ -25,7 +25,7 @@ from tvm import te
from
tvm.runtime
import
Object
from
...
import
target
as
_target
from
...
import
autotvm
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
..
import
ty
as
_ty
from
.
import
_backend
...
...
@@ -65,7 +65,7 @@ class CCacheValue(Object):
def
_get_cache_key
(
source_func
,
target
):
if
isinstance
(
source_func
,
_
expr
.
Function
):
if
isinstance
(
source_func
,
_
function
.
Function
):
if
isinstance
(
target
,
str
):
target
=
_target
.
create
(
target
)
if
not
target
:
...
...
python/tvm/relay/backend/interpreter.py
View file @
7ca3212f
...
...
@@ -27,7 +27,8 @@ from tvm.ir import IRModule
from
.
import
_backend
from
..
import
_make
,
analysis
,
transform
from
...
import
nd
from
..expr
import
Tuple
,
RefCreate
,
Call
,
Constant
,
GlobalVar
,
Function
,
const
from
..expr
import
Tuple
,
RefCreate
,
Call
,
Constant
,
GlobalVar
,
const
from
..function
import
Function
from
..scope_builder
import
ScopeBuilder
...
...
python/tvm/relay/build_module.py
View file @
7ca3212f
...
...
@@ -29,6 +29,7 @@ from ..contrib import graph_runtime as _graph_rt
from
.
import
_build_module
from
.
import
ty
as
_ty
from
.
import
expr
as
_expr
from
.
import
function
as
_function
from
.backend
import
interpreter
as
_interpreter
from
.backend.vm
import
VMExecutor
...
...
@@ -218,16 +219,16 @@ def build(mod, target=None, target_host=None, params=None):
params : dict
The parameters of the final graph.
"""
if
not
isinstance
(
mod
,
(
IRModule
,
_
expr
.
Function
)):
if
not
isinstance
(
mod
,
(
IRModule
,
_
function
.
Function
)):
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
if
isinstance
(
mod
,
_
expr
.
Function
):
if
isinstance
(
mod
,
_
function
.
Function
):
if
params
:
mod
=
bind_params_by_name
(
mod
,
params
)
mod
=
IRModule
.
from_expr
(
mod
)
warnings
.
warn
(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter mod (tvm.relay.
expr
.Function)"
,
"instead of deprecated parameter mod (tvm.relay.
function
.Function)"
,
DeprecationWarning
)
target
=
_update_target
(
target
)
...
...
@@ -276,16 +277,16 @@ def optimize(mod, target=None, params=None):
params : dict
The parameters of the final graph.
"""
if
not
isinstance
(
mod
,
(
IRModule
,
_
expr
.
Function
)):
if
not
isinstance
(
mod
,
(
IRModule
,
_
function
.
Function
)):
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
if
isinstance
(
mod
,
_
expr
.
Function
):
if
isinstance
(
mod
,
_
function
.
Function
):
if
params
:
mod
=
bind_params_by_name
(
mod
,
params
)
mod
=
IRModule
.
from_expr
(
mod
)
warnings
.
warn
(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.
expr
.Function)"
,
"instead of deprecated parameter func (tvm.relay.
function
.Function)"
,
DeprecationWarning
)
target
=
_update_target
(
target
)
...
...
python/tvm/relay/expr.py
View file @
7ca3212f
...
...
@@ -22,8 +22,8 @@ from numbers import Number as _Number
import
numpy
as
_np
import
tvm._ffi
from
tvm._ffi
import
base
as
_base
from
tvm.runtime
import
NDArray
,
convert
,
ndarray
as
_nd
from
tvm.ir
import
RelayExpr
,
GlobalVar
,
BaseFunc
from
tvm.runtime
import
NDArray
,
ndarray
as
_nd
from
tvm.ir
import
RelayExpr
,
GlobalVar
from
.base
import
RelayNode
from
.
import
_ffi_api
...
...
@@ -225,68 +225,6 @@ class Var(ExprWithOp):
return
name
@tvm._ffi.register_object
(
"relay.Function"
)
class
Function
(
BaseFunc
):
"""A function declaration expression.
Parameters
----------
params: List[tvm.relay.Var]
List of input parameters to the function.
body: tvm.relay.Expr
The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def
__init__
(
self
,
params
,
body
,
ret_type
=
None
,
type_params
=
None
,
attrs
=
None
):
if
type_params
is
None
:
type_params
=
convert
([])
self
.
__init_handle_by_constructor__
(
_ffi_api
.
Function
,
params
,
body
,
ret_type
,
type_params
,
attrs
)
def
__call__
(
self
,
*
args
):
"""Invoke the global function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return
Call
(
self
,
args
,
None
,
None
)
def
with_attr
(
self
,
attr_key
,
attr_value
):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return
_ffi_api
.
FunctionWithAttr
(
self
,
attr_key
,
convert
(
attr_value
))
@tvm._ffi.register_object
(
"relay.Call"
)
class
Call
(
ExprWithOp
):
"""Function call node in Relay.
...
...
python/tvm/relay/expr_functor.py
View file @
7ca3212f
...
...
@@ -17,7 +17,8 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression functor of Relay."""
from
.expr
import
Function
,
Call
,
Let
,
Var
,
GlobalVar
from
.function
import
Function
from
.expr
import
Call
,
Let
,
Var
,
GlobalVar
from
.expr
import
If
,
Tuple
,
TupleGetItem
,
Constant
from
.expr
import
RefCreate
,
RefRead
,
RefWrite
from
.adt
import
Constructor
,
Match
,
Clause
...
...
python/tvm/relay/frontend/caffe2.py
View file @
7ca3212f
...
...
@@ -21,6 +21,7 @@ from tvm.ir import IRModule
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
...
import
nd
as
_nd
from
.common
import
AttrCvt
,
Renamer
...
...
@@ -451,7 +452,7 @@ class Caffe2NetDef(object):
else
:
outputs
=
out
[
0
]
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
self
.
_mod
[
"main"
]
=
func
return
self
.
_mod
,
self
.
_params
...
...
@@ -517,7 +518,7 @@ class Caffe2NetDef(object):
----------
op_type : str
Operator name, such as Convolution, FullyConnected
inputs : list of tvm.relay.
expr
.Function
inputs : list of tvm.relay.
function
.Function
List of input inputs.
args : dict
Dict of operator attributes
...
...
@@ -530,7 +531,7 @@ class Caffe2NetDef(object):
Returns
-------
func : tvm.relay.
expr
.Function
func : tvm.relay.
function
.Function
Converted relay function
"""
identity_list
=
identity_list
if
identity_list
else
_identity_list
...
...
python/tvm/relay/frontend/common.py
View file @
7ca3212f
...
...
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from
topi.util
import
get_const_tuple
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
transform
as
_transform
from
..
import
op
as
_op
from
..
import
analysis
...
...
@@ -459,7 +460,7 @@ def infer_type(node, mod=None):
new_mod
.
update
(
mod
)
new_mod
=
_transform
.
InferType
()(
new_mod
)
entry
=
new_mod
[
"main"
]
return
entry
if
isinstance
(
node
,
_
expr
.
Function
)
else
entry
.
body
return
entry
if
isinstance
(
node
,
_
function
.
Function
)
else
entry
.
body
def
infer_shape
(
inputs
,
mod
=
None
):
"""A method to get the output type of an intermediate node in the graph."""
...
...
@@ -491,7 +492,7 @@ def infer_value(input_val, params):
# Check that all free variables have associated parameters.
assert
all
(
var
.
name_hint
in
params
.
keys
()
for
var
in
analysis
.
free_vars
(
input_val
)),
"All inputs to infer must be available in params."
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
input_val
),
input_val
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
input_val
),
input_val
)
with
tvm
.
relay
.
build_config
(
opt_level
=
0
):
graph
,
lib
,
params
=
tvm
.
relay
.
build
(
func
,
target
=
"llvm"
,
params
=
params
)
ctx
=
tvm
.
cpu
(
0
)
...
...
python/tvm/relay/frontend/coreml.py
View file @
7ca3212f
...
...
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
...
import
nd
as
_nd
from
..._ffi
import
base
as
_base
...
...
@@ -503,6 +504,6 @@ def from_coreml(model, shape=None):
for
o
in
spec
.
description
.
output
]
# for now return first output
outexpr
=
outexpr
[
0
]
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
outexpr
),
outexpr
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
outexpr
),
outexpr
)
params
=
{
k
:
_nd
.
array
(
np
.
array
(
v
,
dtype
=
np
.
float32
))
for
k
,
v
in
etab
.
params
.
items
()}
return
IRModule
.
from_expr
(
func
),
params
python/tvm/relay/frontend/darknet.py
View file @
7ca3212f
...
...
@@ -26,6 +26,7 @@ from tvm.ir import IRModule
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
.common
import
get_relay_op
,
new_var
__all__
=
[
'from_darknet'
]
...
...
@@ -821,7 +822,7 @@ class GraphProto(object):
outputs
=
_as_list
(
sym
)
+
self
.
_outs
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
sym
=
_
expr
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
sym
=
_
function
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
return
IRModule
.
from_expr
(
sym
),
self
.
_tvmparams
def
from_darknet
(
net
,
...
...
python/tvm/relay/frontend/keras.py
View file @
7ca3212f
...
...
@@ -23,6 +23,7 @@ from tvm.ir import IRModule
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
...
import
nd
as
_nd
from
.common
import
ExprTable
,
new_var
...
...
@@ -914,6 +915,6 @@ def from_keras(model, shape=None, layout='NCHW'):
outexpr
=
[
etab
.
get_expr
(
oc
[
0
]
.
name
+
":"
+
str
(
oc
[
1
])
+
":"
+
str
(
oc
[
2
]))
\
for
oc
in
model
.
_output_coordinates
]
outexpr
=
outexpr
[
0
]
if
len
(
outexpr
)
==
1
else
_expr
.
Tuple
(
outexpr
)
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
outexpr
),
outexpr
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
outexpr
),
outexpr
)
params
=
{
k
:
_nd
.
array
(
np
.
array
(
v
,
dtype
=
np
.
float32
))
for
k
,
v
in
etab
.
params
.
items
()}
return
IRModule
.
from_expr
(
func
),
params
python/tvm/relay/frontend/mxnet.py
View file @
7ca3212f
...
...
@@ -25,6 +25,7 @@ from tvm import relay
from
topi.util
import
get_const_tuple
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
..
import
scope_builder
as
_scope_builder
from
...
import
nd
as
_nd
...
...
@@ -1096,7 +1097,7 @@ def _mx_cond(inputs, attrs, subgraphs):
else_arg_dtype_info
=
[
arg
.
type_annotation
.
dtype
for
arg
in
else_args
]
else_func
=
_from_mxnet_impl
(
subgraphs
[
2
],
else_arg_shapes
,
else_arg_dtype_info
)
sb
.
ret
(
_expr
.
Call
(
else_func
,
else_args
))
func
=
_
expr
.
Function
(
input_args
,
sb
.
get
())
func
=
_
function
.
Function
(
input_args
,
sb
.
get
())
ret
=
_expr
.
Call
(
func
,
inputs
)
if
num_outputs
>
1
:
ret
=
_expr
.
TupleWrapper
(
ret
,
num_outputs
)
...
...
@@ -1969,7 +1970,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
outputs
=
[
node_map
[
e
[
0
]][
e
[
1
]]
for
e
in
jgraph
[
"heads"
]]
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
return
func
...
...
python/tvm/relay/frontend/onnx.py
View file @
7ca3212f
...
...
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from
...
import
nd
as
_nd
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
.common
import
AttrCvt
,
Renamer
from
.common
import
get_relay_op
,
new_var
,
infer_shape
,
infer_channels
...
...
@@ -1708,7 +1709,7 @@ class GraphProto(object):
# now return the outputs
outputs
=
[
self
.
_nodes
[
self
.
_parse_value_proto
(
i
)]
for
i
in
graph
.
output
]
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
return
IRModule
.
from_expr
(
func
),
self
.
_params
def
_parse_value_proto
(
self
,
value_proto
):
...
...
@@ -1774,7 +1775,7 @@ class GraphProto(object):
----------
op_name : str
Operator name, such as Convolution, FullyConnected
inputs : list of tvm.relay.
expr
.Function
inputs : list of tvm.relay.
function
.Function
List of inputs.
attrs : dict
Dict of operator attributes
...
...
@@ -1783,7 +1784,7 @@ class GraphProto(object):
Returns
-------
sym : tvm.relay.
expr
.Function
sym : tvm.relay.
function
.Function
Converted relay function
"""
convert_map
=
_get_convert_map
(
opset
)
...
...
python/tvm/relay/frontend/tensorflow.py
View file @
7ca3212f
...
...
@@ -31,6 +31,7 @@ from tvm.relay.prelude import Prelude
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
..expr_functor
import
ExprMutator
from
.common
import
AttrCvt
,
get_relay_op
...
...
@@ -2461,7 +2462,7 @@ class GraphProto(object):
out
.
append
(
out_rnn
)
out
=
out
[
0
]
if
len
(
out
)
==
1
else
_expr
.
Tuple
(
out
)
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
out
),
out
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
out
),
out
)
self
.
_mod
[
"main"
]
=
func
return
self
.
_mod
,
self
.
_params
...
...
python/tvm/relay/frontend/tflite.py
View file @
7ca3212f
...
...
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from
tvm
import
relay
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
..
import
qnn
as
_qnn
from
...
import
nd
as
_nd
...
...
@@ -2365,6 +2366,6 @@ def from_tflite(model, shape_dict, dtype_dict):
params
=
{
k
:
_nd
.
array
(
np
.
array
(
v
))
for
k
,
v
in
exp_tab
.
params
.
items
()}
outputs
=
[
exp_tab
.
get_expr
(
get_tensor_name
(
subgraph
,
i
))
for
i
in
model_outputs
]
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
mod
=
IRModule
.
from_expr
(
func
)
return
mod
,
params
python/tvm/relay/function.py
0 → 100644
View file @
7ca3212f
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, invalid-name, unused-import
"""The expression nodes of Relay."""
from
__future__
import
absolute_import
import
tvm._ffi
from
tvm.runtime
import
convert
from
tvm.ir
import
BaseFunc
from
.expr
import
Call
from
.
import
_ffi_api
@tvm._ffi.register_object
(
"relay.Function"
)
class
Function
(
BaseFunc
):
"""A function declaration expression.
Parameters
----------
params: List[tvm.relay.Var]
List of input parameters to the function.
body: tvm.relay.Expr
The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def
__init__
(
self
,
params
,
body
,
ret_type
=
None
,
type_params
=
None
,
attrs
=
None
):
if
type_params
is
None
:
type_params
=
convert
([])
self
.
__init_handle_by_constructor__
(
_ffi_api
.
Function
,
params
,
body
,
ret_type
,
type_params
,
attrs
)
def
__call__
(
self
,
*
args
):
"""Invoke the global function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return
Call
(
self
,
args
,
None
,
None
)
def
with_attr
(
self
,
attr_key
,
attr_value
):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return
_ffi_api
.
FunctionWithAttr
(
self
,
attr_key
,
convert
(
attr_value
))
python/tvm/relay/loops.py
View file @
7ca3212f
...
...
@@ -20,6 +20,7 @@ Utilities for building Relay loops.
"""
from
.scope_builder
import
ScopeBuilder
from
.
import
expr
as
_expr
from
.
import
function
as
_function
def
while_loop
(
cond
,
loop_vars
,
loop_bodies
):
"""
...
...
@@ -60,6 +61,6 @@ def while_loop(cond, loop_vars, loop_bodies):
with
sb
.
else_scope
():
sb
.
ret
(
_expr
.
Tuple
(
fresh_vars
))
func
=
_
expr
.
Function
(
fresh_vars
,
sb
.
get
())
func
=
_
function
.
Function
(
fresh_vars
,
sb
.
get
())
let
=
_expr
.
Let
(
loop
,
func
,
loop
)
return
let
python/tvm/relay/prelude.py
View file @
7ca3212f
...
...
@@ -19,7 +19,8 @@
from
tvm.ir
import
IRModule
from
.ty
import
GlobalTypeVar
,
TensorType
,
Any
,
scalar_type
from
.expr
import
Var
,
Function
,
GlobalVar
,
If
,
const
from
.expr
import
Var
,
GlobalVar
,
If
,
const
from
.function
import
Function
from
.op.tensor
import
add
,
subtract
,
equal
from
.adt
import
Constructor
,
TypeData
,
Clause
,
Match
from
.adt
import
PatternConstructor
,
PatternVar
,
PatternWildcard
...
...
python/tvm/relay/testing/nat.py
View file @
7ca3212f
...
...
@@ -21,7 +21,8 @@ test cases for recursion and pattern matching."""
from
tvm.relay.adt
import
Constructor
,
TypeData
,
Clause
,
Match
,
PatternConstructor
,
PatternVar
from
tvm.relay.backend.interpreter
import
ConstructorValue
from
tvm.relay.expr
import
Var
,
Function
,
GlobalVar
from
tvm.relay.expr
import
Var
,
GlobalVar
from
tvm.relay.function
import
Function
from
tvm.relay.ty
import
GlobalTypeVar
,
TypeVar
,
FuncType
def
define_nat_adt
(
prelude
):
...
...
python/tvm/relay/testing/py_converter.py
View file @
7ca3212f
...
...
@@ -23,7 +23,8 @@ import tvm
from
tvm
import
relay
from
tvm.relay.adt
import
Pattern
from
tvm.relay.backend
import
compile_engine
from
tvm.relay.expr
import
Expr
,
Function
,
GlobalVar
,
Var
from
tvm.relay.expr
import
Expr
,
GlobalVar
,
Var
from
tvm.relay.function
import
Function
from
tvm.relay.expr_functor
import
ExprFunctor
OUTPUT_VAR_NAME
=
'_py_out'
...
...
src/relay/ir/function.cc
View file @
7ca3212f
...
...
@@ -66,7 +66,6 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
return
Function
(
params
,
body
,
ret_type
,
ty_params
,
attrs
);
});
TVM_STATIC_IR_FUNCTOR
(
ReprPrinter
,
vtable
)
.
set_dispatch
<
FunctionNode
>
([](
const
ObjectRef
&
ref
,
ReprPrinter
*
p
)
{
auto
*
node
=
static_cast
<
const
FunctionNode
*>
(
ref
.
get
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment