Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
275e317c
Unverified
Commit
275e317c
authored
Apr 14, 2020
by
Tianqi Chen
Committed by
GitHub
Apr 14, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY] Remove re-exports of tvm.transform (#5337)
parent
f08d5d78
Hide whitespace changes
Inline
Side-by-side
Showing
38 changed files
with
169 additions
and
229 deletions
+169
-229
docs/api/python/ir.rst
+8
-0
docs/dev/convert_layout.rst
+1
-1
docs/dev/relay_pass_infra.rst
+2
-2
include/tvm/ir/transform.h
+3
-1
python/tvm/ir/json_compact.py
+1
-1
python/tvm/ir/transform.py
+5
-2
python/tvm/relay/__init__.py
+0
-11
python/tvm/relay/backend/interpreter.py
+4
-4
python/tvm/relay/qnn/transform.py
+2
-2
python/tvm/relay/quantize/quantize.py
+18
-15
python/tvm/relay/testing/__init__.py
+1
-1
python/tvm/relay/testing/py_converter.py
+2
-2
python/tvm/relay/transform/transform.py
+36
-53
src/ir/transform.cc
+3
-3
src/relay/transforms/print_ir.cc
+0
-49
tests/python/relay/test_op_level10.py
+4
-4
tests/python/relay/test_pass_alter_op_layout.py
+2
-2
tests/python/relay/test_pass_annotation.py
+2
-2
tests/python/relay/test_pass_canonicalize_cast.py
+2
-2
tests/python/relay/test_pass_combine_parallel_conv2d.py
+1
-1
tests/python/relay/test_pass_combine_parallel_dense.py
+1
-1
tests/python/relay/test_pass_convert_op_layout.py
+2
-2
tests/python/relay/test_pass_dead_code_elimination.py
+1
-1
tests/python/relay/test_pass_eliminate_common_subexpr.py
+1
-1
tests/python/relay/test_pass_eta_expand.py
+4
-4
tests/python/relay/test_pass_fold_constant.py
+2
-2
tests/python/relay/test_pass_fold_scale_axis.py
+1
-1
tests/python/relay/test_pass_lazy_gradient_init.py
+13
-13
tests/python/relay/test_pass_legalize.py
+2
-2
tests/python/relay/test_pass_mac_count.py
+1
-1
tests/python/relay/test_pass_manager.py
+17
-17
tests/python/relay/test_pass_partial_eval.py
+3
-3
tests/python/relay/test_pass_partition_graph.py
+4
-4
tests/python/relay/test_pass_qnn_legalize.py
+2
-2
tests/python/relay/test_pass_to_a_normal_form.py
+2
-2
tests/python/relay/test_pass_to_cps.py
+2
-1
tutorials/dev/relay_pass_infra.py
+13
-13
vta/python/vta/top/graphpack.py
+1
-1
No files found.
docs/api/python/ir.rst
View file @
275e317c
...
@@ -21,3 +21,11 @@ tvm.ir
...
@@ -21,3 +21,11 @@ tvm.ir
:members:
:members:
:imported-members:
:imported-members:
:autosummary:
:autosummary:
tvm.transform
-------------
.. automodule:: tvm.transform
:members:
:imported-members:
:autosummary:
docs/dev/convert_layout.rst
View file @
275e317c
...
@@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r
...
@@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r
# Convert the layout to NCHW
# Convert the layout to NCHW
# RemoveUnunsedFunctions is used to clean up the graph.
# RemoveUnunsedFunctions is used to clean up the graph.
seq =
relay
.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
seq =
tvm
.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout('NCHW')])
relay.transform.ConvertLayout('NCHW')])
with relay.transform.PassContext(opt_level=3):
with relay.transform.PassContext(opt_level=3):
mod = seq(mod)
mod = seq(mod)
...
...
docs/dev/relay_pass_infra.rst
View file @
275e317c
...
@@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes.
...
@@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes.
func = relay.Function([x], z2)
func = relay.Function([x], z2)
# Customize the optimization pipeline.
# Customize the optimization pipeline.
seq =
_
transform.Sequential([
seq =
tvm.
transform.Sequential([
relay.transform.InferType(),
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
relay.transform.EliminateCommonSubexpr(),
...
@@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for
...
@@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for
.. code:: python
.. code:: python
seq =
_
transform.Sequential([
seq =
tvm.
transform.Sequential([
relay.transform.InferType(),
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.FoldConstant(),
relay.transform.PrintIR(),
relay.transform.PrintIR(),
...
...
include/tvm/ir/transform.h
View file @
275e317c
...
@@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass(
...
@@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass(
/*!
/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
* \param show_meta_data Whether should we show meta data.
* \return The pass.
* \return The pass.
*/
*/
TVM_DLL
Pass
PrintIR
(
std
::
string
header
);
TVM_DLL
Pass
PrintIR
(
std
::
string
header
=
""
,
bool
show_meta_data
=
false
);
}
// namespace transform
}
// namespace transform
}
// namespace tvm
}
// namespace tvm
...
...
python/tvm/ir/json_compact.py
View file @
275e317c
...
@@ -106,7 +106,7 @@ def create_updater_06_to_07():
...
@@ -106,7 +106,7 @@ def create_updater_06_to_07():
"relay.PassInfo"
:
_rename
(
"transform.PassInfo"
),
"relay.PassInfo"
:
_rename
(
"transform.PassInfo"
),
"relay.PassContext"
:
_rename
(
"transform.PassContext"
),
"relay.PassContext"
:
_rename
(
"transform.PassContext"
),
"relay.ModulePass"
:
_rename
(
"transform.ModulePass"
),
"relay.ModulePass"
:
_rename
(
"transform.ModulePass"
),
"relay.Sequ
antial"
:
_rename
(
"transform.Sequa
ntial"
),
"relay.Sequ
ential"
:
_rename
(
"transform.Seque
ntial"
),
# TIR
# TIR
"Variable"
:
_update_tir_var
(
"tir.Var"
),
"Variable"
:
_update_tir_var
(
"tir.Var"
),
"SizeVar"
:
_update_tir_var
(
"tir.SizeVar"
),
"SizeVar"
:
_update_tir_var
(
"tir.SizeVar"
),
...
...
python/tvm/ir/transform.py
View file @
275e317c
...
@@ -329,7 +329,7 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
...
@@ -329,7 +329,7 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
return
create_module_pass
return
create_module_pass
def
PrintIR
(
header
):
def
PrintIR
(
header
=
""
,
show_meta_data
=
False
):
"""A special trace pass that prints the header and IR.
"""A special trace pass that prints the header and IR.
Parameters
Parameters
...
@@ -337,8 +337,11 @@ def PrintIR(header):
...
@@ -337,8 +337,11 @@ def PrintIR(header):
header : str
header : str
The header to be displayed along with the dump.
The header to be displayed along with the dump.
show_meta_data : bool
A boolean flag to indicate if meta data should be printed.
Returns
Returns
--------
--------
The pass
The pass
"""
"""
return
_ffi_transform_api
.
PrintIR
(
header
)
return
_ffi_transform_api
.
PrintIR
(
header
,
show_meta_data
)
python/tvm/relay/__init__.py
View file @
275e317c
...
@@ -128,20 +128,9 @@ Prelude = prelude.Prelude
...
@@ -128,20 +128,9 @@ Prelude = prelude.Prelude
# Scope builder
# Scope builder
ScopeBuilder
=
scope_builder
.
ScopeBuilder
ScopeBuilder
=
scope_builder
.
ScopeBuilder
module_pass
=
transform
.
module_pass
function_pass
=
transform
.
function_pass
# Parser
# Parser
fromtext
=
parser
.
fromtext
fromtext
=
parser
.
fromtext
# Param Serialization
# Param Serialization
save_param_dict
=
param_dict
.
save_param_dict
save_param_dict
=
param_dict
.
save_param_dict
load_param_dict
=
param_dict
.
load_param_dict
load_param_dict
=
param_dict
.
load_param_dict
# Pass manager
PassInfo
=
transform
.
PassInfo
PassContext
=
transform
.
PassContext
Pass
=
transform
.
Pass
ModulePass
=
transform
.
ModulePass
FunctionPass
=
transform
.
FunctionPass
Sequential
=
transform
.
Sequential
python/tvm/relay/backend/interpreter.py
View file @
275e317c
...
@@ -210,10 +210,10 @@ class Interpreter(Executor):
...
@@ -210,10 +210,10 @@ class Interpreter(Executor):
opt_mod : tvm.IRModule
opt_mod : tvm.IRModule
The optimized module.
The optimized module.
"""
"""
seq
=
transform
.
Sequential
([
transform
.
SimplifyInference
(),
seq
=
t
vm
.
t
ransform
.
Sequential
([
transform
.
SimplifyInference
(),
transform
.
FuseOps
(
0
),
transform
.
FuseOps
(
0
),
transform
.
ToANormalForm
(),
transform
.
ToANormalForm
(),
transform
.
InferType
()])
transform
.
InferType
()])
return
seq
(
self
.
mod
)
return
seq
(
self
.
mod
)
def
_make_executor
(
self
,
expr
=
None
):
def
_make_executor
(
self
,
expr
=
None
):
...
...
python/tvm/relay/qnn/transform.py
View file @
275e317c
...
@@ -60,7 +60,7 @@ def CanonicalizeOps():
...
@@ -60,7 +60,7 @@ def CanonicalizeOps():
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that canonicalizes QNN ops to Relay ops.
The registered pass that canonicalizes QNN ops to Relay ops.
"""
"""
...
@@ -108,7 +108,7 @@ def Legalize():
...
@@ -108,7 +108,7 @@ def Legalize():
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that legalizes QNN ops.
The registered pass that legalizes QNN ops.
"""
"""
...
...
python/tvm/relay/quantize/quantize.py
View file @
275e317c
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#pylint: disable=unused-argument, not-context-manager
#pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit."""
"""Automatic quantization toolkit."""
import
tvm.ir
import
tvm.ir
import
tvm
from
tvm.runtime
import
Object
from
tvm.runtime
import
Object
from
.
import
_quantize
from
.
import
_quantize
...
@@ -240,7 +241,7 @@ def partition():
...
@@ -240,7 +241,7 @@ def partition():
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass for VTA rewrite.
The registered pass for VTA rewrite.
"""
"""
return
_quantize
.
QuantizePartition
()
return
_quantize
.
QuantizePartition
()
...
@@ -253,7 +254,7 @@ def annotate():
...
@@ -253,7 +254,7 @@ def annotate():
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass for quantization annotation.
The registered pass for quantization annotation.
"""
"""
return
_quantize
.
QuantizeAnnotate
()
return
_quantize
.
QuantizeAnnotate
()
...
@@ -267,7 +268,7 @@ def realize():
...
@@ -267,7 +268,7 @@ def realize():
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass for quantization realization.
The registered pass for quantization realization.
"""
"""
return
_quantize
.
QuantizeRealize
()
return
_quantize
.
QuantizeRealize
()
...
@@ -298,11 +299,12 @@ def prerequisite_optimize(mod, params=None):
...
@@ -298,11 +299,12 @@ def prerequisite_optimize(mod, params=None):
""" Prerequisite optimization passes for quantization. Perform
""" Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """
"CanonicalizeOps" optimization before quantization. """
optimize
=
_transform
.
Sequential
([
_transform
.
SimplifyInference
(),
optimize
=
tvm
.
transform
.
Sequential
(
_transform
.
FoldConstant
(),
[
_transform
.
SimplifyInference
(),
_transform
.
FoldScaleAxis
(),
_transform
.
FoldConstant
(),
_transform
.
CanonicalizeOps
(),
_transform
.
FoldScaleAxis
(),
_transform
.
FoldConstant
()])
_transform
.
CanonicalizeOps
(),
_transform
.
FoldConstant
()])
if
params
:
if
params
:
mod
[
'main'
]
=
_bind_params
(
mod
[
'main'
],
params
)
mod
[
'main'
]
=
_bind_params
(
mod
[
'main'
],
params
)
...
@@ -336,19 +338,20 @@ def quantize(mod, params=None, dataset=None):
...
@@ -336,19 +338,20 @@ def quantize(mod, params=None, dataset=None):
"""
"""
mod
=
prerequisite_optimize
(
mod
,
params
)
mod
=
prerequisite_optimize
(
mod
,
params
)
calibrate_pass
=
_transform
.
module_pass
(
calibrate
(
dataset
),
opt_level
=
1
,
calibrate_pass
=
tvm
.
transform
.
module_pass
(
name
=
"QuantizeCalibrate"
)
calibrate
(
dataset
),
opt_level
=
1
,
name
=
"QuantizeCalibrate"
)
quant_passes
=
[
partition
(),
quant_passes
=
[
partition
(),
annotate
(),
annotate
(),
calibrate_pass
]
calibrate_pass
]
if
not
current_qconfig
()
.
do_simulation
:
if
not
current_qconfig
()
.
do_simulation
:
quant_passes
.
append
(
realize
())
quant_passes
.
append
(
realize
())
quant_passes
.
append
(
_transform
.
FoldConstant
())
quant_passes
.
append
(
_transform
.
FoldConstant
())
quantize_seq
=
_
transform
.
Sequential
(
quant_passes
)
quantize_seq
=
tvm
.
transform
.
Sequential
(
quant_passes
)
with
_
transform
.
PassContext
(
opt_level
=
3
,
with
tvm
.
transform
.
PassContext
(
opt_level
=
3
,
required_pass
=
[
"QuantizeAnnotate"
,
required_pass
=
[
"QuantizeAnnotate"
,
"QuantizeCalibrate"
,
"QuantizeCalibrate"
,
"QuantizeRealize"
]):
"QuantizeRealize"
]):
with
quantize_context
():
with
quantize_context
():
mod
=
quantize_seq
(
mod
)
mod
=
quantize_seq
(
mod
)
...
...
python/tvm/relay/testing/__init__.py
View file @
275e317c
...
@@ -47,7 +47,7 @@ from .py_converter import to_python, run_as_python
...
@@ -47,7 +47,7 @@ from .py_converter import to_python, run_as_python
from
..transform
import
gradient
from
..transform
import
gradient
def
run_opt_pass
(
expr
,
opt_pass
):
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
mod
=
opt_pass
(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
...
...
python/tvm/relay/testing/py_converter.py
View file @
275e317c
...
@@ -95,8 +95,8 @@ class PythonConverter(ExprFunctor):
...
@@ -95,8 +95,8 @@ class PythonConverter(ExprFunctor):
# necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# and fusion (to get primitive functions)
# and fusion (to get primitive functions)
opts
=
relay
.
transform
.
Sequential
([
relay
.
transform
.
SimplifyInference
(),
opts
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
SimplifyInference
(),
relay
.
transform
.
FuseOps
(
fuse_opt_level
=
0
)])
relay
.
transform
.
FuseOps
(
fuse_opt_level
=
0
)])
mod
=
opts
(
mod
)
mod
=
opts
(
mod
)
optimized
=
mod
[
'main'
]
optimized
=
mod
[
'main'
]
return
optimized
if
isinstance
(
unwrapped
,
Function
)
else
optimized
.
body
return
optimized
if
isinstance
(
unwrapped
,
Function
)
else
optimized
.
body
...
...
python/tvm/relay/transform/transform.py
View file @
275e317c
...
@@ -22,10 +22,9 @@ import types
...
@@ -22,10 +22,9 @@ import types
import
inspect
import
inspect
import
functools
import
functools
import
tvm
import
tvm
.ir
from
tvm
import
te
from
tvm
import
te
from
tvm.runtime
import
ndarray
as
_nd
from
tvm.runtime
import
ndarray
as
_nd
from
tvm.ir.transform
import
PassInfo
,
PassContext
,
Pass
,
ModulePass
,
Sequential
,
module_pass
from
tvm
import
relay
from
tvm
import
relay
from
.
import
_ffi_api
from
.
import
_ffi_api
...
@@ -78,12 +77,13 @@ def build_config(opt_level=2,
...
@@ -78,12 +77,13 @@ def build_config(opt_level=2,
pass_context: PassContext
pass_context: PassContext
The pass context for optimizations.
The pass context for optimizations.
"""
"""
return
PassContext
(
opt_level
,
fallback_device
,
required_pass
,
return
tvm
.
ir
.
transform
.
PassContext
(
disabled_pass
,
trace
)
opt_level
,
fallback_device
,
required_pass
,
disabled_pass
,
trace
)
@tvm._ffi.register_object
(
"relay.FunctionPass"
)
@tvm._ffi.register_object
(
"relay.FunctionPass"
)
class
FunctionPass
(
Pass
):
class
FunctionPass
(
tvm
.
ir
.
transform
.
Pass
):
"""A pass that works on each tvm.relay.Function in a module. A function
"""A pass that works on each tvm.relay.Function in a module. A function
pass class should be created through `function_pass`.
pass class should be created through `function_pass`.
"""
"""
...
@@ -94,7 +94,7 @@ def InferType():
...
@@ -94,7 +94,7 @@ def InferType():
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered type inference pass.
The registered type inference pass.
"""
"""
return
_ffi_api
.
InferType
()
return
_ffi_api
.
InferType
()
...
@@ -106,7 +106,7 @@ def FoldScaleAxis():
...
@@ -106,7 +106,7 @@ def FoldScaleAxis():
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass to fold expressions.
The registered pass to fold expressions.
Note
Note
...
@@ -123,7 +123,7 @@ def BackwardFoldScaleAxis():
...
@@ -123,7 +123,7 @@ def BackwardFoldScaleAxis():
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass to backward fold expressions.
The registered pass to backward fold expressions.
Note
Note
...
@@ -144,7 +144,7 @@ def RemoveUnusedFunctions(entry_functions=None):
...
@@ -144,7 +144,7 @@ def RemoveUnusedFunctions(entry_functions=None):
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass to remove unused functions.
The registered pass to remove unused functions.
"""
"""
if
entry_functions
is
None
:
if
entry_functions
is
None
:
...
@@ -156,7 +156,7 @@ def ForwardFoldScaleAxis():
...
@@ -156,7 +156,7 @@ def ForwardFoldScaleAxis():
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass to forward fold expressions.
The registered pass to forward fold expressions.
Note
Note
...
@@ -174,7 +174,7 @@ def SimplifyInference():
...
@@ -174,7 +174,7 @@ def SimplifyInference():
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass to perform operator simplification.
The registered pass to perform operator simplification.
"""
"""
return
_ffi_api
.
SimplifyInference
()
return
_ffi_api
.
SimplifyInference
()
...
@@ -185,7 +185,7 @@ def FastMath():
...
@@ -185,7 +185,7 @@ def FastMath():
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass to perform fast math operations.
The registered pass to perform fast math operations.
"""
"""
return
_ffi_api
.
FastMath
()
return
_ffi_api
.
FastMath
()
...
@@ -198,7 +198,7 @@ def CanonicalizeOps():
...
@@ -198,7 +198,7 @@ def CanonicalizeOps():
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass performing the canonicalization.
The registered pass performing the canonicalization.
"""
"""
return
_ffi_api
.
CanonicalizeOps
()
return
_ffi_api
.
CanonicalizeOps
()
...
@@ -214,7 +214,7 @@ def DeadCodeElimination(inline_once=False):
...
@@ -214,7 +214,7 @@ def DeadCodeElimination(inline_once=False):
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that eliminates the dead code in a Relay program.
The registered pass that eliminates the dead code in a Relay program.
"""
"""
return
_ffi_api
.
DeadCodeElimination
(
inline_once
)
return
_ffi_api
.
DeadCodeElimination
(
inline_once
)
...
@@ -227,7 +227,7 @@ def LazyGradientInit():
...
@@ -227,7 +227,7 @@ def LazyGradientInit():
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
A pass which delays and/or reduces memory allocation,
A pass which delays and/or reduces memory allocation,
by lazily allocating 0 or one filled tensors.
by lazily allocating 0 or one filled tensors.
"""
"""
...
@@ -238,7 +238,7 @@ def FoldConstant():
...
@@ -238,7 +238,7 @@ def FoldConstant():
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass for constant folding.
The registered pass for constant folding.
"""
"""
return
_ffi_api
.
FoldConstant
()
return
_ffi_api
.
FoldConstant
()
...
@@ -255,7 +255,7 @@ def FuseOps(fuse_opt_level=-1):
...
@@ -255,7 +255,7 @@ def FuseOps(fuse_opt_level=-1):
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass for operator fusion.
The registered pass for operator fusion.
"""
"""
return
_ffi_api
.
FuseOps
(
fuse_opt_level
)
return
_ffi_api
.
FuseOps
(
fuse_opt_level
)
...
@@ -272,7 +272,7 @@ def CombineParallelConv2D(min_num_branches=3):
...
@@ -272,7 +272,7 @@ def CombineParallelConv2D(min_num_branches=3):
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that combines parallel conv2d operators.
The registered pass that combines parallel conv2d operators.
"""
"""
return
_ffi_api
.
CombineParallelConv2D
(
min_num_branches
)
return
_ffi_api
.
CombineParallelConv2D
(
min_num_branches
)
...
@@ -304,7 +304,7 @@ def CombineParallelDense(min_num_branches=3):
...
@@ -304,7 +304,7 @@ def CombineParallelDense(min_num_branches=3):
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that combines parallel dense operators.
The registered pass that combines parallel dense operators.
"""
"""
return
_ffi_api
.
CombineParallelDense
(
min_num_branches
)
return
_ffi_api
.
CombineParallelDense
(
min_num_branches
)
...
@@ -318,7 +318,7 @@ def AlterOpLayout():
...
@@ -318,7 +318,7 @@ def AlterOpLayout():
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that alters the layout of operators.
The registered pass that alters the layout of operators.
"""
"""
return
_ffi_api
.
AlterOpLayout
()
return
_ffi_api
.
AlterOpLayout
()
...
@@ -366,7 +366,7 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
...
@@ -366,7 +366,7 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that rewrites an expr.
The registered pass that rewrites an expr.
"""
"""
return
_ffi_api
.
Legalize
(
legalize_map_attr_name
)
return
_ffi_api
.
Legalize
(
legalize_map_attr_name
)
...
@@ -387,7 +387,7 @@ def MergeComposite(pattern_table):
...
@@ -387,7 +387,7 @@ def MergeComposite(pattern_table):
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that merges operators into a single composite
The registered pass that merges operators into a single composite
relay function.
relay function.
"""
"""
...
@@ -413,7 +413,7 @@ def MergeCompilerRegions():
...
@@ -413,7 +413,7 @@ def MergeCompilerRegions():
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that merges compiler regions.
The registered pass that merges compiler regions.
"""
"""
return
_ffi_api
.
MergeCompilerRegions
()
return
_ffi_api
.
MergeCompilerRegions
()
...
@@ -433,7 +433,7 @@ def RewriteAnnotatedOps(fallback_device):
...
@@ -433,7 +433,7 @@ def RewriteAnnotatedOps(fallback_device):
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that rewrites an expression with annotated
The registered pass that rewrites an expression with annotated
`on_device` operators.
`on_device` operators.
"""
"""
...
@@ -448,7 +448,7 @@ def ToANormalForm():
...
@@ -448,7 +448,7 @@ def ToANormalForm():
Returns
Returns
-------
-------
ret: Union[tvm.
relay
.Pass, tvm.relay.Expr]
ret: Union[tvm.
transform
.Pass, tvm.relay.Expr]
The registered pass that transforms an expression into A Normal Form.
The registered pass that transforms an expression into A Normal Form.
"""
"""
return
_ffi_api
.
ToANormalForm
()
return
_ffi_api
.
ToANormalForm
()
...
@@ -462,7 +462,7 @@ def ToCPS(expr, mod=None):
...
@@ -462,7 +462,7 @@ def ToCPS(expr, mod=None):
Returns
Returns
-------
-------
result: tvm.
relay
.Pass
result: tvm.
transform
.Pass
The registered pass that transforms an expression into CPS.
The registered pass that transforms an expression into CPS.
"""
"""
return
_ffi_api
.
to_cps
(
expr
,
mod
)
return
_ffi_api
.
to_cps
(
expr
,
mod
)
...
@@ -481,7 +481,7 @@ def EtaExpand(expand_constructor=False, expand_global_var=False):
...
@@ -481,7 +481,7 @@ def EtaExpand(expand_constructor=False, expand_global_var=False):
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that eta expands an expression.
The registered pass that eta expands an expression.
"""
"""
return
_ffi_api
.
EtaExpand
(
expand_constructor
,
expand_global_var
)
return
_ffi_api
.
EtaExpand
(
expand_constructor
,
expand_global_var
)
...
@@ -492,7 +492,7 @@ def ToGraphNormalForm():
...
@@ -492,7 +492,7 @@ def ToGraphNormalForm():
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that transforms an expression into Graph Normal Form.
The registered pass that transforms an expression into Graph Normal Form.
"""
"""
return
_ffi_api
.
ToGraphNormalForm
()
return
_ffi_api
.
ToGraphNormalForm
()
...
@@ -509,7 +509,7 @@ def EliminateCommonSubexpr(fskip=None):
...
@@ -509,7 +509,7 @@ def EliminateCommonSubexpr(fskip=None):
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that eliminates common subexpressions.
The registered pass that eliminates common subexpressions.
"""
"""
return
_ffi_api
.
EliminateCommonSubexpr
(
fskip
)
return
_ffi_api
.
EliminateCommonSubexpr
(
fskip
)
...
@@ -527,7 +527,7 @@ def PartialEvaluate():
...
@@ -527,7 +527,7 @@ def PartialEvaluate():
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that performs partial evaluation on an expression.
The registered pass that performs partial evaluation on an expression.
"""
"""
return
_ffi_api
.
PartialEvaluate
()
return
_ffi_api
.
PartialEvaluate
()
...
@@ -539,7 +539,7 @@ def CanonicalizeCast():
...
@@ -539,7 +539,7 @@ def CanonicalizeCast():
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that canonicalizes cast expression.
The registered pass that canonicalizes cast expression.
"""
"""
return
_ffi_api
.
CanonicalizeCast
()
return
_ffi_api
.
CanonicalizeCast
()
...
@@ -551,36 +551,19 @@ def LambdaLift():
...
@@ -551,36 +551,19 @@ def LambdaLift():
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that lifts the lambda function.
The registered pass that lifts the lambda function.
"""
"""
return
_ffi_api
.
LambdaLift
()
return
_ffi_api
.
LambdaLift
()
def
PrintIR
(
show_meta_data
=
True
):
"""
Print the IR for a module to help debugging.
Parameters
----------
show_meta_data : bool
A boolean flag to indicate if meta data should be printed.
Returns
-------
ret : tvm.relay.Pass
The registered pass that prints the module IR.
"""
return
_ffi_api
.
PrintIR
(
show_meta_data
)
def
PartitionGraph
():
def
PartitionGraph
():
"""Partition a Relay program into regions that can be executed on different
"""Partition a Relay program into regions that can be executed on different
backends.
backends.
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that partitions the Relay program.
The registered pass that partitions the Relay program.
"""
"""
return
_ffi_api
.
PartitionGraph
()
return
_ffi_api
.
PartitionGraph
()
...
@@ -598,7 +581,7 @@ def AnnotateTarget(targets):
...
@@ -598,7 +581,7 @@ def AnnotateTarget(targets):
Returns
Returns
-------
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The annotated pass that wrapps ops with subgraph_start and
The annotated pass that wrapps ops with subgraph_start and
subgraph_end.
subgraph_end.
"""
"""
...
@@ -614,7 +597,7 @@ def Inline():
...
@@ -614,7 +597,7 @@ def Inline():
Returns
Returns
-------
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that performs inlining for a Relay IR module.
The registered pass that performs inlining for a Relay IR module.
"""
"""
return
_ffi_api
.
Inline
()
return
_ffi_api
.
Inline
()
...
@@ -809,7 +792,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
...
@@ -809,7 +792,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
def
create_function_pass
(
pass_arg
):
def
create_function_pass
(
pass_arg
):
"""Internal function that creates a function pass"""
"""Internal function that creates a function pass"""
fname
=
name
if
name
else
pass_arg
.
__name__
fname
=
name
if
name
else
pass_arg
.
__name__
info
=
PassInfo
(
opt_level
,
fname
,
required
)
info
=
tvm
.
transform
.
PassInfo
(
opt_level
,
fname
,
required
)
if
inspect
.
isclass
(
pass_arg
):
if
inspect
.
isclass
(
pass_arg
):
return
_wrap_class_function_pass
(
pass_arg
,
info
)
return
_wrap_class_function_pass
(
pass_arg
,
info
)
if
not
isinstance
(
pass_arg
,
(
types
.
FunctionType
,
types
.
LambdaType
)):
if
not
isinstance
(
pass_arg
,
(
types
.
FunctionType
,
types
.
LambdaType
)):
...
...
src/ir/transform.cc
View file @
275e317c
...
@@ -474,10 +474,10 @@ TVM_REGISTER_GLOBAL("transform.ExitPassContext")
...
@@ -474,10 +474,10 @@ TVM_REGISTER_GLOBAL("transform.ExitPassContext")
.
set_body_typed
(
PassContext
::
Internal
::
ExitScope
);
.
set_body_typed
(
PassContext
::
Internal
::
ExitScope
);
Pass
PrintIR
(
std
::
string
header
)
{
Pass
PrintIR
(
std
::
string
header
,
bool
show_meta_data
)
{
auto
pass_func
=
[
header
](
IRModule
mod
,
const
PassContext
&
ctx
)
{
auto
pass_func
=
[
header
,
show_meta_data
](
IRModule
mod
,
const
PassContext
&
ctx
)
{
LOG
(
INFO
)
<<
"PrintIR("
<<
header
<<
"):
\n
"
LOG
(
INFO
)
<<
"PrintIR("
<<
header
<<
"):
\n
"
<<
mod
;
<<
AsText
(
mod
,
show_meta_data
)
;
return
mod
;
return
mod
;
};
};
return
CreateModulePass
(
pass_func
,
0
,
"PrintIR"
,
{});
return
CreateModulePass
(
pass_func
,
0
,
"PrintIR"
,
{});
...
...
src/relay/transforms/print_ir.cc
deleted
100644 → 0
View file @
f08d5d78
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file src/relay/transforms/print_ir.cc
*
* \brief Print the module IR to help debugging.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
namespace
tvm
{
namespace
relay
{
namespace
transform
{
Pass
PrintIR
(
bool
show_meta_data
)
{
runtime
::
TypedPackedFunc
<
IRModule
(
IRModule
,
PassContext
)
>
pass_func
=
[
=
](
IRModule
m
,
PassContext
pc
)
{
LOG
(
INFO
)
<<
"Dumping the module IR: "
<<
std
::
endl
<<
AsText
(
m
,
show_meta_data
);
return
m
;
};
return
CreateModulePass
(
pass_func
,
0
,
"PrintIR"
,
{});
}
TVM_REGISTER_GLOBAL
(
"relay._transform.PrintIR"
)
.
set_body_typed
(
PrintIR
);
}
// namespace transform
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_op_level10.py
View file @
275e317c
...
@@ -53,10 +53,10 @@ def test_checkpoint_alpha_equal():
...
@@ -53,10 +53,10 @@ def test_checkpoint_alpha_equal():
df
=
transform
.
gradient
(
run_infer_type
(
f
))
df
=
transform
.
gradient
(
run_infer_type
(
f
))
# run PE and DCE
# run PE and DCE
with
transform
.
PassContext
(
opt_level
=
3
):
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
passes
=
[
transform
.
PartialEvaluate
(),
passes
=
[
transform
.
PartialEvaluate
(),
transform
.
DeadCodeElimination
(
inline_once
=
True
)]
transform
.
DeadCodeElimination
(
inline_once
=
True
)]
mod
=
transform
.
Sequential
(
passes
)(
tvm
.
IRModule
.
from_expr
(
df
))
mod
=
t
vm
.
t
ransform
.
Sequential
(
passes
)(
tvm
.
IRModule
.
from_expr
(
df
))
df
=
mod
[
"main"
]
df
=
mod
[
"main"
]
df_parsed
=
relay
.
parser
.
fromtext
(
df_parsed
=
relay
.
parser
.
fromtext
(
...
@@ -109,10 +109,10 @@ def test_checkpoint_alpha_equal_tuple():
...
@@ -109,10 +109,10 @@ def test_checkpoint_alpha_equal_tuple():
df
=
transform
.
gradient
(
run_infer_type
(
f
))
df
=
transform
.
gradient
(
run_infer_type
(
f
))
# run PE and DCE
# run PE and DCE
with
transform
.
PassContext
(
opt_level
=
3
):
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
passes
=
[
transform
.
PartialEvaluate
(),
passes
=
[
transform
.
PartialEvaluate
(),
transform
.
DeadCodeElimination
(
inline_once
=
True
)]
transform
.
DeadCodeElimination
(
inline_once
=
True
)]
mod
=
transform
.
Sequential
(
passes
)(
tvm
.
IRModule
.
from_expr
(
df
))
mod
=
t
vm
.
t
ransform
.
Sequential
(
passes
)(
tvm
.
IRModule
.
from_expr
(
df
))
df
=
mod
[
"main"
]
df
=
mod
[
"main"
]
df_parsed
=
relay
.
parser
.
fromtext
(
df_parsed
=
relay
.
parser
.
fromtext
(
...
...
tests/python/relay/test_pass_alter_op_layout.py
View file @
275e317c
...
@@ -26,8 +26,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
...
@@ -26,8 +26,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
def
run_opt_pass
(
expr
,
passes
):
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
...
...
tests/python/relay/test_pass_annotation.py
View file @
275e317c
...
@@ -28,8 +28,8 @@ from tvm.relay import transform
...
@@ -28,8 +28,8 @@ from tvm.relay import transform
def
run_opt_pass
(
expr
,
passes
):
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
mod
=
seq
(
mod
)
return
mod
[
"main"
]
return
mod
[
"main"
]
...
...
tests/python/relay/test_pass_canonicalize_cast.py
View file @
275e317c
...
@@ -54,9 +54,9 @@ def test_canonicalize_cast():
...
@@ -54,9 +54,9 @@ def test_canonicalize_cast():
bias2
=
relay
.
var
(
"bias2"
,
shape
=
(
16
,
1
,
1
),
dtype
=
"int32"
)
bias2
=
relay
.
var
(
"bias2"
,
shape
=
(
16
,
1
,
1
),
dtype
=
"int32"
)
y
=
before
(
data
,
conv_weight
,
bias1
,
bias2
)
y
=
before
(
data
,
conv_weight
,
bias1
,
bias2
)
mod
=
tvm
.
IRModule
.
from_expr
(
y
)
mod
=
tvm
.
IRModule
.
from_expr
(
y
)
seq
=
_
transform
.
Sequential
([
_transform
.
InferType
(),
_transform
.
CanonicalizeCast
(),
seq
=
tvm
.
transform
.
Sequential
([
_transform
.
InferType
(),
_transform
.
CanonicalizeCast
(),
_transform
.
InferType
()])
_transform
.
InferType
()])
with
_
transform
.
PassContext
(
opt_level
=
3
):
with
tvm
.
transform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
mod
=
seq
(
mod
)
y
=
mod
[
"main"
]
y
=
mod
[
"main"
]
y_expected
=
expected
(
data
,
conv_weight
,
bias1
,
bias2
)
y_expected
=
expected
(
data
,
conv_weight
,
bias1
,
bias2
)
...
...
tests/python/relay/test_pass_combine_parallel_conv2d.py
View file @
275e317c
...
@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
...
@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
return
mod
[
"main"
]
return
mod
[
"main"
]
def
run_opt_pass
(
expr
,
opt_pass
):
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
mod
=
opt_pass
(
mod
)
return
mod
[
"main"
]
return
mod
[
"main"
]
...
...
tests/python/relay/test_pass_combine_parallel_dense.py
View file @
275e317c
...
@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
...
@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
return
mod
[
"main"
]
return
mod
[
"main"
]
def
run_opt_pass
(
expr
,
opt_pass
):
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
mod
=
opt_pass
(
mod
)
return
mod
[
"main"
]
return
mod
[
"main"
]
...
...
tests/python/relay/test_pass_convert_op_layout.py
View file @
275e317c
...
@@ -26,8 +26,8 @@ from tvm.relay import transform, analysis
...
@@ -26,8 +26,8 @@ from tvm.relay import transform, analysis
def
run_opt_pass
(
expr
,
passes
):
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
...
...
tests/python/relay/test_pass_dead_code_elimination.py
View file @
275e317c
...
@@ -47,7 +47,7 @@ e = env()
...
@@ -47,7 +47,7 @@ e = env()
def
run_opt_pass
(
expr
,
opt_pass
):
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
mod
=
opt_pass
(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
...
...
tests/python/relay/test_pass_eliminate_common_subexpr.py
View file @
275e317c
...
@@ -24,7 +24,7 @@ from tvm.relay import transform, analysis
...
@@ -24,7 +24,7 @@ from tvm.relay import transform, analysis
def
run_opt_pass
(
expr
,
opt_pass
):
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
mod
=
opt_pass
(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
...
...
tests/python/relay/test_pass_eta_expand.py
View file @
275e317c
...
@@ -33,8 +33,8 @@ def test_eta_expand_global_var():
...
@@ -33,8 +33,8 @@ def test_eta_expand_global_var():
@aux
@aux
}
}
"""
)
"""
)
seq
=
_
transform
.
Sequential
([
_transform
.
EtaExpand
(
expand_global_var
=
True
)])
seq
=
tvm
.
transform
.
Sequential
([
_transform
.
EtaExpand
(
expand_global_var
=
True
)])
with
_
transform
.
PassContext
(
opt_level
=
3
):
with
tvm
.
transform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
mod
=
seq
(
mod
)
expected
=
relay
.
fromtext
(
r"""
expected
=
relay
.
fromtext
(
r"""
v0.0.4
v0.0.4
...
@@ -62,8 +62,8 @@ def test_eta_expand_constructor():
...
@@ -62,8 +62,8 @@ def test_eta_expand_constructor():
Cons
Cons
}
}
"""
)
"""
)
seq
=
_
transform
.
Sequential
([
_transform
.
EtaExpand
(
expand_constructor
=
True
)])
seq
=
tvm
.
transform
.
Sequential
([
_transform
.
EtaExpand
(
expand_constructor
=
True
)])
with
_
transform
.
PassContext
(
opt_level
=
3
):
with
tvm
.
transform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
mod
=
seq
(
mod
)
expected
=
relay
.
fromtext
(
r"""
expected
=
relay
.
fromtext
(
r"""
v0.0.4
v0.0.4
...
...
tests/python/relay/test_pass_fold_constant.py
View file @
275e317c
...
@@ -24,7 +24,7 @@ from tvm.relay.testing import run_infer_type, create_workload
...
@@ -24,7 +24,7 @@ from tvm.relay.testing import run_infer_type, create_workload
def
run_opt_pass
(
expr
,
opt_pass
):
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
mod
=
opt_pass
(
mod
)
...
@@ -174,7 +174,7 @@ def test_fold_batch_norm():
...
@@ -174,7 +174,7 @@ def test_fold_batch_norm():
add
=
relay
.
add
(
conv
,
bias
)
add
=
relay
.
add
(
conv
,
bias
)
return
relay
.
Function
(
relay
.
analysis
.
free_vars
(
add
),
add
)
return
relay
.
Function
(
relay
.
analysis
.
free_vars
(
add
),
add
)
remove_bn_pass
=
transform
.
Sequential
([
remove_bn_pass
=
t
vm
.
t
ransform
.
Sequential
([
relay
.
transform
.
InferType
(),
relay
.
transform
.
InferType
(),
relay
.
transform
.
SimplifyInference
(),
relay
.
transform
.
SimplifyInference
(),
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
FoldConstant
(),
...
...
tests/python/relay/test_pass_fold_scale_axis.py
View file @
275e317c
...
@@ -26,7 +26,7 @@ def _get_positive_scale(size):
...
@@ -26,7 +26,7 @@ def _get_positive_scale(size):
def
run_opt_pass
(
expr
,
opt_pass
):
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
mod
=
opt_pass
(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
...
...
tests/python/relay/test_pass_lazy_gradient_init.py
View file @
275e317c
...
@@ -80,7 +80,7 @@ def test_add_tuple():
...
@@ -80,7 +80,7 @@ def test_add_tuple():
mod
[
"main"
]
=
y
mod
[
"main"
]
=
y
mod
=
transform
.
LazyGradientInit
()(
mod
)
mod
=
transform
.
LazyGradientInit
()(
mod
)
mod
=
transform
.
PrintIR
(
show_meta_data
=
True
)(
mod
)
mod
=
t
vm
.
t
ransform
.
PrintIR
(
show_meta_data
=
True
)(
mod
)
y
=
mod
[
"main"
]
y
=
mod
[
"main"
]
assert
mod
[
"main"
]
.
checked_type
==
relay
.
FuncType
([
t
],
tensor_type
)
assert
mod
[
"main"
]
.
checked_type
==
relay
.
FuncType
([
t
],
tensor_type
)
...
@@ -116,7 +116,7 @@ def test_mult():
...
@@ -116,7 +116,7 @@ def test_mult():
def
test_ret_tuple
():
def
test_ret_tuple
():
"""Test tuple return type. Check types and semantic equivalence."""
"""Test tuple return type. Check types and semantic equivalence."""
mod
=
tvm
.
IRModule
()
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
shape
=
(
10
,
10
)
dtype
=
'float32'
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
@@ -141,7 +141,7 @@ def test_ret_tuple():
...
@@ -141,7 +141,7 @@ def test_ret_tuple():
def
test_add_broadcast
():
def
test_add_broadcast
():
"""Test adding matrices of different size. Check types and semantic equivalence."""
"""Test adding matrices of different size. Check types and semantic equivalence."""
mod
=
tvm
.
IRModule
()
mod
=
tvm
.
IRModule
()
shape1
=
(
3
,
4
,
1
)
shape1
=
(
3
,
4
,
1
)
shape2
=
(
1
,
5
)
shape2
=
(
1
,
5
)
dtype
=
'float32'
dtype
=
'float32'
...
@@ -173,7 +173,7 @@ def test_reverse_ad_identity():
...
@@ -173,7 +173,7 @@ def test_reverse_ad_identity():
"""Simple test with reverse mode ad."""
"""Simple test with reverse mode ad."""
# of f(x) = x
# of f(x) = x
mod
=
tvm
.
IRModule
()
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
shape
=
(
10
,
10
)
dtype
=
'float32'
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
@@ -201,7 +201,7 @@ def test_reverse_ad_identity():
...
@@ -201,7 +201,7 @@ def test_reverse_ad_identity():
def
test_multivar_reverse_ad
():
def
test_multivar_reverse_ad
():
"""Simple test with multivariate reverse mode ad."""
"""Simple test with multivariate reverse mode ad."""
mod
=
tvm
.
IRModule
()
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
shape
=
(
10
,
10
)
dtype
=
'float32'
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
@@ -232,7 +232,7 @@ def test_multivar_reverse_ad():
...
@@ -232,7 +232,7 @@ def test_multivar_reverse_ad():
def
test_after_partial_eval
():
def
test_after_partial_eval
():
"""Test transformation following reverse mode ad and PartialEval"""
"""Test transformation following reverse mode ad and PartialEval"""
mod
=
tvm
.
IRModule
()
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
shape
=
(
10
,
10
)
dtype
=
'float32'
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
@@ -248,7 +248,7 @@ def test_after_partial_eval():
...
@@ -248,7 +248,7 @@ def test_after_partial_eval():
mod
[
"main"
]
=
back_func
mod
[
"main"
]
=
back_func
back_func
=
mod
[
"main"
]
back_func
=
mod
[
"main"
]
seq
=
transform
.
Sequential
([
seq
=
t
vm
.
t
ransform
.
Sequential
([
transform
.
PartialEvaluate
(),
transform
.
PartialEvaluate
(),
transform
.
LazyGradientInit
(),
transform
.
LazyGradientInit
(),
transform
.
DeadCodeElimination
()
transform
.
DeadCodeElimination
()
...
@@ -270,7 +270,7 @@ def test_after_partial_eval():
...
@@ -270,7 +270,7 @@ def test_after_partial_eval():
def
test_before_partial_eval
():
def
test_before_partial_eval
():
"""Test transformation before PartialEval"""
"""Test transformation before PartialEval"""
mod
=
tvm
.
IRModule
()
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
shape
=
(
10
,
10
)
dtype
=
'float32'
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
@@ -284,7 +284,7 @@ def test_before_partial_eval():
...
@@ -284,7 +284,7 @@ def test_before_partial_eval():
back_func
=
run_infer_type
(
back_func
)
back_func
=
run_infer_type
(
back_func
)
mod
[
"main"
]
=
back_func
mod
[
"main"
]
=
back_func
seq
=
transform
.
Sequential
([
seq
=
t
vm
.
t
ransform
.
Sequential
([
transform
.
LazyGradientInit
(),
transform
.
LazyGradientInit
(),
transform
.
PartialEvaluate
(),
transform
.
PartialEvaluate
(),
transform
.
DeadCodeElimination
()
transform
.
DeadCodeElimination
()
...
@@ -306,7 +306,7 @@ def test_before_partial_eval():
...
@@ -306,7 +306,7 @@ def test_before_partial_eval():
def
test_zeros
():
def
test_zeros
():
"""Simple test using "zeros" op"""
"""Simple test using "zeros" op"""
mod
=
tvm
.
IRModule
()
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
shape
=
(
10
,
10
)
dtype
=
'float32'
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
@@ -328,7 +328,7 @@ def test_zeros():
...
@@ -328,7 +328,7 @@ def test_zeros():
def
test_ones
():
def
test_ones
():
"""Simple test using "ones" op"""
"""Simple test using "ones" op"""
mod
=
tvm
.
IRModule
()
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
shape
=
(
10
,
10
)
dtype
=
'float32'
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
@@ -350,7 +350,7 @@ def test_ones():
...
@@ -350,7 +350,7 @@ def test_ones():
def
test_zeros_like
():
def
test_zeros_like
():
"""Simple test using "zeros_like" op"""
"""Simple test using "zeros_like" op"""
mod
=
tvm
.
IRModule
()
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
shape
=
(
10
,
10
)
dtype
=
'float32'
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
@@ -372,7 +372,7 @@ def test_zeros_like():
...
@@ -372,7 +372,7 @@ def test_zeros_like():
def
test_ones_like
():
def
test_ones_like
():
"""Simple test using "ones_like" op"""
"""Simple test using "ones_like" op"""
mod
=
tvm
.
IRModule
()
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
shape
=
(
10
,
10
)
dtype
=
'float32'
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
...
tests/python/relay/test_pass_legalize.py
View file @
275e317c
...
@@ -28,8 +28,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
...
@@ -28,8 +28,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
def
run_opt_pass
(
expr
,
passes
):
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
...
...
tests/python/relay/test_pass_mac_count.py
View file @
275e317c
...
@@ -23,7 +23,7 @@ from tvm.relay import analysis, transform
...
@@ -23,7 +23,7 @@ from tvm.relay import analysis, transform
def
run_opt_pass
(
expr
,
opt_pass
):
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
mod
=
opt_pass
(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
...
...
tests/python/relay/test_pass_manager.py
View file @
275e317c
...
@@ -129,13 +129,13 @@ def test_module_pass():
...
@@ -129,13 +129,13 @@ def test_module_pass():
opt_tester
=
OptTester
(
mod
)
opt_tester
=
OptTester
(
mod
)
pass_ctx
=
None
pass_ctx
=
None
@
_
transform.module_pass
(
opt_level
=
opt_level
,
name
=
pass_name
)
@
tvm.
transform.module_pass
(
opt_level
=
opt_level
,
name
=
pass_name
)
def
transform
(
expr
,
ctx
):
def
transform
(
expr
,
ctx
):
return
opt_tester
.
transform
(
expr
,
ctx
)
return
opt_tester
.
transform
(
expr
,
ctx
)
def
test_pass_registration
():
def
test_pass_registration
():
mod_pass
=
transform
mod_pass
=
transform
assert
isinstance
(
mod_pass
,
_
transform
.
ModulePass
)
assert
isinstance
(
mod_pass
,
tvm
.
transform
.
ModulePass
)
pass_info
=
mod_pass
.
info
pass_info
=
mod_pass
.
info
assert
pass_info
.
name
==
pass_name
assert
pass_info
.
name
==
pass_name
assert
pass_info
.
opt_level
==
opt_level
assert
pass_info
.
opt_level
==
opt_level
...
@@ -143,8 +143,8 @@ def test_module_pass():
...
@@ -143,8 +143,8 @@ def test_module_pass():
def
test_pass_registration_no_decorator
():
def
test_pass_registration_no_decorator
():
def
direct_transform
(
expr
,
ctx
):
def
direct_transform
(
expr
,
ctx
):
return
opt_tester
.
transform
(
expr
,
ctx
)
return
opt_tester
.
transform
(
expr
,
ctx
)
mod_pass
=
_
transform
.
module_pass
(
direct_transform
,
opt_level
=
3
)
mod_pass
=
tvm
.
transform
.
module_pass
(
direct_transform
,
opt_level
=
3
)
assert
isinstance
(
mod_pass
,
_
transform
.
ModulePass
)
assert
isinstance
(
mod_pass
,
tvm
.
transform
.
ModulePass
)
pass_info
=
mod_pass
.
info
pass_info
=
mod_pass
.
info
assert
pass_info
.
name
==
"direct_transform"
assert
pass_info
.
name
==
"direct_transform"
assert
pass_info
.
opt_level
==
3
assert
pass_info
.
opt_level
==
3
...
@@ -285,7 +285,7 @@ def test_function_pass():
...
@@ -285,7 +285,7 @@ def test_function_pass():
def
test_module_class_pass
():
def
test_module_class_pass
():
@
relay
.transform.module_pass
(
opt_level
=
1
)
@
tvm
.transform.module_pass
(
opt_level
=
1
)
class
TestPipeline
:
class
TestPipeline
:
"""Simple test function to replace one argument to another."""
"""Simple test function to replace one argument to another."""
def
__init__
(
self
,
new_mod
,
replace
):
def
__init__
(
self
,
new_mod
,
replace
):
...
@@ -309,7 +309,7 @@ def test_module_class_pass():
...
@@ -309,7 +309,7 @@ def test_module_class_pass():
def
test_pass_info
():
def
test_pass_info
():
info
=
relay
.
transform
.
PassInfo
(
opt_level
=
1
,
name
=
"xyz"
)
info
=
tvm
.
transform
.
PassInfo
(
opt_level
=
1
,
name
=
"xyz"
)
assert
info
.
opt_level
==
1
assert
info
.
opt_level
==
1
assert
info
.
name
==
"xyz"
assert
info
.
name
==
"xyz"
...
@@ -350,7 +350,7 @@ def test_sequential_pass():
...
@@ -350,7 +350,7 @@ def test_sequential_pass():
opt_tester
=
OptTester
(
mod
)
opt_tester
=
OptTester
(
mod
)
pass_ctx
=
None
pass_ctx
=
None
@
_
transform.module_pass
(
opt_level
=
1
)
@
tvm.
transform.module_pass
(
opt_level
=
1
)
def
mod_transform
(
expr
,
ctx
):
def
mod_transform
(
expr
,
ctx
):
return
opt_tester
.
transform
(
expr
,
ctx
)
return
opt_tester
.
transform
(
expr
,
ctx
)
...
@@ -367,21 +367,21 @@ def test_sequential_pass():
...
@@ -367,21 +367,21 @@ def test_sequential_pass():
passes
=
[
module_pass
,
function_pass
]
passes
=
[
module_pass
,
function_pass
]
opt_level
=
2
opt_level
=
2
pass_name
=
"sequential"
pass_name
=
"sequential"
sequential
=
_
transform
.
Sequential
(
passes
=
passes
,
opt_level
=
opt_level
)
sequential
=
tvm
.
transform
.
Sequential
(
passes
=
passes
,
opt_level
=
opt_level
)
pass_info
=
sequential
.
info
pass_info
=
sequential
.
info
assert
pass_info
.
name
==
pass_name
assert
pass_info
.
name
==
pass_name
assert
pass_info
.
opt_level
==
opt_level
assert
pass_info
.
opt_level
==
opt_level
def
test_no_pass
():
def
test_no_pass
():
passes
=
[]
passes
=
[]
sequential
=
_
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
sequential
=
tvm
.
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
ret_mod
=
sequential
(
mod
)
ret_mod
=
sequential
(
mod
)
mod_func
=
ret_mod
[
v_sub
]
mod_func
=
ret_mod
[
v_sub
]
check_func
(
sub
,
mod_func
)
check_func
(
sub
,
mod_func
)
def
test_only_module_pass
():
def
test_only_module_pass
():
passes
=
[
module_pass
]
passes
=
[
module_pass
]
sequential
=
_
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
sequential
=
tvm
.
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
with
relay
.
build_config
(
required_pass
=
[
"mod_transform"
]):
with
relay
.
build_config
(
required_pass
=
[
"mod_transform"
]):
ret_mod
=
sequential
(
mod
)
ret_mod
=
sequential
(
mod
)
# Check the subtract function.
# Check the subtract function.
...
@@ -396,7 +396,7 @@ def test_sequential_pass():
...
@@ -396,7 +396,7 @@ def test_sequential_pass():
def
test_only_function_pass
():
def
test_only_function_pass
():
# Check the subtract function.
# Check the subtract function.
passes
=
[
function_pass
]
passes
=
[
function_pass
]
sequential
=
_
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
sequential
=
tvm
.
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
with
relay
.
build_config
(
required_pass
=
[
"func_transform"
]):
with
relay
.
build_config
(
required_pass
=
[
"func_transform"
]):
ret_mod
=
sequential
(
mod
)
ret_mod
=
sequential
(
mod
)
_
,
new_sub
=
extract_var_func
(
ret_mod
,
v_sub
.
name_hint
)
_
,
new_sub
=
extract_var_func
(
ret_mod
,
v_sub
.
name_hint
)
...
@@ -411,7 +411,7 @@ def test_sequential_pass():
...
@@ -411,7 +411,7 @@ def test_sequential_pass():
# function pass.
# function pass.
mod
=
tvm
.
IRModule
({
v_sub
:
sub
,
v_log
:
log
})
mod
=
tvm
.
IRModule
({
v_sub
:
sub
,
v_log
:
log
})
passes
=
[
module_pass
,
function_pass
]
passes
=
[
module_pass
,
function_pass
]
sequential
=
_
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
sequential
=
tvm
.
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
required
=
[
"mod_transform"
,
"func_transform"
]
required
=
[
"mod_transform"
,
"func_transform"
]
with
relay
.
build_config
(
required_pass
=
required
):
with
relay
.
build_config
(
required_pass
=
required
):
ret_mod
=
sequential
(
mod
)
ret_mod
=
sequential
(
mod
)
...
@@ -482,7 +482,7 @@ def test_sequential_with_scoping():
...
@@ -482,7 +482,7 @@ def test_sequential_with_scoping():
z1
=
relay
.
add
(
z
,
z
)
z1
=
relay
.
add
(
z
,
z
)
return
relay
.
Function
([
x
],
z1
)
return
relay
.
Function
([
x
],
z1
)
seq
=
_
transform
.
Sequential
([
seq
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
InferType
(),
relay
.
transform
.
InferType
(),
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
EliminateCommonSubexpr
(),
relay
.
transform
.
EliminateCommonSubexpr
(),
...
@@ -507,10 +507,10 @@ def test_print_ir(capfd):
...
@@ -507,10 +507,10 @@ def test_print_ir(capfd):
y
=
relay
.
multiply
(
y
,
relay
.
const
(
2
,
"float32"
))
y
=
relay
.
multiply
(
y
,
relay
.
const
(
2
,
"float32"
))
func
=
relay
.
Function
([
x
],
y
)
func
=
relay
.
Function
([
x
],
y
)
seq
=
_
transform
.
Sequential
([
seq
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
InferType
(),
relay
.
transform
.
InferType
(),
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
PrintIR
(),
tvm
.
transform
.
PrintIR
(),
relay
.
transform
.
DeadCodeElimination
()
relay
.
transform
.
DeadCodeElimination
()
])
])
...
@@ -520,7 +520,7 @@ def test_print_ir(capfd):
...
@@ -520,7 +520,7 @@ def test_print_ir(capfd):
out
=
capfd
.
readouterr
()
.
err
out
=
capfd
.
readouterr
()
.
err
assert
"
Dumping the module
IR"
in
out
assert
"
Print
IR"
in
out
assert
"multiply"
in
out
assert
"multiply"
in
out
__TRACE_COUNTER__
=
0
__TRACE_COUNTER__
=
0
...
@@ -539,7 +539,7 @@ def test_print_debug_callback():
...
@@ -539,7 +539,7 @@ def test_print_debug_callback():
y
=
relay
.
multiply
(
y
,
relay
.
const
(
2
,
"float32"
))
y
=
relay
.
multiply
(
y
,
relay
.
const
(
2
,
"float32"
))
func
=
relay
.
Function
([
x
],
y
)
func
=
relay
.
Function
([
x
],
y
)
seq
=
_
transform
.
Sequential
([
seq
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
InferType
(),
relay
.
transform
.
InferType
(),
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
DeadCodeElimination
()
relay
.
transform
.
DeadCodeElimination
()
...
...
tests/python/relay/test_pass_partial_eval.py
View file @
275e317c
...
@@ -38,8 +38,8 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
...
@@ -38,8 +38,8 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
def
run_opt_pass
(
expr
,
passes
):
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
...
@@ -58,7 +58,7 @@ def dcpe(expr, mod=None, grad=False):
...
@@ -58,7 +58,7 @@ def dcpe(expr, mod=None, grad=False):
if
mod
:
if
mod
:
assert
isinstance
(
expr
,
Function
)
assert
isinstance
(
expr
,
Function
)
mod
[
"main"
]
=
expr
mod
[
"main"
]
=
expr
seq
=
transform
.
Sequential
(
passes
)
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
mod
=
seq
(
mod
)
mod
=
seq
(
mod
)
return
mod
[
"main"
]
return
mod
[
"main"
]
return
run_opt_pass
(
expr
,
passes
)
return
run_opt_pass
(
expr
,
passes
)
...
...
tests/python/relay/test_pass_partition_graph.py
View file @
275e317c
...
@@ -496,7 +496,7 @@ def test_function_lifting():
...
@@ -496,7 +496,7 @@ def test_function_lifting():
op_list
=
[
"nn.batch_norm"
,
"nn.conv2d"
]
op_list
=
[
"nn.batch_norm"
,
"nn.conv2d"
]
mod
=
WhiteListAnnotator
(
op_list
,
"test_compiler"
)(
mod
)
mod
=
WhiteListAnnotator
(
op_list
,
"test_compiler"
)(
mod
)
opt_pass
=
transform
.
Sequential
([
opt_pass
=
t
vm
.
t
ransform
.
Sequential
([
transform
.
InferType
(),
transform
.
InferType
(),
transform
.
PartitionGraph
(),
transform
.
PartitionGraph
(),
transform
.
SimplifyInference
(),
transform
.
SimplifyInference
(),
...
@@ -578,7 +578,7 @@ def test_function_lifting_inline():
...
@@ -578,7 +578,7 @@ def test_function_lifting_inline():
op_list
=
[
"nn.batch_norm"
,
"nn.conv2d"
]
op_list
=
[
"nn.batch_norm"
,
"nn.conv2d"
]
mod
=
WhiteListAnnotator
(
op_list
,
"test_compiler"
)(
mod
)
mod
=
WhiteListAnnotator
(
op_list
,
"test_compiler"
)(
mod
)
opt_pass
=
transform
.
Sequential
([
opt_pass
=
t
vm
.
t
ransform
.
Sequential
([
transform
.
InferType
(),
transform
.
InferType
(),
transform
.
PartitionGraph
(),
transform
.
PartitionGraph
(),
transform
.
SimplifyInference
(),
transform
.
SimplifyInference
(),
...
@@ -878,13 +878,13 @@ def test_dnnl_fuse():
...
@@ -878,13 +878,13 @@ def test_dnnl_fuse():
# This is required for constant folding
# This is required for constant folding
mod
[
"main"
]
=
bind_params_by_name
(
mod
[
"main"
],
params
)
mod
[
"main"
]
=
bind_params_by_name
(
mod
[
"main"
],
params
)
remove_bn_pass
=
transform
.
Sequential
([
remove_bn_pass
=
t
vm
.
t
ransform
.
Sequential
([
transform
.
InferType
(),
transform
.
InferType
(),
transform
.
SimplifyInference
(),
transform
.
SimplifyInference
(),
transform
.
FoldConstant
(),
transform
.
FoldConstant
(),
transform
.
FoldScaleAxis
(),
transform
.
FoldScaleAxis
(),
])
])
composite_partition
=
transform
.
Sequential
([
composite_partition
=
t
vm
.
t
ransform
.
Sequential
([
remove_bn_pass
,
remove_bn_pass
,
transform
.
MergeComposite
(
pattern_table
),
transform
.
MergeComposite
(
pattern_table
),
transform
.
AnnotateTarget
(
"dnnl"
),
transform
.
AnnotateTarget
(
"dnnl"
),
...
...
tests/python/relay/test_pass_qnn_legalize.py
View file @
275e317c
...
@@ -37,8 +37,8 @@ def alpha_equal(x, y):
...
@@ -37,8 +37,8 @@ def alpha_equal(x, y):
def
run_opt_pass
(
expr
,
passes
):
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
...
...
tests/python/relay/test_pass_to_a_normal_form.py
View file @
275e317c
...
@@ -28,8 +28,8 @@ from tvm.relay.analysis import Feature
...
@@ -28,8 +28,8 @@ from tvm.relay.analysis import Feature
def
run_opt_pass
(
expr
,
passes
):
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
...
...
tests/python/relay/test_pass_to_cps.py
View file @
275e317c
...
@@ -71,7 +71,8 @@ def test_cps_pe():
...
@@ -71,7 +71,8 @@ def test_cps_pe():
x
=
run_infer_type
(
x
)
x
=
run_infer_type
(
x
)
y
=
un_cps
(
x
)
y
=
un_cps
(
x
)
y
=
run_infer_type
(
y
)
y
=
run_infer_type
(
y
)
x
=
run_opt_pass
(
x
,
transform
.
Sequential
([
transform
.
PartialEvaluate
(),
transform
.
DeadCodeElimination
(
inline_once
=
True
)]))
x
=
run_opt_pass
(
x
,
tvm
.
transform
.
Sequential
(
[
transform
.
PartialEvaluate
(),
transform
.
DeadCodeElimination
(
inline_once
=
True
)]))
assert
Feature
.
fRefCreate
not
in
detect_feature
(
x
)
assert
Feature
.
fRefCreate
not
in
detect_feature
(
x
)
unit
=
relay
.
Function
([],
relay
.
const
(
0.
,
dtype
=
'float32'
))
unit
=
relay
.
Function
([],
relay
.
const
(
0.
,
dtype
=
'float32'
))
f_ref
=
relay
.
Var
(
"f_ref"
)
f_ref
=
relay
.
Var
(
"f_ref"
)
...
...
tutorials/dev/relay_pass_infra.py
View file @
275e317c
...
@@ -29,7 +29,7 @@ introduced an infrastructure to manage the optimization passes.
...
@@ -29,7 +29,7 @@ introduced an infrastructure to manage the optimization passes.
The optimizations of a Relay program could be applied at various granularity,
The optimizations of a Relay program could be applied at various granularity,
namely function-level and module-level using :py:class:`tvm.relay.transform.FunctionPass`
namely function-level and module-level using :py:class:`tvm.relay.transform.FunctionPass`
and py:class:`tvm.relay.transform.ModulePass`
and py:class:`tvm.relay.transform.ModulePass`
respectively. Or users can rely on py:class:`tvm.
relay.
transform.Sequential` to apply a sequence of passes
respectively. Or users can rely on py:class:`tvm.transform.Sequential` to apply a sequence of passes
on a Relay program where the dependencies between passes can be resolved by the
on a Relay program where the dependencies between passes can be resolved by the
pass infra. For more details about each type of these passes, please refer to
pass infra. For more details about each type of these passes, please refer to
the :ref:`relay-pass-infra`
the :ref:`relay-pass-infra`
...
@@ -130,22 +130,22 @@ print(mod)
...
@@ -130,22 +130,22 @@ print(mod)
# fusion, as this pass generates let bindings for each expression to
# fusion, as this pass generates let bindings for each expression to
# canonicalize a Relay program.
# canonicalize a Relay program.
#
#
# Relay, hence, provides :py:class:`tvm.
relay.
transform.Sequential` to alleviate developers from handling
# Relay, hence, provides :py:class:`tvm.transform.Sequential` to alleviate developers from handling
# these issues explicitly by specifying the required passes of each pass and
# these issues explicitly by specifying the required passes of each pass and
# packing them as a whole to execute. For example, the same passes can now be
# packing them as a whole to execute. For example, the same passes can now be
# applied using the sequential style as the following. :py:class:`tvm.
relay.
transform.Sequential` is
# applied using the sequential style as the following. :py:class:`tvm.transform.Sequential` is
# similiar to `torch.nn.sequential <https://pytorch.org/docs/stable/nn.html#torch.nn.Sequential>`_
# similiar to `torch.nn.sequential <https://pytorch.org/docs/stable/nn.html#torch.nn.Sequential>`_
# and `mxnet.gluon.block <https://mxnet.incubator.apache.org/api/python/docs/_modules/mxnet/gluon/block.html>`_.
# and `mxnet.gluon.block <https://mxnet.incubator.apache.org/api/python/docs/_modules/mxnet/gluon/block.html>`_.
# For example, `torch.nn.sequential` is used to contain a sequence of PyTorch
# For example, `torch.nn.sequential` is used to contain a sequence of PyTorch
# `Modules` that will be added to build a network. It focuses on the network
# `Modules` that will be added to build a network. It focuses on the network
# layers. Instead, the :py:class:`tvm.
relay.
transform.Sequential` in our pass infra works on the optimizing
# layers. Instead, the :py:class:`tvm.transform.Sequential` in our pass infra works on the optimizing
# pass.
# pass.
# Now let's execute some passes through :py:class:`tvm.
relay.
transform.Sequential`
# Now let's execute some passes through :py:class:`tvm.transform.Sequential`
f
=
example
()
f
=
example
()
mod
=
tvm
.
IRModule
.
from_expr
(
f
)
mod
=
tvm
.
IRModule
.
from_expr
(
f
)
# Glob the interested passes.
# Glob the interested passes.
seq
=
relay
.
transform
.
Sequential
([
relay
.
transform
.
FoldConstant
(),
seq
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
EliminateCommonSubexpr
(),
relay
.
transform
.
EliminateCommonSubexpr
(),
relay
.
transform
.
FuseOps
(
fuse_opt_level
=
2
)])
relay
.
transform
.
FuseOps
(
fuse_opt_level
=
2
)])
mod1
=
seq
(
mod
)
mod1
=
seq
(
mod
)
...
@@ -156,7 +156,7 @@ print(mod1)
...
@@ -156,7 +156,7 @@ print(mod1)
# identical addition operations. This is because `EliminateCommonSubexpr`
# identical addition operations. This is because `EliminateCommonSubexpr`
# was not actually performed. The reason is because only the passes that have
# was not actually performed. The reason is because only the passes that have
# optimization level less or equal to 2 will be executed by default under
# optimization level less or equal to 2 will be executed by default under
# :py:class:`tvm.
relay.
transform.Sequential`. The pass infra,
# :py:class:`tvm.transform.Sequential`. The pass infra,
# however, provides a configuration interface
# however, provides a configuration interface
# for users to customize the optimization level that they want to execute.
# for users to customize the optimization level that they want to execute.
...
@@ -186,7 +186,7 @@ with relay.build_config(opt_level=3):
...
@@ -186,7 +186,7 @@ with relay.build_config(opt_level=3):
mod4
=
seq
(
mod
)
mod4
=
seq
(
mod
)
print
(
mod4
)
print
(
mod4
)
seq1
=
relay
.
transform
.
Sequential
([
relay
.
transform
.
AlterOpLayout
()])
seq1
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
AlterOpLayout
()])
with
relay
.
build_config
(
opt_level
=
3
):
with
relay
.
build_config
(
opt_level
=
3
):
with
tvm
.
target
.
create
(
"llvm"
):
with
tvm
.
target
.
create
(
"llvm"
):
mod5
=
seq1
(
mod
)
mod5
=
seq1
(
mod
)
...
@@ -237,11 +237,11 @@ print(mod3)
...
@@ -237,11 +237,11 @@ print(mod3)
f
=
example
()
f
=
example
()
mod
=
tvm
.
IRModule
.
from_expr
(
f
)
mod
=
tvm
.
IRModule
.
from_expr
(
f
)
seq
=
relay
.
transform
.
Sequential
([
relay
.
transform
.
FoldConstant
(),
seq
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
PrintIR
(
False
),
tvm
.
transform
.
PrintIR
(
),
relay
.
transform
.
EliminateCommonSubexpr
(),
relay
.
transform
.
EliminateCommonSubexpr
(),
relay
.
transform
.
FuseOps
(),
relay
.
transform
.
FuseOps
(),
relay
.
transform
.
PrintIR
(
False
)])
tvm
.
transform
.
PrintIR
(
)])
with
relay
.
build_config
(
opt_level
=
3
):
with
relay
.
build_config
(
opt_level
=
3
):
mod
=
seq
(
mod
)
mod
=
seq
(
mod
)
...
...
vta/python/vta/top/graphpack.py
View file @
275e317c
...
@@ -24,7 +24,7 @@ from tvm.relay import ExprMutator
...
@@ -24,7 +24,7 @@ from tvm.relay import ExprMutator
def
run_opt_pass
(
expr
,
opt_pass
):
def
run_opt_pass
(
expr
,
opt_pass
):
"""Exectue a relay pass."""
"""Exectue a relay pass."""
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
mod
=
opt_pass
(
mod
)
entry
=
mod
[
"main"
]
entry
=
mod
[
"main"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment