Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
7ca3212f
Unverified
Commit
7ca3212f
authored
Mar 18, 2020
by
Zhi
Committed by
GitHub
Mar 18, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
create function.py (#5087)
parent
06bbc7c9
Hide whitespace changes
Inline
Side-by-side
Showing
28 changed files
with
152 additions
and
113 deletions
+152
-113
docs/api/python/relay/expr.rst
+0
-3
docs/langref/relay_expr.rst
+1
-1
python/tvm/autotvm/graph_tuner/base_graph_tuner.py
+2
-2
python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
+2
-1
python/tvm/autotvm/task/relay_integration.py
+3
-3
python/tvm/relay/__init__.py
+2
-1
python/tvm/relay/_parser.py
+4
-3
python/tvm/relay/analysis/analysis.py
+1
-1
python/tvm/relay/backend/compile_engine.py
+2
-2
python/tvm/relay/backend/interpreter.py
+2
-1
python/tvm/relay/build_module.py
+7
-6
python/tvm/relay/expr.py
+2
-64
python/tvm/relay/expr_functor.py
+2
-1
python/tvm/relay/frontend/caffe2.py
+4
-3
python/tvm/relay/frontend/common.py
+3
-2
python/tvm/relay/frontend/coreml.py
+2
-1
python/tvm/relay/frontend/darknet.py
+2
-1
python/tvm/relay/frontend/keras.py
+2
-1
python/tvm/relay/frontend/mxnet.py
+3
-2
python/tvm/relay/frontend/onnx.py
+4
-3
python/tvm/relay/frontend/tensorflow.py
+2
-1
python/tvm/relay/frontend/tflite.py
+2
-1
python/tvm/relay/function.py
+86
-0
python/tvm/relay/loops.py
+2
-1
python/tvm/relay/prelude.py
+2
-1
python/tvm/relay/testing/nat.py
+2
-1
python/tvm/relay/testing/py_converter.py
+2
-1
src/relay/ir/function.cc
+4
-5
No files found.
docs/api/python/relay/expr.rst
View file @
7ca3212f
...
@@ -35,9 +35,6 @@ tvm.relay.expr
...
@@ -35,9 +35,6 @@ tvm.relay.expr
.. autoclass:: tvm.relay.expr.Tuple
.. autoclass:: tvm.relay.expr.Tuple
:members:
:members:
.. autoclass:: tvm.relay.expr.Function
:members:
.. autoclass:: tvm.relay.expr.Call
.. autoclass:: tvm.relay.expr.Call
:members:
:members:
...
...
docs/langref/relay_expr.rst
View file @
7ca3212f
...
@@ -120,7 +120,7 @@ Additionally, functions in Relay are higher-order, which means that a function c
...
@@ -120,7 +120,7 @@ Additionally, functions in Relay are higher-order, which means that a function c
function or returned by a function, as function expressions evaluate to closures (see the `Closures`_ subsection),
function or returned by a function, as function expressions evaluate to closures (see the `Closures`_ subsection),
which are values like tensors and tuples.
which are values like tensors and tuples.
See :py:class:`~tvm.relay.
expr
.Function` for the definition and documentation of function nodes.
See :py:class:`~tvm.relay.
function
.Function` for the definition and documentation of function nodes.
Syntax
Syntax
~~~~~~
~~~~~~
...
...
python/tvm/autotvm/graph_tuner/base_graph_tuner.py
View file @
7ca3212f
...
@@ -69,7 +69,7 @@ class BaseGraphTuner(object):
...
@@ -69,7 +69,7 @@ class BaseGraphTuner(object):
target_op in the input graph and layout transformation benchmark need to be
target_op in the input graph and layout transformation benchmark need to be
executed before initialization.
executed before initialization.
graph : tvm.relay.
Expr
.Function
graph : tvm.relay.
function
.Function
Input graph
Input graph
input_shapes : dict of str to tuple.
input_shapes : dict of str to tuple.
...
@@ -143,7 +143,7 @@ class BaseGraphTuner(object):
...
@@ -143,7 +143,7 @@ class BaseGraphTuner(object):
if
isinstance
(
graph
,
tvm
.
IRModule
):
if
isinstance
(
graph
,
tvm
.
IRModule
):
graph
=
graph
[
"main"
]
graph
=
graph
[
"main"
]
if
isinstance
(
graph
,
relay
.
expr
.
Function
):
if
isinstance
(
graph
,
relay
.
function
.
Function
):
node_dict
=
{}
node_dict
=
{}
graph
=
bind_inputs
(
graph
,
input_shapes
,
dtype
)
graph
=
bind_inputs
(
graph
,
input_shapes
,
dtype
)
expr2graph
(
graph
,
self
.
_target_ops
,
node_dict
,
self
.
_node_list
)
expr2graph
(
graph
,
self
.
_target_ops
,
node_dict
,
self
.
_node_list
)
...
...
python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
View file @
7ca3212f
...
@@ -21,7 +21,8 @@ import threading
...
@@ -21,7 +21,8 @@ import threading
import
tvm
import
tvm
from
tvm
import
relay
,
autotvm
from
tvm
import
relay
,
autotvm
from
tvm.relay
import
transform
from
tvm.relay
import
transform
from
tvm.relay.expr
import
Call
,
Function
,
TupleGetItem
,
Var
,
Constant
,
Tuple
from
tvm.relay.expr
import
Call
,
TupleGetItem
,
Var
,
Constant
,
Tuple
from
tvm.relay.function
import
Function
from
tvm.relay.ty
import
TupleType
,
TensorType
from
tvm.relay.ty
import
TupleType
,
TensorType
from
tvm.autotvm.task
import
TaskExtractEnv
from
tvm.autotvm.task
import
TaskExtractEnv
...
...
python/tvm/autotvm/task/relay_integration.py
View file @
7ca3212f
...
@@ -61,7 +61,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
...
@@ -61,7 +61,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
Parameters
Parameters
----------
----------
mod: tvm.IRModule or relay.
expr
.Function
mod: tvm.IRModule or relay.
function
.Function
The module or function to tune
The module or function to tune
params: dict of str to numpy array
params: dict of str to numpy array
The associated parameters of the program
The associated parameters of the program
...
@@ -88,7 +88,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
...
@@ -88,7 +88,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
Parameters
Parameters
----------
----------
mods: List[tvm.IRModule] or List[relay.
expr
.Function]
mods: List[tvm.IRModule] or List[relay.
function
.Function]
The list of modules or functions to tune
The list of modules or functions to tune
params: List of dict of str to numpy array
params: List of dict of str to numpy array
The associated parameters of the programs
The associated parameters of the programs
...
@@ -118,7 +118,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
...
@@ -118,7 +118,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
logger
.
disabled
=
True
logger
.
disabled
=
True
for
mod
,
param
in
zip
(
mods
,
params
):
for
mod
,
param
in
zip
(
mods
,
params
):
if
isinstance
(
mod
,
relay
.
expr
.
Function
):
if
isinstance
(
mod
,
relay
.
function
.
Function
):
mod
=
tvm
.
IRModule
.
from_expr
(
mod
)
mod
=
tvm
.
IRModule
.
from_expr
(
mod
)
assert
isinstance
(
mod
,
tvm
.
IRModule
),
\
assert
isinstance
(
mod
,
tvm
.
IRModule
),
\
"only support relay Module or Function to be tuned"
"only support relay Module or Function to be tuned"
...
...
python/tvm/relay/__init__.py
View file @
7ca3212f
...
@@ -22,6 +22,7 @@ from sys import setrecursionlimit
...
@@ -22,6 +22,7 @@ from sys import setrecursionlimit
from
.
import
base
from
.
import
base
from
.
import
ty
from
.
import
ty
from
.
import
expr
from
.
import
expr
from
.
import
function
from
.
import
type_functor
from
.
import
type_functor
from
.
import
expr_functor
from
.
import
expr_functor
from
.
import
adt
from
.
import
adt
...
@@ -87,7 +88,7 @@ Constant = expr.Constant
...
@@ -87,7 +88,7 @@ Constant = expr.Constant
Tuple
=
expr
.
Tuple
Tuple
=
expr
.
Tuple
Var
=
expr
.
Var
Var
=
expr
.
Var
GlobalVar
=
expr
.
GlobalVar
GlobalVar
=
expr
.
GlobalVar
Function
=
expr
.
Function
Function
=
function
.
Function
Call
=
expr
.
Call
Call
=
expr
.
Call
Let
=
expr
.
Let
Let
=
expr
.
Let
If
=
expr
.
If
If
=
expr
.
If
...
...
python/tvm/relay/_parser.py
View file @
7ca3212f
...
@@ -43,6 +43,7 @@ from tvm.ir import IRModule
...
@@ -43,6 +43,7 @@ from tvm.ir import IRModule
from
.base
import
Span
,
SourceName
from
.base
import
Span
,
SourceName
from
.
import
adt
from
.
import
adt
from
.
import
expr
from
.
import
expr
from
.
import
function
from
.
import
ty
from
.
import
ty
from
.
import
op
from
.
import
op
...
@@ -481,7 +482,7 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -481,7 +482,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def
mk_func
(
def
mk_func
(
self
,
self
,
ctx
:
Union
[
RelayParser
.
FuncContext
,
RelayParser
.
DefnContext
])
\
ctx
:
Union
[
RelayParser
.
FuncContext
,
RelayParser
.
DefnContext
])
\
->
expr
.
Function
:
->
function
.
Function
:
"""Construct a function from either a Func or Defn."""
"""Construct a function from either a Func or Defn."""
# Enter var scope early to put params in scope.
# Enter var scope early to put params in scope.
self
.
enter_var_scope
()
self
.
enter_var_scope
()
...
@@ -511,10 +512,10 @@ class ParseTreeToRelayIR(RelayVisitor):
...
@@ -511,10 +512,10 @@ class ParseTreeToRelayIR(RelayVisitor):
self
.
exit_var_scope
()
self
.
exit_var_scope
()
attrs
=
tvm
.
ir
.
make_node
(
"DictAttrs"
,
**
attr_list
)
if
attr_list
is
not
None
else
None
attrs
=
tvm
.
ir
.
make_node
(
"DictAttrs"
,
**
attr_list
)
if
attr_list
is
not
None
else
None
return
expr
.
Function
(
var_list
,
body
,
ret_type
,
type_params
,
attrs
)
return
function
.
Function
(
var_list
,
body
,
ret_type
,
type_params
,
attrs
)
@spanify
@spanify
def
visitFunc
(
self
,
ctx
:
RelayParser
.
FuncContext
)
->
expr
.
Function
:
def
visitFunc
(
self
,
ctx
:
RelayParser
.
FuncContext
)
->
function
.
Function
:
return
self
.
mk_func
(
ctx
)
return
self
.
mk_func
(
ctx
)
# TODO: how to set spans for definitions?
# TODO: how to set spans for definitions?
...
...
python/tvm/relay/analysis/analysis.py
View file @
7ca3212f
...
@@ -421,7 +421,7 @@ def extract_fused_functions(mod):
...
@@ -421,7 +421,7 @@ def extract_fused_functions(mod):
Returns
Returns
-------
-------
ret : Dict[int, tvm.relay.
ir.expr
.Function]
ret : Dict[int, tvm.relay.
function
.Function]
A module containing only fused primitive functions
A module containing only fused primitive functions
"""
"""
ret_mod
=
_ffi_api
.
ExtractFusedFunctions
()(
mod
)
ret_mod
=
_ffi_api
.
ExtractFusedFunctions
()(
mod
)
...
...
python/tvm/relay/backend/compile_engine.py
View file @
7ca3212f
...
@@ -25,7 +25,7 @@ from tvm import te
...
@@ -25,7 +25,7 @@ from tvm import te
from
tvm.runtime
import
Object
from
tvm.runtime
import
Object
from
...
import
target
as
_target
from
...
import
target
as
_target
from
...
import
autotvm
from
...
import
autotvm
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
..
import
op
as
_op
from
..
import
ty
as
_ty
from
..
import
ty
as
_ty
from
.
import
_backend
from
.
import
_backend
...
@@ -65,7 +65,7 @@ class CCacheValue(Object):
...
@@ -65,7 +65,7 @@ class CCacheValue(Object):
def
_get_cache_key
(
source_func
,
target
):
def
_get_cache_key
(
source_func
,
target
):
if
isinstance
(
source_func
,
_
expr
.
Function
):
if
isinstance
(
source_func
,
_
function
.
Function
):
if
isinstance
(
target
,
str
):
if
isinstance
(
target
,
str
):
target
=
_target
.
create
(
target
)
target
=
_target
.
create
(
target
)
if
not
target
:
if
not
target
:
...
...
python/tvm/relay/backend/interpreter.py
View file @
7ca3212f
...
@@ -27,7 +27,8 @@ from tvm.ir import IRModule
...
@@ -27,7 +27,8 @@ from tvm.ir import IRModule
from
.
import
_backend
from
.
import
_backend
from
..
import
_make
,
analysis
,
transform
from
..
import
_make
,
analysis
,
transform
from
...
import
nd
from
...
import
nd
from
..expr
import
Tuple
,
RefCreate
,
Call
,
Constant
,
GlobalVar
,
Function
,
const
from
..expr
import
Tuple
,
RefCreate
,
Call
,
Constant
,
GlobalVar
,
const
from
..function
import
Function
from
..scope_builder
import
ScopeBuilder
from
..scope_builder
import
ScopeBuilder
...
...
python/tvm/relay/build_module.py
View file @
7ca3212f
...
@@ -29,6 +29,7 @@ from ..contrib import graph_runtime as _graph_rt
...
@@ -29,6 +29,7 @@ from ..contrib import graph_runtime as _graph_rt
from
.
import
_build_module
from
.
import
_build_module
from
.
import
ty
as
_ty
from
.
import
ty
as
_ty
from
.
import
expr
as
_expr
from
.
import
expr
as
_expr
from
.
import
function
as
_function
from
.backend
import
interpreter
as
_interpreter
from
.backend
import
interpreter
as
_interpreter
from
.backend.vm
import
VMExecutor
from
.backend.vm
import
VMExecutor
...
@@ -218,16 +219,16 @@ def build(mod, target=None, target_host=None, params=None):
...
@@ -218,16 +219,16 @@ def build(mod, target=None, target_host=None, params=None):
params : dict
params : dict
The parameters of the final graph.
The parameters of the final graph.
"""
"""
if
not
isinstance
(
mod
,
(
IRModule
,
_
expr
.
Function
)):
if
not
isinstance
(
mod
,
(
IRModule
,
_
function
.
Function
)):
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
if
isinstance
(
mod
,
_
expr
.
Function
):
if
isinstance
(
mod
,
_
function
.
Function
):
if
params
:
if
params
:
mod
=
bind_params_by_name
(
mod
,
params
)
mod
=
bind_params_by_name
(
mod
,
params
)
mod
=
IRModule
.
from_expr
(
mod
)
mod
=
IRModule
.
from_expr
(
mod
)
warnings
.
warn
(
warnings
.
warn
(
"Please use input parameter mod (tvm.IRModule) "
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter mod (tvm.relay.
expr
.Function)"
,
"instead of deprecated parameter mod (tvm.relay.
function
.Function)"
,
DeprecationWarning
)
DeprecationWarning
)
target
=
_update_target
(
target
)
target
=
_update_target
(
target
)
...
@@ -276,16 +277,16 @@ def optimize(mod, target=None, params=None):
...
@@ -276,16 +277,16 @@ def optimize(mod, target=None, params=None):
params : dict
params : dict
The parameters of the final graph.
The parameters of the final graph.
"""
"""
if
not
isinstance
(
mod
,
(
IRModule
,
_
expr
.
Function
)):
if
not
isinstance
(
mod
,
(
IRModule
,
_
function
.
Function
)):
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
raise
ValueError
(
"Type of input parameter mod must be tvm.IRModule"
)
if
isinstance
(
mod
,
_
expr
.
Function
):
if
isinstance
(
mod
,
_
function
.
Function
):
if
params
:
if
params
:
mod
=
bind_params_by_name
(
mod
,
params
)
mod
=
bind_params_by_name
(
mod
,
params
)
mod
=
IRModule
.
from_expr
(
mod
)
mod
=
IRModule
.
from_expr
(
mod
)
warnings
.
warn
(
warnings
.
warn
(
"Please use input parameter mod (tvm.IRModule) "
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.
expr
.Function)"
,
"instead of deprecated parameter func (tvm.relay.
function
.Function)"
,
DeprecationWarning
)
DeprecationWarning
)
target
=
_update_target
(
target
)
target
=
_update_target
(
target
)
...
...
python/tvm/relay/expr.py
View file @
7ca3212f
...
@@ -22,8 +22,8 @@ from numbers import Number as _Number
...
@@ -22,8 +22,8 @@ from numbers import Number as _Number
import
numpy
as
_np
import
numpy
as
_np
import
tvm._ffi
import
tvm._ffi
from
tvm._ffi
import
base
as
_base
from
tvm._ffi
import
base
as
_base
from
tvm.runtime
import
NDArray
,
convert
,
ndarray
as
_nd
from
tvm.runtime
import
NDArray
,
ndarray
as
_nd
from
tvm.ir
import
RelayExpr
,
GlobalVar
,
BaseFunc
from
tvm.ir
import
RelayExpr
,
GlobalVar
from
.base
import
RelayNode
from
.base
import
RelayNode
from
.
import
_ffi_api
from
.
import
_ffi_api
...
@@ -225,68 +225,6 @@ class Var(ExprWithOp):
...
@@ -225,68 +225,6 @@ class Var(ExprWithOp):
return
name
return
name
@tvm._ffi.register_object
(
"relay.Function"
)
class
Function
(
BaseFunc
):
"""A function declaration expression.
Parameters
----------
params: List[tvm.relay.Var]
List of input parameters to the function.
body: tvm.relay.Expr
The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def
__init__
(
self
,
params
,
body
,
ret_type
=
None
,
type_params
=
None
,
attrs
=
None
):
if
type_params
is
None
:
type_params
=
convert
([])
self
.
__init_handle_by_constructor__
(
_ffi_api
.
Function
,
params
,
body
,
ret_type
,
type_params
,
attrs
)
def
__call__
(
self
,
*
args
):
"""Invoke the global function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return
Call
(
self
,
args
,
None
,
None
)
def
with_attr
(
self
,
attr_key
,
attr_value
):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return
_ffi_api
.
FunctionWithAttr
(
self
,
attr_key
,
convert
(
attr_value
))
@tvm._ffi.register_object
(
"relay.Call"
)
@tvm._ffi.register_object
(
"relay.Call"
)
class
Call
(
ExprWithOp
):
class
Call
(
ExprWithOp
):
"""Function call node in Relay.
"""Function call node in Relay.
...
...
python/tvm/relay/expr_functor.py
View file @
7ca3212f
...
@@ -17,7 +17,8 @@
...
@@ -17,7 +17,8 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression functor of Relay."""
"""The expression functor of Relay."""
from
.expr
import
Function
,
Call
,
Let
,
Var
,
GlobalVar
from
.function
import
Function
from
.expr
import
Call
,
Let
,
Var
,
GlobalVar
from
.expr
import
If
,
Tuple
,
TupleGetItem
,
Constant
from
.expr
import
If
,
Tuple
,
TupleGetItem
,
Constant
from
.expr
import
RefCreate
,
RefRead
,
RefWrite
from
.expr
import
RefCreate
,
RefRead
,
RefWrite
from
.adt
import
Constructor
,
Match
,
Clause
from
.adt
import
Constructor
,
Match
,
Clause
...
...
python/tvm/relay/frontend/caffe2.py
View file @
7ca3212f
...
@@ -21,6 +21,7 @@ from tvm.ir import IRModule
...
@@ -21,6 +21,7 @@ from tvm.ir import IRModule
from
..
import
analysis
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
..
import
op
as
_op
from
...
import
nd
as
_nd
from
...
import
nd
as
_nd
from
.common
import
AttrCvt
,
Renamer
from
.common
import
AttrCvt
,
Renamer
...
@@ -451,7 +452,7 @@ class Caffe2NetDef(object):
...
@@ -451,7 +452,7 @@ class Caffe2NetDef(object):
else
:
else
:
outputs
=
out
[
0
]
outputs
=
out
[
0
]
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
self
.
_mod
[
"main"
]
=
func
self
.
_mod
[
"main"
]
=
func
return
self
.
_mod
,
self
.
_params
return
self
.
_mod
,
self
.
_params
...
@@ -517,7 +518,7 @@ class Caffe2NetDef(object):
...
@@ -517,7 +518,7 @@ class Caffe2NetDef(object):
----------
----------
op_type : str
op_type : str
Operator name, such as Convolution, FullyConnected
Operator name, such as Convolution, FullyConnected
inputs : list of tvm.relay.
expr
.Function
inputs : list of tvm.relay.
function
.Function
List of input inputs.
List of input inputs.
args : dict
args : dict
Dict of operator attributes
Dict of operator attributes
...
@@ -530,7 +531,7 @@ class Caffe2NetDef(object):
...
@@ -530,7 +531,7 @@ class Caffe2NetDef(object):
Returns
Returns
-------
-------
func : tvm.relay.
expr
.Function
func : tvm.relay.
function
.Function
Converted relay function
Converted relay function
"""
"""
identity_list
=
identity_list
if
identity_list
else
_identity_list
identity_list
=
identity_list
if
identity_list
else
_identity_list
...
...
python/tvm/relay/frontend/common.py
View file @
7ca3212f
...
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
...
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from
topi.util
import
get_const_tuple
from
topi.util
import
get_const_tuple
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
transform
as
_transform
from
..
import
transform
as
_transform
from
..
import
op
as
_op
from
..
import
op
as
_op
from
..
import
analysis
from
..
import
analysis
...
@@ -459,7 +460,7 @@ def infer_type(node, mod=None):
...
@@ -459,7 +460,7 @@ def infer_type(node, mod=None):
new_mod
.
update
(
mod
)
new_mod
.
update
(
mod
)
new_mod
=
_transform
.
InferType
()(
new_mod
)
new_mod
=
_transform
.
InferType
()(
new_mod
)
entry
=
new_mod
[
"main"
]
entry
=
new_mod
[
"main"
]
return
entry
if
isinstance
(
node
,
_
expr
.
Function
)
else
entry
.
body
return
entry
if
isinstance
(
node
,
_
function
.
Function
)
else
entry
.
body
def
infer_shape
(
inputs
,
mod
=
None
):
def
infer_shape
(
inputs
,
mod
=
None
):
"""A method to get the output type of an intermediate node in the graph."""
"""A method to get the output type of an intermediate node in the graph."""
...
@@ -491,7 +492,7 @@ def infer_value(input_val, params):
...
@@ -491,7 +492,7 @@ def infer_value(input_val, params):
# Check that all free variables have associated parameters.
# Check that all free variables have associated parameters.
assert
all
(
var
.
name_hint
in
params
.
keys
()
for
var
in
analysis
.
free_vars
(
assert
all
(
var
.
name_hint
in
params
.
keys
()
for
var
in
analysis
.
free_vars
(
input_val
)),
"All inputs to infer must be available in params."
input_val
)),
"All inputs to infer must be available in params."
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
input_val
),
input_val
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
input_val
),
input_val
)
with
tvm
.
relay
.
build_config
(
opt_level
=
0
):
with
tvm
.
relay
.
build_config
(
opt_level
=
0
):
graph
,
lib
,
params
=
tvm
.
relay
.
build
(
func
,
target
=
"llvm"
,
params
=
params
)
graph
,
lib
,
params
=
tvm
.
relay
.
build
(
func
,
target
=
"llvm"
,
params
=
params
)
ctx
=
tvm
.
cpu
(
0
)
ctx
=
tvm
.
cpu
(
0
)
...
...
python/tvm/relay/frontend/coreml.py
View file @
7ca3212f
...
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
...
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from
..
import
analysis
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
..
import
op
as
_op
from
...
import
nd
as
_nd
from
...
import
nd
as
_nd
from
..._ffi
import
base
as
_base
from
..._ffi
import
base
as
_base
...
@@ -503,6 +504,6 @@ def from_coreml(model, shape=None):
...
@@ -503,6 +504,6 @@ def from_coreml(model, shape=None):
for
o
in
spec
.
description
.
output
]
for
o
in
spec
.
description
.
output
]
# for now return first output
# for now return first output
outexpr
=
outexpr
[
0
]
outexpr
=
outexpr
[
0
]
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
outexpr
),
outexpr
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
outexpr
),
outexpr
)
params
=
{
k
:
_nd
.
array
(
np
.
array
(
v
,
dtype
=
np
.
float32
))
for
k
,
v
in
etab
.
params
.
items
()}
params
=
{
k
:
_nd
.
array
(
np
.
array
(
v
,
dtype
=
np
.
float32
))
for
k
,
v
in
etab
.
params
.
items
()}
return
IRModule
.
from_expr
(
func
),
params
return
IRModule
.
from_expr
(
func
),
params
python/tvm/relay/frontend/darknet.py
View file @
7ca3212f
...
@@ -26,6 +26,7 @@ from tvm.ir import IRModule
...
@@ -26,6 +26,7 @@ from tvm.ir import IRModule
from
..
import
analysis
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
.common
import
get_relay_op
,
new_var
from
.common
import
get_relay_op
,
new_var
__all__
=
[
'from_darknet'
]
__all__
=
[
'from_darknet'
]
...
@@ -821,7 +822,7 @@ class GraphProto(object):
...
@@ -821,7 +822,7 @@ class GraphProto(object):
outputs
=
_as_list
(
sym
)
+
self
.
_outs
outputs
=
_as_list
(
sym
)
+
self
.
_outs
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
sym
=
_
expr
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
sym
=
_
function
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
return
IRModule
.
from_expr
(
sym
),
self
.
_tvmparams
return
IRModule
.
from_expr
(
sym
),
self
.
_tvmparams
def
from_darknet
(
net
,
def
from_darknet
(
net
,
...
...
python/tvm/relay/frontend/keras.py
View file @
7ca3212f
...
@@ -23,6 +23,7 @@ from tvm.ir import IRModule
...
@@ -23,6 +23,7 @@ from tvm.ir import IRModule
from
..
import
analysis
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
..
import
op
as
_op
from
...
import
nd
as
_nd
from
...
import
nd
as
_nd
from
.common
import
ExprTable
,
new_var
from
.common
import
ExprTable
,
new_var
...
@@ -914,6 +915,6 @@ def from_keras(model, shape=None, layout='NCHW'):
...
@@ -914,6 +915,6 @@ def from_keras(model, shape=None, layout='NCHW'):
outexpr
=
[
etab
.
get_expr
(
oc
[
0
]
.
name
+
":"
+
str
(
oc
[
1
])
+
":"
+
str
(
oc
[
2
]))
\
outexpr
=
[
etab
.
get_expr
(
oc
[
0
]
.
name
+
":"
+
str
(
oc
[
1
])
+
":"
+
str
(
oc
[
2
]))
\
for
oc
in
model
.
_output_coordinates
]
for
oc
in
model
.
_output_coordinates
]
outexpr
=
outexpr
[
0
]
if
len
(
outexpr
)
==
1
else
_expr
.
Tuple
(
outexpr
)
outexpr
=
outexpr
[
0
]
if
len
(
outexpr
)
==
1
else
_expr
.
Tuple
(
outexpr
)
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
outexpr
),
outexpr
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
outexpr
),
outexpr
)
params
=
{
k
:
_nd
.
array
(
np
.
array
(
v
,
dtype
=
np
.
float32
))
for
k
,
v
in
etab
.
params
.
items
()}
params
=
{
k
:
_nd
.
array
(
np
.
array
(
v
,
dtype
=
np
.
float32
))
for
k
,
v
in
etab
.
params
.
items
()}
return
IRModule
.
from_expr
(
func
),
params
return
IRModule
.
from_expr
(
func
),
params
python/tvm/relay/frontend/mxnet.py
View file @
7ca3212f
...
@@ -25,6 +25,7 @@ from tvm import relay
...
@@ -25,6 +25,7 @@ from tvm import relay
from
topi.util
import
get_const_tuple
from
topi.util
import
get_const_tuple
from
..
import
analysis
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
..
import
op
as
_op
from
..
import
scope_builder
as
_scope_builder
from
..
import
scope_builder
as
_scope_builder
from
...
import
nd
as
_nd
from
...
import
nd
as
_nd
...
@@ -1096,7 +1097,7 @@ def _mx_cond(inputs, attrs, subgraphs):
...
@@ -1096,7 +1097,7 @@ def _mx_cond(inputs, attrs, subgraphs):
else_arg_dtype_info
=
[
arg
.
type_annotation
.
dtype
for
arg
in
else_args
]
else_arg_dtype_info
=
[
arg
.
type_annotation
.
dtype
for
arg
in
else_args
]
else_func
=
_from_mxnet_impl
(
subgraphs
[
2
],
else_arg_shapes
,
else_arg_dtype_info
)
else_func
=
_from_mxnet_impl
(
subgraphs
[
2
],
else_arg_shapes
,
else_arg_dtype_info
)
sb
.
ret
(
_expr
.
Call
(
else_func
,
else_args
))
sb
.
ret
(
_expr
.
Call
(
else_func
,
else_args
))
func
=
_
expr
.
Function
(
input_args
,
sb
.
get
())
func
=
_
function
.
Function
(
input_args
,
sb
.
get
())
ret
=
_expr
.
Call
(
func
,
inputs
)
ret
=
_expr
.
Call
(
func
,
inputs
)
if
num_outputs
>
1
:
if
num_outputs
>
1
:
ret
=
_expr
.
TupleWrapper
(
ret
,
num_outputs
)
ret
=
_expr
.
TupleWrapper
(
ret
,
num_outputs
)
...
@@ -1969,7 +1970,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
...
@@ -1969,7 +1970,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info, params=None, mod=None):
outputs
=
[
node_map
[
e
[
0
]][
e
[
1
]]
for
e
in
jgraph
[
"heads"
]]
outputs
=
[
node_map
[
e
[
0
]][
e
[
1
]]
for
e
in
jgraph
[
"heads"
]]
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
return
func
return
func
...
...
python/tvm/relay/frontend/onnx.py
View file @
7ca3212f
...
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
...
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from
...
import
nd
as
_nd
from
...
import
nd
as
_nd
from
..
import
analysis
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
..
import
op
as
_op
from
.common
import
AttrCvt
,
Renamer
from
.common
import
AttrCvt
,
Renamer
from
.common
import
get_relay_op
,
new_var
,
infer_shape
,
infer_channels
from
.common
import
get_relay_op
,
new_var
,
infer_shape
,
infer_channels
...
@@ -1708,7 +1709,7 @@ class GraphProto(object):
...
@@ -1708,7 +1709,7 @@ class GraphProto(object):
# now return the outputs
# now return the outputs
outputs
=
[
self
.
_nodes
[
self
.
_parse_value_proto
(
i
)]
for
i
in
graph
.
output
]
outputs
=
[
self
.
_nodes
[
self
.
_parse_value_proto
(
i
)]
for
i
in
graph
.
output
]
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
return
IRModule
.
from_expr
(
func
),
self
.
_params
return
IRModule
.
from_expr
(
func
),
self
.
_params
def
_parse_value_proto
(
self
,
value_proto
):
def
_parse_value_proto
(
self
,
value_proto
):
...
@@ -1774,7 +1775,7 @@ class GraphProto(object):
...
@@ -1774,7 +1775,7 @@ class GraphProto(object):
----------
----------
op_name : str
op_name : str
Operator name, such as Convolution, FullyConnected
Operator name, such as Convolution, FullyConnected
inputs : list of tvm.relay.
expr
.Function
inputs : list of tvm.relay.
function
.Function
List of inputs.
List of inputs.
attrs : dict
attrs : dict
Dict of operator attributes
Dict of operator attributes
...
@@ -1783,7 +1784,7 @@ class GraphProto(object):
...
@@ -1783,7 +1784,7 @@ class GraphProto(object):
Returns
Returns
-------
-------
sym : tvm.relay.
expr
.Function
sym : tvm.relay.
function
.Function
Converted relay function
Converted relay function
"""
"""
convert_map
=
_get_convert_map
(
opset
)
convert_map
=
_get_convert_map
(
opset
)
...
...
python/tvm/relay/frontend/tensorflow.py
View file @
7ca3212f
...
@@ -31,6 +31,7 @@ from tvm.relay.prelude import Prelude
...
@@ -31,6 +31,7 @@ from tvm.relay.prelude import Prelude
from
..
import
analysis
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
..
import
op
as
_op
from
..expr_functor
import
ExprMutator
from
..expr_functor
import
ExprMutator
from
.common
import
AttrCvt
,
get_relay_op
from
.common
import
AttrCvt
,
get_relay_op
...
@@ -2461,7 +2462,7 @@ class GraphProto(object):
...
@@ -2461,7 +2462,7 @@ class GraphProto(object):
out
.
append
(
out_rnn
)
out
.
append
(
out_rnn
)
out
=
out
[
0
]
if
len
(
out
)
==
1
else
_expr
.
Tuple
(
out
)
out
=
out
[
0
]
if
len
(
out
)
==
1
else
_expr
.
Tuple
(
out
)
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
out
),
out
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
out
),
out
)
self
.
_mod
[
"main"
]
=
func
self
.
_mod
[
"main"
]
=
func
return
self
.
_mod
,
self
.
_params
return
self
.
_mod
,
self
.
_params
...
...
python/tvm/relay/frontend/tflite.py
View file @
7ca3212f
...
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
...
@@ -24,6 +24,7 @@ from tvm.ir import IRModule
from
tvm
import
relay
from
tvm
import
relay
from
..
import
analysis
from
..
import
analysis
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
function
as
_function
from
..
import
op
as
_op
from
..
import
op
as
_op
from
..
import
qnn
as
_qnn
from
..
import
qnn
as
_qnn
from
...
import
nd
as
_nd
from
...
import
nd
as
_nd
...
@@ -2365,6 +2366,6 @@ def from_tflite(model, shape_dict, dtype_dict):
...
@@ -2365,6 +2366,6 @@ def from_tflite(model, shape_dict, dtype_dict):
params
=
{
k
:
_nd
.
array
(
np
.
array
(
v
))
for
k
,
v
in
exp_tab
.
params
.
items
()}
params
=
{
k
:
_nd
.
array
(
np
.
array
(
v
))
for
k
,
v
in
exp_tab
.
params
.
items
()}
outputs
=
[
exp_tab
.
get_expr
(
get_tensor_name
(
subgraph
,
i
))
for
i
in
model_outputs
]
outputs
=
[
exp_tab
.
get_expr
(
get_tensor_name
(
subgraph
,
i
))
for
i
in
model_outputs
]
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
outputs
=
outputs
[
0
]
if
len
(
outputs
)
==
1
else
_expr
.
Tuple
(
outputs
)
func
=
_
expr
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
func
=
_
function
.
Function
(
analysis
.
free_vars
(
outputs
),
outputs
)
mod
=
IRModule
.
from_expr
(
func
)
mod
=
IRModule
.
from_expr
(
func
)
return
mod
,
params
return
mod
,
params
python/tvm/relay/function.py
0 → 100644
View file @
7ca3212f
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, invalid-name, unused-import
"""The expression nodes of Relay."""
from
__future__
import
absolute_import
import
tvm._ffi
from
tvm.runtime
import
convert
from
tvm.ir
import
BaseFunc
from
.expr
import
Call
from
.
import
_ffi_api
@tvm._ffi.register_object
(
"relay.Function"
)
class
Function
(
BaseFunc
):
"""A function declaration expression.
Parameters
----------
params: List[tvm.relay.Var]
List of input parameters to the function.
body: tvm.relay.Expr
The body of the function.
ret_type: Optional[tvm.relay.Type]
The return type annotation of the function.
type_params: Optional[List[tvm.relay.TypeParam]]
The additional type parameters, this is only
used in advanced usecase of template functions.
"""
def
__init__
(
self
,
params
,
body
,
ret_type
=
None
,
type_params
=
None
,
attrs
=
None
):
if
type_params
is
None
:
type_params
=
convert
([])
self
.
__init_handle_by_constructor__
(
_ffi_api
.
Function
,
params
,
body
,
ret_type
,
type_params
,
attrs
)
def
__call__
(
self
,
*
args
):
"""Invoke the global function.
Parameters
----------
args: List[relay.Expr]
Arguments.
"""
return
Call
(
self
,
args
,
None
,
None
)
def
with_attr
(
self
,
attr_key
,
attr_value
):
"""Create a new copy of the function and update the attribute
Parameters
----------
attr_key : str
The attribute key to use.
attr_value : Object
The new attribute value.
Returns
-------
func : Function
A new copy of the function
"""
return
_ffi_api
.
FunctionWithAttr
(
self
,
attr_key
,
convert
(
attr_value
))
python/tvm/relay/loops.py
View file @
7ca3212f
...
@@ -20,6 +20,7 @@ Utilities for building Relay loops.
...
@@ -20,6 +20,7 @@ Utilities for building Relay loops.
"""
"""
from
.scope_builder
import
ScopeBuilder
from
.scope_builder
import
ScopeBuilder
from
.
import
expr
as
_expr
from
.
import
expr
as
_expr
from
.
import
function
as
_function
def
while_loop
(
cond
,
loop_vars
,
loop_bodies
):
def
while_loop
(
cond
,
loop_vars
,
loop_bodies
):
"""
"""
...
@@ -60,6 +61,6 @@ def while_loop(cond, loop_vars, loop_bodies):
...
@@ -60,6 +61,6 @@ def while_loop(cond, loop_vars, loop_bodies):
with
sb
.
else_scope
():
with
sb
.
else_scope
():
sb
.
ret
(
_expr
.
Tuple
(
fresh_vars
))
sb
.
ret
(
_expr
.
Tuple
(
fresh_vars
))
func
=
_
expr
.
Function
(
fresh_vars
,
sb
.
get
())
func
=
_
function
.
Function
(
fresh_vars
,
sb
.
get
())
let
=
_expr
.
Let
(
loop
,
func
,
loop
)
let
=
_expr
.
Let
(
loop
,
func
,
loop
)
return
let
return
let
python/tvm/relay/prelude.py
View file @
7ca3212f
...
@@ -19,7 +19,8 @@
...
@@ -19,7 +19,8 @@
from
tvm.ir
import
IRModule
from
tvm.ir
import
IRModule
from
.ty
import
GlobalTypeVar
,
TensorType
,
Any
,
scalar_type
from
.ty
import
GlobalTypeVar
,
TensorType
,
Any
,
scalar_type
from
.expr
import
Var
,
Function
,
GlobalVar
,
If
,
const
from
.expr
import
Var
,
GlobalVar
,
If
,
const
from
.function
import
Function
from
.op.tensor
import
add
,
subtract
,
equal
from
.op.tensor
import
add
,
subtract
,
equal
from
.adt
import
Constructor
,
TypeData
,
Clause
,
Match
from
.adt
import
Constructor
,
TypeData
,
Clause
,
Match
from
.adt
import
PatternConstructor
,
PatternVar
,
PatternWildcard
from
.adt
import
PatternConstructor
,
PatternVar
,
PatternWildcard
...
...
python/tvm/relay/testing/nat.py
View file @
7ca3212f
...
@@ -21,7 +21,8 @@ test cases for recursion and pattern matching."""
...
@@ -21,7 +21,8 @@ test cases for recursion and pattern matching."""
from
tvm.relay.adt
import
Constructor
,
TypeData
,
Clause
,
Match
,
PatternConstructor
,
PatternVar
from
tvm.relay.adt
import
Constructor
,
TypeData
,
Clause
,
Match
,
PatternConstructor
,
PatternVar
from
tvm.relay.backend.interpreter
import
ConstructorValue
from
tvm.relay.backend.interpreter
import
ConstructorValue
from
tvm.relay.expr
import
Var
,
Function
,
GlobalVar
from
tvm.relay.expr
import
Var
,
GlobalVar
from
tvm.relay.function
import
Function
from
tvm.relay.ty
import
GlobalTypeVar
,
TypeVar
,
FuncType
from
tvm.relay.ty
import
GlobalTypeVar
,
TypeVar
,
FuncType
def
define_nat_adt
(
prelude
):
def
define_nat_adt
(
prelude
):
...
...
python/tvm/relay/testing/py_converter.py
View file @
7ca3212f
...
@@ -23,7 +23,8 @@ import tvm
...
@@ -23,7 +23,8 @@ import tvm
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.adt
import
Pattern
from
tvm.relay.adt
import
Pattern
from
tvm.relay.backend
import
compile_engine
from
tvm.relay.backend
import
compile_engine
from
tvm.relay.expr
import
Expr
,
Function
,
GlobalVar
,
Var
from
tvm.relay.expr
import
Expr
,
GlobalVar
,
Var
from
tvm.relay.function
import
Function
from
tvm.relay.expr_functor
import
ExprFunctor
from
tvm.relay.expr_functor
import
ExprFunctor
OUTPUT_VAR_NAME
=
'_py_out'
OUTPUT_VAR_NAME
=
'_py_out'
...
...
src/relay/ir/function.cc
View file @
7ca3212f
...
@@ -27,10 +27,10 @@ namespace tvm {
...
@@ -27,10 +27,10 @@ namespace tvm {
namespace
relay
{
namespace
relay
{
Function
::
Function
(
tvm
::
Array
<
Var
>
params
,
Function
::
Function
(
tvm
::
Array
<
Var
>
params
,
Expr
body
,
Expr
body
,
Type
ret_type
,
Type
ret_type
,
tvm
::
Array
<
TypeVar
>
type_params
,
tvm
::
Array
<
TypeVar
>
type_params
,
DictAttrs
attrs
)
{
DictAttrs
attrs
)
{
ObjectPtr
<
FunctionNode
>
n
=
make_object
<
FunctionNode
>
();
ObjectPtr
<
FunctionNode
>
n
=
make_object
<
FunctionNode
>
();
CHECK
(
params
.
defined
());
CHECK
(
params
.
defined
());
CHECK
(
type_params
.
defined
());
CHECK
(
type_params
.
defined
());
...
@@ -66,7 +66,6 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
...
@@ -66,7 +66,6 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
return
Function
(
params
,
body
,
ret_type
,
ty_params
,
attrs
);
return
Function
(
params
,
body
,
ret_type
,
ty_params
,
attrs
);
});
});
TVM_STATIC_IR_FUNCTOR
(
ReprPrinter
,
vtable
)
TVM_STATIC_IR_FUNCTOR
(
ReprPrinter
,
vtable
)
.
set_dispatch
<
FunctionNode
>
([](
const
ObjectRef
&
ref
,
ReprPrinter
*
p
)
{
.
set_dispatch
<
FunctionNode
>
([](
const
ObjectRef
&
ref
,
ReprPrinter
*
p
)
{
auto
*
node
=
static_cast
<
const
FunctionNode
*>
(
ref
.
get
());
auto
*
node
=
static_cast
<
const
FunctionNode
*>
(
ref
.
get
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment