Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
275e317c
Unverified
Commit
275e317c
authored
Apr 14, 2020
by
Tianqi Chen
Committed by
GitHub
Apr 14, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY] Remove re-exports of tvm.transform (#5337)
parent
f08d5d78
Hide whitespace changes
Inline
Side-by-side
Showing
38 changed files
with
169 additions
and
229 deletions
+169
-229
docs/api/python/ir.rst
+8
-0
docs/dev/convert_layout.rst
+1
-1
docs/dev/relay_pass_infra.rst
+2
-2
include/tvm/ir/transform.h
+3
-1
python/tvm/ir/json_compact.py
+1
-1
python/tvm/ir/transform.py
+5
-2
python/tvm/relay/__init__.py
+0
-11
python/tvm/relay/backend/interpreter.py
+4
-4
python/tvm/relay/qnn/transform.py
+2
-2
python/tvm/relay/quantize/quantize.py
+18
-15
python/tvm/relay/testing/__init__.py
+1
-1
python/tvm/relay/testing/py_converter.py
+2
-2
python/tvm/relay/transform/transform.py
+36
-53
src/ir/transform.cc
+3
-3
src/relay/transforms/print_ir.cc
+0
-49
tests/python/relay/test_op_level10.py
+4
-4
tests/python/relay/test_pass_alter_op_layout.py
+2
-2
tests/python/relay/test_pass_annotation.py
+2
-2
tests/python/relay/test_pass_canonicalize_cast.py
+2
-2
tests/python/relay/test_pass_combine_parallel_conv2d.py
+1
-1
tests/python/relay/test_pass_combine_parallel_dense.py
+1
-1
tests/python/relay/test_pass_convert_op_layout.py
+2
-2
tests/python/relay/test_pass_dead_code_elimination.py
+1
-1
tests/python/relay/test_pass_eliminate_common_subexpr.py
+1
-1
tests/python/relay/test_pass_eta_expand.py
+4
-4
tests/python/relay/test_pass_fold_constant.py
+2
-2
tests/python/relay/test_pass_fold_scale_axis.py
+1
-1
tests/python/relay/test_pass_lazy_gradient_init.py
+13
-13
tests/python/relay/test_pass_legalize.py
+2
-2
tests/python/relay/test_pass_mac_count.py
+1
-1
tests/python/relay/test_pass_manager.py
+17
-17
tests/python/relay/test_pass_partial_eval.py
+3
-3
tests/python/relay/test_pass_partition_graph.py
+4
-4
tests/python/relay/test_pass_qnn_legalize.py
+2
-2
tests/python/relay/test_pass_to_a_normal_form.py
+2
-2
tests/python/relay/test_pass_to_cps.py
+2
-1
tutorials/dev/relay_pass_infra.py
+13
-13
vta/python/vta/top/graphpack.py
+1
-1
No files found.
docs/api/python/ir.rst
View file @
275e317c
...
...
@@ -21,3 +21,11 @@ tvm.ir
:members:
:imported-members:
:autosummary:
tvm.transform
-------------
.. automodule:: tvm.transform
:members:
:imported-members:
:autosummary:
docs/dev/convert_layout.rst
View file @
275e317c
...
...
@@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r
# Convert the layout to NCHW
# RemoveUnunsedFunctions is used to clean up the graph.
seq =
relay
.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
seq =
tvm
.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout('NCHW')])
with relay.transform.PassContext(opt_level=3):
mod = seq(mod)
...
...
docs/dev/relay_pass_infra.rst
View file @
275e317c
...
...
@@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes.
func = relay.Function([x], z2)
# Customize the optimization pipeline.
seq =
_
transform.Sequential([
seq =
tvm.
transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
...
...
@@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for
.. code:: python
seq =
_
transform.Sequential([
seq =
tvm.
transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.PrintIR(),
...
...
include/tvm/ir/transform.h
View file @
275e317c
...
...
@@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass(
/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
* \param show_meta_data Whether should we show meta data.
* \return The pass.
*/
TVM_DLL
Pass
PrintIR
(
std
::
string
header
);
TVM_DLL
Pass
PrintIR
(
std
::
string
header
=
""
,
bool
show_meta_data
=
false
);
}
// namespace transform
}
// namespace tvm
...
...
python/tvm/ir/json_compact.py
View file @
275e317c
...
...
@@ -106,7 +106,7 @@ def create_updater_06_to_07():
"relay.PassInfo"
:
_rename
(
"transform.PassInfo"
),
"relay.PassContext"
:
_rename
(
"transform.PassContext"
),
"relay.ModulePass"
:
_rename
(
"transform.ModulePass"
),
"relay.Sequ
antial"
:
_rename
(
"transform.Sequa
ntial"
),
"relay.Sequ
ential"
:
_rename
(
"transform.Seque
ntial"
),
# TIR
"Variable"
:
_update_tir_var
(
"tir.Var"
),
"SizeVar"
:
_update_tir_var
(
"tir.SizeVar"
),
...
...
python/tvm/ir/transform.py
View file @
275e317c
...
...
@@ -329,7 +329,7 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
return
create_module_pass
def
PrintIR
(
header
):
def
PrintIR
(
header
=
""
,
show_meta_data
=
False
):
"""A special trace pass that prints the header and IR.
Parameters
...
...
@@ -337,8 +337,11 @@ def PrintIR(header):
header : str
The header to be displayed along with the dump.
show_meta_data : bool
A boolean flag to indicate if meta data should be printed.
Returns
--------
The pass
"""
return
_ffi_transform_api
.
PrintIR
(
header
)
return
_ffi_transform_api
.
PrintIR
(
header
,
show_meta_data
)
python/tvm/relay/__init__.py
View file @
275e317c
...
...
@@ -128,20 +128,9 @@ Prelude = prelude.Prelude
# Scope builder
ScopeBuilder
=
scope_builder
.
ScopeBuilder
module_pass
=
transform
.
module_pass
function_pass
=
transform
.
function_pass
# Parser
fromtext
=
parser
.
fromtext
# Param Serialization
save_param_dict
=
param_dict
.
save_param_dict
load_param_dict
=
param_dict
.
load_param_dict
# Pass manager
PassInfo
=
transform
.
PassInfo
PassContext
=
transform
.
PassContext
Pass
=
transform
.
Pass
ModulePass
=
transform
.
ModulePass
FunctionPass
=
transform
.
FunctionPass
Sequential
=
transform
.
Sequential
python/tvm/relay/backend/interpreter.py
View file @
275e317c
...
...
@@ -210,10 +210,10 @@ class Interpreter(Executor):
opt_mod : tvm.IRModule
The optimized module.
"""
seq
=
transform
.
Sequential
([
transform
.
SimplifyInference
(),
transform
.
FuseOps
(
0
),
transform
.
ToANormalForm
(),
transform
.
InferType
()])
seq
=
t
vm
.
t
ransform
.
Sequential
([
transform
.
SimplifyInference
(),
transform
.
FuseOps
(
0
),
transform
.
ToANormalForm
(),
transform
.
InferType
()])
return
seq
(
self
.
mod
)
def
_make_executor
(
self
,
expr
=
None
):
...
...
python/tvm/relay/qnn/transform.py
View file @
275e317c
...
...
@@ -60,7 +60,7 @@ def CanonicalizeOps():
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that canonicalizes QNN ops to Relay ops.
"""
...
...
@@ -108,7 +108,7 @@ def Legalize():
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that legalizes QNN ops.
"""
...
...
python/tvm/relay/quantize/quantize.py
View file @
275e317c
...
...
@@ -17,6 +17,7 @@
#pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit."""
import
tvm.ir
import
tvm
from
tvm.runtime
import
Object
from
.
import
_quantize
...
...
@@ -240,7 +241,7 @@ def partition():
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass for VTA rewrite.
"""
return
_quantize
.
QuantizePartition
()
...
...
@@ -253,7 +254,7 @@ def annotate():
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass for quantization annotation.
"""
return
_quantize
.
QuantizeAnnotate
()
...
...
@@ -267,7 +268,7 @@ def realize():
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass for quantization realization.
"""
return
_quantize
.
QuantizeRealize
()
...
...
@@ -298,11 +299,12 @@ def prerequisite_optimize(mod, params=None):
""" Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """
optimize
=
_transform
.
Sequential
([
_transform
.
SimplifyInference
(),
_transform
.
FoldConstant
(),
_transform
.
FoldScaleAxis
(),
_transform
.
CanonicalizeOps
(),
_transform
.
FoldConstant
()])
optimize
=
tvm
.
transform
.
Sequential
(
[
_transform
.
SimplifyInference
(),
_transform
.
FoldConstant
(),
_transform
.
FoldScaleAxis
(),
_transform
.
CanonicalizeOps
(),
_transform
.
FoldConstant
()])
if
params
:
mod
[
'main'
]
=
_bind_params
(
mod
[
'main'
],
params
)
...
...
@@ -336,19 +338,20 @@ def quantize(mod, params=None, dataset=None):
"""
mod
=
prerequisite_optimize
(
mod
,
params
)
calibrate_pass
=
_transform
.
module_pass
(
calibrate
(
dataset
),
opt_level
=
1
,
name
=
"QuantizeCalibrate"
)
calibrate_pass
=
tvm
.
transform
.
module_pass
(
calibrate
(
dataset
),
opt_level
=
1
,
name
=
"QuantizeCalibrate"
)
quant_passes
=
[
partition
(),
annotate
(),
calibrate_pass
]
if
not
current_qconfig
()
.
do_simulation
:
quant_passes
.
append
(
realize
())
quant_passes
.
append
(
_transform
.
FoldConstant
())
quantize_seq
=
_
transform
.
Sequential
(
quant_passes
)
with
_
transform
.
PassContext
(
opt_level
=
3
,
required_pass
=
[
"QuantizeAnnotate"
,
"QuantizeCalibrate"
,
"QuantizeRealize"
]):
quantize_seq
=
tvm
.
transform
.
Sequential
(
quant_passes
)
with
tvm
.
transform
.
PassContext
(
opt_level
=
3
,
required_pass
=
[
"QuantizeAnnotate"
,
"QuantizeCalibrate"
,
"QuantizeRealize"
]):
with
quantize_context
():
mod
=
quantize_seq
(
mod
)
...
...
python/tvm/relay/testing/__init__.py
View file @
275e317c
...
...
@@ -47,7 +47,7 @@ from .py_converter import to_python, run_as_python
from
..transform
import
gradient
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
entry
=
mod
[
"main"
]
...
...
python/tvm/relay/testing/py_converter.py
View file @
275e317c
...
...
@@ -95,8 +95,8 @@ class PythonConverter(ExprFunctor):
# necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# and fusion (to get primitive functions)
opts
=
relay
.
transform
.
Sequential
([
relay
.
transform
.
SimplifyInference
(),
relay
.
transform
.
FuseOps
(
fuse_opt_level
=
0
)])
opts
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
SimplifyInference
(),
relay
.
transform
.
FuseOps
(
fuse_opt_level
=
0
)])
mod
=
opts
(
mod
)
optimized
=
mod
[
'main'
]
return
optimized
if
isinstance
(
unwrapped
,
Function
)
else
optimized
.
body
...
...
python/tvm/relay/transform/transform.py
View file @
275e317c
...
...
@@ -22,10 +22,9 @@ import types
import
inspect
import
functools
import
tvm
import
tvm
.ir
from
tvm
import
te
from
tvm.runtime
import
ndarray
as
_nd
from
tvm.ir.transform
import
PassInfo
,
PassContext
,
Pass
,
ModulePass
,
Sequential
,
module_pass
from
tvm
import
relay
from
.
import
_ffi_api
...
...
@@ -78,12 +77,13 @@ def build_config(opt_level=2,
pass_context: PassContext
The pass context for optimizations.
"""
return
PassContext
(
opt_level
,
fallback_device
,
required_pass
,
disabled_pass
,
trace
)
return
tvm
.
ir
.
transform
.
PassContext
(
opt_level
,
fallback_device
,
required_pass
,
disabled_pass
,
trace
)
@tvm._ffi.register_object
(
"relay.FunctionPass"
)
class
FunctionPass
(
Pass
):
class
FunctionPass
(
tvm
.
ir
.
transform
.
Pass
):
"""A pass that works on each tvm.relay.Function in a module. A function
pass class should be created through `function_pass`.
"""
...
...
@@ -94,7 +94,7 @@ def InferType():
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered type inference pass.
"""
return
_ffi_api
.
InferType
()
...
...
@@ -106,7 +106,7 @@ def FoldScaleAxis():
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass to fold expressions.
Note
...
...
@@ -123,7 +123,7 @@ def BackwardFoldScaleAxis():
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass to backward fold expressions.
Note
...
...
@@ -144,7 +144,7 @@ def RemoveUnusedFunctions(entry_functions=None):
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass to remove unused functions.
"""
if
entry_functions
is
None
:
...
...
@@ -156,7 +156,7 @@ def ForwardFoldScaleAxis():
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass to forward fold expressions.
Note
...
...
@@ -174,7 +174,7 @@ def SimplifyInference():
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass to perform operator simplification.
"""
return
_ffi_api
.
SimplifyInference
()
...
...
@@ -185,7 +185,7 @@ def FastMath():
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass to perform fast math operations.
"""
return
_ffi_api
.
FastMath
()
...
...
@@ -198,7 +198,7 @@ def CanonicalizeOps():
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass performing the canonicalization.
"""
return
_ffi_api
.
CanonicalizeOps
()
...
...
@@ -214,7 +214,7 @@ def DeadCodeElimination(inline_once=False):
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that eliminates the dead code in a Relay program.
"""
return
_ffi_api
.
DeadCodeElimination
(
inline_once
)
...
...
@@ -227,7 +227,7 @@ def LazyGradientInit():
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
A pass which delays and/or reduces memory allocation,
by lazily allocating 0 or one filled tensors.
"""
...
...
@@ -238,7 +238,7 @@ def FoldConstant():
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass for constant folding.
"""
return
_ffi_api
.
FoldConstant
()
...
...
@@ -255,7 +255,7 @@ def FuseOps(fuse_opt_level=-1):
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass for operator fusion.
"""
return
_ffi_api
.
FuseOps
(
fuse_opt_level
)
...
...
@@ -272,7 +272,7 @@ def CombineParallelConv2D(min_num_branches=3):
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that combines parallel conv2d operators.
"""
return
_ffi_api
.
CombineParallelConv2D
(
min_num_branches
)
...
...
@@ -304,7 +304,7 @@ def CombineParallelDense(min_num_branches=3):
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that combines parallel dense operators.
"""
return
_ffi_api
.
CombineParallelDense
(
min_num_branches
)
...
...
@@ -318,7 +318,7 @@ def AlterOpLayout():
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that alters the layout of operators.
"""
return
_ffi_api
.
AlterOpLayout
()
...
...
@@ -366,7 +366,7 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that rewrites an expr.
"""
return
_ffi_api
.
Legalize
(
legalize_map_attr_name
)
...
...
@@ -387,7 +387,7 @@ def MergeComposite(pattern_table):
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that merges operators into a single composite
relay function.
"""
...
...
@@ -413,7 +413,7 @@ def MergeCompilerRegions():
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that merges compiler regions.
"""
return
_ffi_api
.
MergeCompilerRegions
()
...
...
@@ -433,7 +433,7 @@ def RewriteAnnotatedOps(fallback_device):
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that rewrites an expression with annotated
`on_device` operators.
"""
...
...
@@ -448,7 +448,7 @@ def ToANormalForm():
Returns
-------
ret: Union[tvm.
relay
.Pass, tvm.relay.Expr]
ret: Union[tvm.
transform
.Pass, tvm.relay.Expr]
The registered pass that transforms an expression into A Normal Form.
"""
return
_ffi_api
.
ToANormalForm
()
...
...
@@ -462,7 +462,7 @@ def ToCPS(expr, mod=None):
Returns
-------
result: tvm.
relay
.Pass
result: tvm.
transform
.Pass
The registered pass that transforms an expression into CPS.
"""
return
_ffi_api
.
to_cps
(
expr
,
mod
)
...
...
@@ -481,7 +481,7 @@ def EtaExpand(expand_constructor=False, expand_global_var=False):
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that eta expands an expression.
"""
return
_ffi_api
.
EtaExpand
(
expand_constructor
,
expand_global_var
)
...
...
@@ -492,7 +492,7 @@ def ToGraphNormalForm():
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that transforms an expression into Graph Normal Form.
"""
return
_ffi_api
.
ToGraphNormalForm
()
...
...
@@ -509,7 +509,7 @@ def EliminateCommonSubexpr(fskip=None):
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that eliminates common subexpressions.
"""
return
_ffi_api
.
EliminateCommonSubexpr
(
fskip
)
...
...
@@ -527,7 +527,7 @@ def PartialEvaluate():
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that performs partial evaluation on an expression.
"""
return
_ffi_api
.
PartialEvaluate
()
...
...
@@ -539,7 +539,7 @@ def CanonicalizeCast():
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that canonicalizes cast expression.
"""
return
_ffi_api
.
CanonicalizeCast
()
...
...
@@ -551,36 +551,19 @@ def LambdaLift():
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The registered pass that lifts the lambda function.
"""
return
_ffi_api
.
LambdaLift
()
def
PrintIR
(
show_meta_data
=
True
):
"""
Print the IR for a module to help debugging.
Parameters
----------
show_meta_data : bool
A boolean flag to indicate if meta data should be printed.
Returns
-------
ret : tvm.relay.Pass
The registered pass that prints the module IR.
"""
return
_ffi_api
.
PrintIR
(
show_meta_data
)
def
PartitionGraph
():
"""Partition a Relay program into regions that can be executed on different
backends.
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that partitions the Relay program.
"""
return
_ffi_api
.
PartitionGraph
()
...
...
@@ -598,7 +581,7 @@ def AnnotateTarget(targets):
Returns
-------
ret : tvm.
relay
.Pass
ret : tvm.
transform
.Pass
The annotated pass that wrapps ops with subgraph_start and
subgraph_end.
"""
...
...
@@ -614,7 +597,7 @@ def Inline():
Returns
-------
ret: tvm.
relay
.Pass
ret: tvm.
transform
.Pass
The registered pass that performs inlining for a Relay IR module.
"""
return
_ffi_api
.
Inline
()
...
...
@@ -809,7 +792,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
def
create_function_pass
(
pass_arg
):
"""Internal function that creates a function pass"""
fname
=
name
if
name
else
pass_arg
.
__name__
info
=
PassInfo
(
opt_level
,
fname
,
required
)
info
=
tvm
.
transform
.
PassInfo
(
opt_level
,
fname
,
required
)
if
inspect
.
isclass
(
pass_arg
):
return
_wrap_class_function_pass
(
pass_arg
,
info
)
if
not
isinstance
(
pass_arg
,
(
types
.
FunctionType
,
types
.
LambdaType
)):
...
...
src/ir/transform.cc
View file @
275e317c
...
...
@@ -474,10 +474,10 @@ TVM_REGISTER_GLOBAL("transform.ExitPassContext")
.
set_body_typed
(
PassContext
::
Internal
::
ExitScope
);
Pass
PrintIR
(
std
::
string
header
)
{
auto
pass_func
=
[
header
](
IRModule
mod
,
const
PassContext
&
ctx
)
{
Pass
PrintIR
(
std
::
string
header
,
bool
show_meta_data
)
{
auto
pass_func
=
[
header
,
show_meta_data
](
IRModule
mod
,
const
PassContext
&
ctx
)
{
LOG
(
INFO
)
<<
"PrintIR("
<<
header
<<
"):
\n
"
<<
mod
;
<<
AsText
(
mod
,
show_meta_data
)
;
return
mod
;
};
return
CreateModulePass
(
pass_func
,
0
,
"PrintIR"
,
{});
...
...
src/relay/transforms/print_ir.cc
deleted
100644 → 0
View file @
f08d5d78
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file src/relay/transforms/print_ir.cc
*
* \brief Print the module IR to help debugging.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
namespace
tvm
{
namespace
relay
{
namespace
transform
{
Pass
PrintIR
(
bool
show_meta_data
)
{
runtime
::
TypedPackedFunc
<
IRModule
(
IRModule
,
PassContext
)
>
pass_func
=
[
=
](
IRModule
m
,
PassContext
pc
)
{
LOG
(
INFO
)
<<
"Dumping the module IR: "
<<
std
::
endl
<<
AsText
(
m
,
show_meta_data
);
return
m
;
};
return
CreateModulePass
(
pass_func
,
0
,
"PrintIR"
,
{});
}
TVM_REGISTER_GLOBAL
(
"relay._transform.PrintIR"
)
.
set_body_typed
(
PrintIR
);
}
// namespace transform
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_op_level10.py
View file @
275e317c
...
...
@@ -53,10 +53,10 @@ def test_checkpoint_alpha_equal():
df
=
transform
.
gradient
(
run_infer_type
(
f
))
# run PE and DCE
with
transform
.
PassContext
(
opt_level
=
3
):
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
passes
=
[
transform
.
PartialEvaluate
(),
transform
.
DeadCodeElimination
(
inline_once
=
True
)]
mod
=
transform
.
Sequential
(
passes
)(
tvm
.
IRModule
.
from_expr
(
df
))
mod
=
t
vm
.
t
ransform
.
Sequential
(
passes
)(
tvm
.
IRModule
.
from_expr
(
df
))
df
=
mod
[
"main"
]
df_parsed
=
relay
.
parser
.
fromtext
(
...
...
@@ -109,10 +109,10 @@ def test_checkpoint_alpha_equal_tuple():
df
=
transform
.
gradient
(
run_infer_type
(
f
))
# run PE and DCE
with
transform
.
PassContext
(
opt_level
=
3
):
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
passes
=
[
transform
.
PartialEvaluate
(),
transform
.
DeadCodeElimination
(
inline_once
=
True
)]
mod
=
transform
.
Sequential
(
passes
)(
tvm
.
IRModule
.
from_expr
(
df
))
mod
=
t
vm
.
t
ransform
.
Sequential
(
passes
)(
tvm
.
IRModule
.
from_expr
(
df
))
df
=
mod
[
"main"
]
df_parsed
=
relay
.
parser
.
fromtext
(
...
...
tests/python/relay/test_pass_alter_op_layout.py
View file @
275e317c
...
...
@@ -26,8 +26,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
...
...
tests/python/relay/test_pass_annotation.py
View file @
275e317c
...
...
@@ -28,8 +28,8 @@ from tvm.relay import transform
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
return
mod
[
"main"
]
...
...
tests/python/relay/test_pass_canonicalize_cast.py
View file @
275e317c
...
...
@@ -54,9 +54,9 @@ def test_canonicalize_cast():
bias2
=
relay
.
var
(
"bias2"
,
shape
=
(
16
,
1
,
1
),
dtype
=
"int32"
)
y
=
before
(
data
,
conv_weight
,
bias1
,
bias2
)
mod
=
tvm
.
IRModule
.
from_expr
(
y
)
seq
=
_
transform
.
Sequential
([
_transform
.
InferType
(),
_transform
.
CanonicalizeCast
(),
seq
=
tvm
.
transform
.
Sequential
([
_transform
.
InferType
(),
_transform
.
CanonicalizeCast
(),
_transform
.
InferType
()])
with
_
transform
.
PassContext
(
opt_level
=
3
):
with
tvm
.
transform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
y
=
mod
[
"main"
]
y_expected
=
expected
(
data
,
conv_weight
,
bias1
,
bias2
)
...
...
tests/python/relay/test_pass_combine_parallel_conv2d.py
View file @
275e317c
...
...
@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
return
mod
[
"main"
]
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
return
mod
[
"main"
]
...
...
tests/python/relay/test_pass_combine_parallel_dense.py
View file @
275e317c
...
...
@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
return
mod
[
"main"
]
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
return
mod
[
"main"
]
...
...
tests/python/relay/test_pass_convert_op_layout.py
View file @
275e317c
...
...
@@ -26,8 +26,8 @@ from tvm.relay import transform, analysis
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
...
...
tests/python/relay/test_pass_dead_code_elimination.py
View file @
275e317c
...
...
@@ -47,7 +47,7 @@ e = env()
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
entry
=
mod
[
"main"
]
...
...
tests/python/relay/test_pass_eliminate_common_subexpr.py
View file @
275e317c
...
...
@@ -24,7 +24,7 @@ from tvm.relay import transform, analysis
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
entry
=
mod
[
"main"
]
...
...
tests/python/relay/test_pass_eta_expand.py
View file @
275e317c
...
...
@@ -33,8 +33,8 @@ def test_eta_expand_global_var():
@aux
}
"""
)
seq
=
_
transform
.
Sequential
([
_transform
.
EtaExpand
(
expand_global_var
=
True
)])
with
_
transform
.
PassContext
(
opt_level
=
3
):
seq
=
tvm
.
transform
.
Sequential
([
_transform
.
EtaExpand
(
expand_global_var
=
True
)])
with
tvm
.
transform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
expected
=
relay
.
fromtext
(
r"""
v0.0.4
...
...
@@ -62,8 +62,8 @@ def test_eta_expand_constructor():
Cons
}
"""
)
seq
=
_
transform
.
Sequential
([
_transform
.
EtaExpand
(
expand_constructor
=
True
)])
with
_
transform
.
PassContext
(
opt_level
=
3
):
seq
=
tvm
.
transform
.
Sequential
([
_transform
.
EtaExpand
(
expand_constructor
=
True
)])
with
tvm
.
transform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
expected
=
relay
.
fromtext
(
r"""
v0.0.4
...
...
tests/python/relay/test_pass_fold_constant.py
View file @
275e317c
...
...
@@ -24,7 +24,7 @@ from tvm.relay.testing import run_infer_type, create_workload
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
...
...
@@ -174,7 +174,7 @@ def test_fold_batch_norm():
add
=
relay
.
add
(
conv
,
bias
)
return
relay
.
Function
(
relay
.
analysis
.
free_vars
(
add
),
add
)
remove_bn_pass
=
transform
.
Sequential
([
remove_bn_pass
=
t
vm
.
t
ransform
.
Sequential
([
relay
.
transform
.
InferType
(),
relay
.
transform
.
SimplifyInference
(),
relay
.
transform
.
FoldConstant
(),
...
...
tests/python/relay/test_pass_fold_scale_axis.py
View file @
275e317c
...
...
@@ -26,7 +26,7 @@ def _get_positive_scale(size):
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
entry
=
mod
[
"main"
]
...
...
tests/python/relay/test_pass_lazy_gradient_init.py
View file @
275e317c
...
...
@@ -80,7 +80,7 @@ def test_add_tuple():
mod
[
"main"
]
=
y
mod
=
transform
.
LazyGradientInit
()(
mod
)
mod
=
transform
.
PrintIR
(
show_meta_data
=
True
)(
mod
)
mod
=
t
vm
.
t
ransform
.
PrintIR
(
show_meta_data
=
True
)(
mod
)
y
=
mod
[
"main"
]
assert
mod
[
"main"
]
.
checked_type
==
relay
.
FuncType
([
t
],
tensor_type
)
...
...
@@ -116,7 +116,7 @@ def test_mult():
def
test_ret_tuple
():
"""Test tuple return type. Check types and semantic equivalence."""
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
...
@@ -141,7 +141,7 @@ def test_ret_tuple():
def
test_add_broadcast
():
"""Test adding matrices of different size. Check types and semantic equivalence."""
mod
=
tvm
.
IRModule
()
shape1
=
(
3
,
4
,
1
)
shape2
=
(
1
,
5
)
dtype
=
'float32'
...
...
@@ -173,7 +173,7 @@ def test_reverse_ad_identity():
"""Simple test with reverse mode ad."""
# of f(x) = x
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
...
@@ -201,7 +201,7 @@ def test_reverse_ad_identity():
def
test_multivar_reverse_ad
():
"""Simple test with multivariate reverse mode ad."""
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
...
@@ -232,7 +232,7 @@ def test_multivar_reverse_ad():
def
test_after_partial_eval
():
"""Test transformation following reverse mode ad and PartialEval"""
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
...
@@ -248,7 +248,7 @@ def test_after_partial_eval():
mod
[
"main"
]
=
back_func
back_func
=
mod
[
"main"
]
seq
=
transform
.
Sequential
([
seq
=
t
vm
.
t
ransform
.
Sequential
([
transform
.
PartialEvaluate
(),
transform
.
LazyGradientInit
(),
transform
.
DeadCodeElimination
()
...
...
@@ -270,7 +270,7 @@ def test_after_partial_eval():
def
test_before_partial_eval
():
"""Test transformation before PartialEval"""
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
...
@@ -284,7 +284,7 @@ def test_before_partial_eval():
back_func
=
run_infer_type
(
back_func
)
mod
[
"main"
]
=
back_func
seq
=
transform
.
Sequential
([
seq
=
t
vm
.
t
ransform
.
Sequential
([
transform
.
LazyGradientInit
(),
transform
.
PartialEvaluate
(),
transform
.
DeadCodeElimination
()
...
...
@@ -306,7 +306,7 @@ def test_before_partial_eval():
def
test_zeros
():
"""Simple test using "zeros" op"""
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
...
@@ -328,7 +328,7 @@ def test_zeros():
def
test_ones
():
"""Simple test using "ones" op"""
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
...
@@ -350,7 +350,7 @@ def test_ones():
def
test_zeros_like
():
"""Simple test using "zeros_like" op"""
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
...
@@ -372,7 +372,7 @@ def test_zeros_like():
def
test_ones_like
():
"""Simple test using "ones_like" op"""
mod
=
tvm
.
IRModule
()
shape
=
(
10
,
10
)
dtype
=
'float32'
t
=
relay
.
TensorType
(
shape
,
dtype
)
...
...
tests/python/relay/test_pass_legalize.py
View file @
275e317c
...
...
@@ -28,8 +28,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
...
...
tests/python/relay/test_pass_mac_count.py
View file @
275e317c
...
...
@@ -23,7 +23,7 @@ from tvm.relay import analysis, transform
def
run_opt_pass
(
expr
,
opt_pass
):
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
entry
=
mod
[
"main"
]
...
...
tests/python/relay/test_pass_manager.py
View file @
275e317c
...
...
@@ -129,13 +129,13 @@ def test_module_pass():
opt_tester
=
OptTester
(
mod
)
pass_ctx
=
None
@
_
transform.module_pass
(
opt_level
=
opt_level
,
name
=
pass_name
)
@
tvm.
transform.module_pass
(
opt_level
=
opt_level
,
name
=
pass_name
)
def
transform
(
expr
,
ctx
):
return
opt_tester
.
transform
(
expr
,
ctx
)
def
test_pass_registration
():
mod_pass
=
transform
assert
isinstance
(
mod_pass
,
_
transform
.
ModulePass
)
assert
isinstance
(
mod_pass
,
tvm
.
transform
.
ModulePass
)
pass_info
=
mod_pass
.
info
assert
pass_info
.
name
==
pass_name
assert
pass_info
.
opt_level
==
opt_level
...
...
@@ -143,8 +143,8 @@ def test_module_pass():
def
test_pass_registration_no_decorator
():
def
direct_transform
(
expr
,
ctx
):
return
opt_tester
.
transform
(
expr
,
ctx
)
mod_pass
=
_
transform
.
module_pass
(
direct_transform
,
opt_level
=
3
)
assert
isinstance
(
mod_pass
,
_
transform
.
ModulePass
)
mod_pass
=
tvm
.
transform
.
module_pass
(
direct_transform
,
opt_level
=
3
)
assert
isinstance
(
mod_pass
,
tvm
.
transform
.
ModulePass
)
pass_info
=
mod_pass
.
info
assert
pass_info
.
name
==
"direct_transform"
assert
pass_info
.
opt_level
==
3
...
...
@@ -285,7 +285,7 @@ def test_function_pass():
def
test_module_class_pass
():
@
relay
.transform.module_pass
(
opt_level
=
1
)
@
tvm
.transform.module_pass
(
opt_level
=
1
)
class
TestPipeline
:
"""Simple test function to replace one argument to another."""
def
__init__
(
self
,
new_mod
,
replace
):
...
...
@@ -309,7 +309,7 @@ def test_module_class_pass():
def
test_pass_info
():
info
=
relay
.
transform
.
PassInfo
(
opt_level
=
1
,
name
=
"xyz"
)
info
=
tvm
.
transform
.
PassInfo
(
opt_level
=
1
,
name
=
"xyz"
)
assert
info
.
opt_level
==
1
assert
info
.
name
==
"xyz"
...
...
@@ -350,7 +350,7 @@ def test_sequential_pass():
opt_tester
=
OptTester
(
mod
)
pass_ctx
=
None
@
_
transform.module_pass
(
opt_level
=
1
)
@
tvm.
transform.module_pass
(
opt_level
=
1
)
def
mod_transform
(
expr
,
ctx
):
return
opt_tester
.
transform
(
expr
,
ctx
)
...
...
@@ -367,21 +367,21 @@ def test_sequential_pass():
passes
=
[
module_pass
,
function_pass
]
opt_level
=
2
pass_name
=
"sequential"
sequential
=
_
transform
.
Sequential
(
passes
=
passes
,
opt_level
=
opt_level
)
sequential
=
tvm
.
transform
.
Sequential
(
passes
=
passes
,
opt_level
=
opt_level
)
pass_info
=
sequential
.
info
assert
pass_info
.
name
==
pass_name
assert
pass_info
.
opt_level
==
opt_level
def
test_no_pass
():
passes
=
[]
sequential
=
_
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
sequential
=
tvm
.
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
ret_mod
=
sequential
(
mod
)
mod_func
=
ret_mod
[
v_sub
]
check_func
(
sub
,
mod_func
)
def
test_only_module_pass
():
passes
=
[
module_pass
]
sequential
=
_
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
sequential
=
tvm
.
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
with
relay
.
build_config
(
required_pass
=
[
"mod_transform"
]):
ret_mod
=
sequential
(
mod
)
# Check the subtract function.
...
...
@@ -396,7 +396,7 @@ def test_sequential_pass():
def
test_only_function_pass
():
# Check the subtract function.
passes
=
[
function_pass
]
sequential
=
_
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
sequential
=
tvm
.
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
with
relay
.
build_config
(
required_pass
=
[
"func_transform"
]):
ret_mod
=
sequential
(
mod
)
_
,
new_sub
=
extract_var_func
(
ret_mod
,
v_sub
.
name_hint
)
...
...
@@ -411,7 +411,7 @@ def test_sequential_pass():
# function pass.
mod
=
tvm
.
IRModule
({
v_sub
:
sub
,
v_log
:
log
})
passes
=
[
module_pass
,
function_pass
]
sequential
=
_
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
sequential
=
tvm
.
transform
.
Sequential
(
opt_level
=
1
,
passes
=
passes
)
required
=
[
"mod_transform"
,
"func_transform"
]
with
relay
.
build_config
(
required_pass
=
required
):
ret_mod
=
sequential
(
mod
)
...
...
@@ -482,7 +482,7 @@ def test_sequential_with_scoping():
z1
=
relay
.
add
(
z
,
z
)
return
relay
.
Function
([
x
],
z1
)
seq
=
_
transform
.
Sequential
([
seq
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
InferType
(),
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
EliminateCommonSubexpr
(),
...
...
@@ -507,10 +507,10 @@ def test_print_ir(capfd):
y
=
relay
.
multiply
(
y
,
relay
.
const
(
2
,
"float32"
))
func
=
relay
.
Function
([
x
],
y
)
seq
=
_
transform
.
Sequential
([
seq
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
InferType
(),
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
PrintIR
(),
tvm
.
transform
.
PrintIR
(),
relay
.
transform
.
DeadCodeElimination
()
])
...
...
@@ -520,7 +520,7 @@ def test_print_ir(capfd):
out
=
capfd
.
readouterr
()
.
err
assert
"
Dumping the module
IR"
in
out
assert
"
Print
IR"
in
out
assert
"multiply"
in
out
__TRACE_COUNTER__
=
0
...
...
@@ -539,7 +539,7 @@ def test_print_debug_callback():
y
=
relay
.
multiply
(
y
,
relay
.
const
(
2
,
"float32"
))
func
=
relay
.
Function
([
x
],
y
)
seq
=
_
transform
.
Sequential
([
seq
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
InferType
(),
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
DeadCodeElimination
()
...
...
tests/python/relay/test_pass_partial_eval.py
View file @
275e317c
...
...
@@ -38,8 +38,8 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
...
...
@@ -58,7 +58,7 @@ def dcpe(expr, mod=None, grad=False):
if
mod
:
assert
isinstance
(
expr
,
Function
)
mod
[
"main"
]
=
expr
seq
=
transform
.
Sequential
(
passes
)
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
mod
=
seq
(
mod
)
return
mod
[
"main"
]
return
run_opt_pass
(
expr
,
passes
)
...
...
tests/python/relay/test_pass_partition_graph.py
View file @
275e317c
...
...
@@ -496,7 +496,7 @@ def test_function_lifting():
op_list
=
[
"nn.batch_norm"
,
"nn.conv2d"
]
mod
=
WhiteListAnnotator
(
op_list
,
"test_compiler"
)(
mod
)
opt_pass
=
transform
.
Sequential
([
opt_pass
=
t
vm
.
t
ransform
.
Sequential
([
transform
.
InferType
(),
transform
.
PartitionGraph
(),
transform
.
SimplifyInference
(),
...
...
@@ -578,7 +578,7 @@ def test_function_lifting_inline():
op_list
=
[
"nn.batch_norm"
,
"nn.conv2d"
]
mod
=
WhiteListAnnotator
(
op_list
,
"test_compiler"
)(
mod
)
opt_pass
=
transform
.
Sequential
([
opt_pass
=
t
vm
.
t
ransform
.
Sequential
([
transform
.
InferType
(),
transform
.
PartitionGraph
(),
transform
.
SimplifyInference
(),
...
...
@@ -878,13 +878,13 @@ def test_dnnl_fuse():
# This is required for constant folding
mod
[
"main"
]
=
bind_params_by_name
(
mod
[
"main"
],
params
)
remove_bn_pass
=
transform
.
Sequential
([
remove_bn_pass
=
t
vm
.
t
ransform
.
Sequential
([
transform
.
InferType
(),
transform
.
SimplifyInference
(),
transform
.
FoldConstant
(),
transform
.
FoldScaleAxis
(),
])
composite_partition
=
transform
.
Sequential
([
composite_partition
=
t
vm
.
t
ransform
.
Sequential
([
remove_bn_pass
,
transform
.
MergeComposite
(
pattern_table
),
transform
.
AnnotateTarget
(
"dnnl"
),
...
...
tests/python/relay/test_pass_qnn_legalize.py
View file @
275e317c
...
...
@@ -37,8 +37,8 @@ def alpha_equal(x, y):
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
...
...
tests/python/relay/test_pass_to_a_normal_form.py
View file @
275e317c
...
...
@@ -28,8 +28,8 @@ from tvm.relay.analysis import Feature
def
run_opt_pass
(
expr
,
passes
):
passes
=
passes
if
isinstance
(
passes
,
list
)
else
[
passes
]
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
seq
=
transform
.
Sequential
(
passes
)
with
transform
.
PassContext
(
opt_level
=
3
):
seq
=
t
vm
.
t
ransform
.
Sequential
(
passes
)
with
t
vm
.
t
ransform
.
PassContext
(
opt_level
=
3
):
mod
=
seq
(
mod
)
entry
=
mod
[
"main"
]
return
entry
if
isinstance
(
expr
,
relay
.
Function
)
else
entry
.
body
...
...
tests/python/relay/test_pass_to_cps.py
View file @
275e317c
...
...
@@ -71,7 +71,8 @@ def test_cps_pe():
x
=
run_infer_type
(
x
)
y
=
un_cps
(
x
)
y
=
run_infer_type
(
y
)
x
=
run_opt_pass
(
x
,
transform
.
Sequential
([
transform
.
PartialEvaluate
(),
transform
.
DeadCodeElimination
(
inline_once
=
True
)]))
x
=
run_opt_pass
(
x
,
tvm
.
transform
.
Sequential
(
[
transform
.
PartialEvaluate
(),
transform
.
DeadCodeElimination
(
inline_once
=
True
)]))
assert
Feature
.
fRefCreate
not
in
detect_feature
(
x
)
unit
=
relay
.
Function
([],
relay
.
const
(
0.
,
dtype
=
'float32'
))
f_ref
=
relay
.
Var
(
"f_ref"
)
...
...
tutorials/dev/relay_pass_infra.py
View file @
275e317c
...
...
@@ -29,7 +29,7 @@ introduced an infrastructure to manage the optimization passes.
The optimizations of a Relay program could be applied at various granularity,
namely function-level and module-level using :py:class:`tvm.relay.transform.FunctionPass`
and py:class:`tvm.relay.transform.ModulePass`
respectively. Or users can rely on py:class:`tvm.
relay.
transform.Sequential` to apply a sequence of passes
respectively. Or users can rely on py:class:`tvm.transform.Sequential` to apply a sequence of passes
on a Relay program where the dependencies between passes can be resolved by the
pass infra. For more details about each type of these passes, please refer to
the :ref:`relay-pass-infra`
...
...
@@ -130,22 +130,22 @@ print(mod)
# fusion, as this pass generates let bindings for each expression to
# canonicalize a Relay program.
#
# Relay, hence, provides :py:class:`tvm.
relay.
transform.Sequential` to alleviate developers from handling
# Relay, hence, provides :py:class:`tvm.transform.Sequential` to alleviate developers from handling
# these issues explicitly by specifying the required passes of each pass and
# packing them as a whole to execute. For example, the same passes can now be
# applied using the sequential style as the following. :py:class:`tvm.
relay.
transform.Sequential` is
# applied using the sequential style as the following. :py:class:`tvm.transform.Sequential` is
# similiar to `torch.nn.sequential <https://pytorch.org/docs/stable/nn.html#torch.nn.Sequential>`_
# and `mxnet.gluon.block <https://mxnet.incubator.apache.org/api/python/docs/_modules/mxnet/gluon/block.html>`_.
# For example, `torch.nn.sequential` is used to contain a sequence of PyTorch
# `Modules` that will be added to build a network. It focuses on the network
# layers. Instead, the :py:class:`tvm.
relay.
transform.Sequential` in our pass infra works on the optimizing
# layers. Instead, the :py:class:`tvm.transform.Sequential` in our pass infra works on the optimizing
# pass.
# Now let's execute some passes through :py:class:`tvm.
relay.
transform.Sequential`
# Now let's execute some passes through :py:class:`tvm.transform.Sequential`
f
=
example
()
mod
=
tvm
.
IRModule
.
from_expr
(
f
)
# Glob the interested passes.
seq
=
relay
.
transform
.
Sequential
([
relay
.
transform
.
FoldConstant
(),
seq
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
EliminateCommonSubexpr
(),
relay
.
transform
.
FuseOps
(
fuse_opt_level
=
2
)])
mod1
=
seq
(
mod
)
...
...
@@ -156,7 +156,7 @@ print(mod1)
# identical addition operations. This is because `EliminateCommonSubexpr`
# was not actually performed. The reason is because only the passes that have
# optimization level less or equal to 2 will be executed by default under
# :py:class:`tvm.
relay.
transform.Sequential`. The pass infra,
# :py:class:`tvm.transform.Sequential`. The pass infra,
# however, provides a configuration interface
# for users to customize the optimization level that they want to execute.
...
...
@@ -186,7 +186,7 @@ with relay.build_config(opt_level=3):
mod4
=
seq
(
mod
)
print
(
mod4
)
seq1
=
relay
.
transform
.
Sequential
([
relay
.
transform
.
AlterOpLayout
()])
seq1
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
AlterOpLayout
()])
with
relay
.
build_config
(
opt_level
=
3
):
with
tvm
.
target
.
create
(
"llvm"
):
mod5
=
seq1
(
mod
)
...
...
@@ -237,11 +237,11 @@ print(mod3)
f
=
example
()
mod
=
tvm
.
IRModule
.
from_expr
(
f
)
seq
=
relay
.
transform
.
Sequential
([
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
PrintIR
(
False
),
relay
.
transform
.
EliminateCommonSubexpr
(),
relay
.
transform
.
FuseOps
(),
relay
.
transform
.
PrintIR
(
False
)])
seq
=
tvm
.
transform
.
Sequential
([
relay
.
transform
.
FoldConstant
(),
tvm
.
transform
.
PrintIR
(
),
relay
.
transform
.
EliminateCommonSubexpr
(),
relay
.
transform
.
FuseOps
(),
tvm
.
transform
.
PrintIR
(
)])
with
relay
.
build_config
(
opt_level
=
3
):
mod
=
seq
(
mod
)
...
...
vta/python/vta/top/graphpack.py
View file @
275e317c
...
...
@@ -24,7 +24,7 @@ from tvm.relay import ExprMutator
def
run_opt_pass
(
expr
,
opt_pass
):
"""Exectue a relay pass."""
assert
isinstance
(
opt_pass
,
transform
.
Pass
)
assert
isinstance
(
opt_pass
,
t
vm
.
t
ransform
.
Pass
)
mod
=
tvm
.
IRModule
.
from_expr
(
expr
)
mod
=
opt_pass
(
mod
)
entry
=
mod
[
"main"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment