Unverified Commit 275e317c by Tianqi Chen Committed by GitHub

[RELAY] Remove re-exports of tvm.transform (#5337)

parent f08d5d78
...@@ -21,3 +21,11 @@ tvm.ir ...@@ -21,3 +21,11 @@ tvm.ir
:members: :members:
:imported-members: :imported-members:
:autosummary: :autosummary:
tvm.transform
-------------
.. automodule:: tvm.transform
:members:
:imported-members:
:autosummary:
...@@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r ...@@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r
# Convert the layout to NCHW # Convert the layout to NCHW
# RemoveUnunsedFunctions is used to clean up the graph. # RemoveUnunsedFunctions is used to clean up the graph.
seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(), seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout('NCHW')]) relay.transform.ConvertLayout('NCHW')])
with relay.transform.PassContext(opt_level=3): with relay.transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
......
...@@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes. ...@@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes.
func = relay.Function([x], z2) func = relay.Function([x], z2)
# Customize the optimization pipeline. # Customize the optimization pipeline.
seq = _transform.Sequential([ seq = tvm.transform.Sequential([
relay.transform.InferType(), relay.transform.InferType(),
relay.transform.FoldConstant(), relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(), relay.transform.EliminateCommonSubexpr(),
...@@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for ...@@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for
.. code:: python .. code:: python
seq = _transform.Sequential([ seq = tvm.transform.Sequential([
relay.transform.InferType(), relay.transform.InferType(),
relay.transform.FoldConstant(), relay.transform.FoldConstant(),
relay.transform.PrintIR(), relay.transform.PrintIR(),
......
...@@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass( ...@@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass(
/*! /*!
* \brief A special trace pass that prints the header and IR to LOG(INFO). * \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
* \param show_meta_data Whether should we show meta data.
* \return The pass. * \return The pass.
*/ */
TVM_DLL Pass PrintIR(std::string header); TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false);
} // namespace transform } // namespace transform
} // namespace tvm } // namespace tvm
......
...@@ -106,7 +106,7 @@ def create_updater_06_to_07(): ...@@ -106,7 +106,7 @@ def create_updater_06_to_07():
"relay.PassInfo": _rename("transform.PassInfo"), "relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"), "relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"), "relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequantial": _rename("transform.Sequantial"), "relay.Sequential": _rename("transform.Sequential"),
# TIR # TIR
"Variable": _update_tir_var("tir.Var"), "Variable": _update_tir_var("tir.Var"),
"SizeVar": _update_tir_var("tir.SizeVar"), "SizeVar": _update_tir_var("tir.SizeVar"),
......
...@@ -329,7 +329,7 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None): ...@@ -329,7 +329,7 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
return create_module_pass return create_module_pass
def PrintIR(header): def PrintIR(header="", show_meta_data=False):
"""A special trace pass that prints the header and IR. """A special trace pass that prints the header and IR.
Parameters Parameters
...@@ -337,8 +337,11 @@ def PrintIR(header): ...@@ -337,8 +337,11 @@ def PrintIR(header):
header : str header : str
The header to be displayed along with the dump. The header to be displayed along with the dump.
show_meta_data : bool
A boolean flag to indicate if meta data should be printed.
Returns Returns
-------- --------
The pass The pass
""" """
return _ffi_transform_api.PrintIR(header) return _ffi_transform_api.PrintIR(header, show_meta_data)
...@@ -128,20 +128,9 @@ Prelude = prelude.Prelude ...@@ -128,20 +128,9 @@ Prelude = prelude.Prelude
# Scope builder # Scope builder
ScopeBuilder = scope_builder.ScopeBuilder ScopeBuilder = scope_builder.ScopeBuilder
module_pass = transform.module_pass
function_pass = transform.function_pass
# Parser # Parser
fromtext = parser.fromtext fromtext = parser.fromtext
# Param Serialization # Param Serialization
save_param_dict = param_dict.save_param_dict save_param_dict = param_dict.save_param_dict
load_param_dict = param_dict.load_param_dict load_param_dict = param_dict.load_param_dict
# Pass manager
PassInfo = transform.PassInfo
PassContext = transform.PassContext
Pass = transform.Pass
ModulePass = transform.ModulePass
FunctionPass = transform.FunctionPass
Sequential = transform.Sequential
...@@ -210,7 +210,7 @@ class Interpreter(Executor): ...@@ -210,7 +210,7 @@ class Interpreter(Executor):
opt_mod : tvm.IRModule opt_mod : tvm.IRModule
The optimized module. The optimized module.
""" """
seq = transform.Sequential([transform.SimplifyInference(), seq = tvm.transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0), transform.FuseOps(0),
transform.ToANormalForm(), transform.ToANormalForm(),
transform.InferType()]) transform.InferType()])
......
...@@ -60,7 +60,7 @@ def CanonicalizeOps(): ...@@ -60,7 +60,7 @@ def CanonicalizeOps():
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass that canonicalizes QNN ops to Relay ops. The registered pass that canonicalizes QNN ops to Relay ops.
""" """
...@@ -108,7 +108,7 @@ def Legalize(): ...@@ -108,7 +108,7 @@ def Legalize():
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass that legalizes QNN ops. The registered pass that legalizes QNN ops.
""" """
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#pylint: disable=unused-argument, not-context-manager #pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit.""" """Automatic quantization toolkit."""
import tvm.ir import tvm.ir
import tvm
from tvm.runtime import Object from tvm.runtime import Object
from . import _quantize from . import _quantize
...@@ -240,7 +241,7 @@ def partition(): ...@@ -240,7 +241,7 @@ def partition():
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass for VTA rewrite. The registered pass for VTA rewrite.
""" """
return _quantize.QuantizePartition() return _quantize.QuantizePartition()
...@@ -253,7 +254,7 @@ def annotate(): ...@@ -253,7 +254,7 @@ def annotate():
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass for quantization annotation. The registered pass for quantization annotation.
""" """
return _quantize.QuantizeAnnotate() return _quantize.QuantizeAnnotate()
...@@ -267,7 +268,7 @@ def realize(): ...@@ -267,7 +268,7 @@ def realize():
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass for quantization realization. The registered pass for quantization realization.
""" """
return _quantize.QuantizeRealize() return _quantize.QuantizeRealize()
...@@ -298,7 +299,8 @@ def prerequisite_optimize(mod, params=None): ...@@ -298,7 +299,8 @@ def prerequisite_optimize(mod, params=None):
""" Prerequisite optimization passes for quantization. Perform """ Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """ "CanonicalizeOps" optimization before quantization. """
optimize = _transform.Sequential([_transform.SimplifyInference(), optimize = tvm.transform.Sequential(
[_transform.SimplifyInference(),
_transform.FoldConstant(), _transform.FoldConstant(),
_transform.FoldScaleAxis(), _transform.FoldScaleAxis(),
_transform.CanonicalizeOps(), _transform.CanonicalizeOps(),
...@@ -336,7 +338,8 @@ def quantize(mod, params=None, dataset=None): ...@@ -336,7 +338,8 @@ def quantize(mod, params=None, dataset=None):
""" """
mod = prerequisite_optimize(mod, params) mod = prerequisite_optimize(mod, params)
calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1, calibrate_pass = tvm.transform.module_pass(
calibrate(dataset), opt_level=1,
name="QuantizeCalibrate") name="QuantizeCalibrate")
quant_passes = [partition(), quant_passes = [partition(),
annotate(), annotate(),
...@@ -344,8 +347,8 @@ def quantize(mod, params=None, dataset=None): ...@@ -344,8 +347,8 @@ def quantize(mod, params=None, dataset=None):
if not current_qconfig().do_simulation: if not current_qconfig().do_simulation:
quant_passes.append(realize()) quant_passes.append(realize())
quant_passes.append(_transform.FoldConstant()) quant_passes.append(_transform.FoldConstant())
quantize_seq = _transform.Sequential(quant_passes) quantize_seq = tvm.transform.Sequential(quant_passes)
with _transform.PassContext(opt_level=3, with tvm.transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate", required_pass=["QuantizeAnnotate",
"QuantizeCalibrate", "QuantizeCalibrate",
"QuantizeRealize"]): "QuantizeRealize"]):
......
...@@ -47,7 +47,7 @@ from .py_converter import to_python, run_as_python ...@@ -47,7 +47,7 @@ from .py_converter import to_python, run_as_python
from ..transform import gradient from ..transform import gradient
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod["main"] entry = mod["main"]
......
...@@ -95,7 +95,7 @@ class PythonConverter(ExprFunctor): ...@@ -95,7 +95,7 @@ class PythonConverter(ExprFunctor):
# necessary pass: SimplifyInference (otherwise we can't generate code for some operators) # necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# and fusion (to get primitive functions) # and fusion (to get primitive functions)
opts = relay.transform.Sequential([relay.transform.SimplifyInference(), opts = tvm.transform.Sequential([relay.transform.SimplifyInference(),
relay.transform.FuseOps(fuse_opt_level=0)]) relay.transform.FuseOps(fuse_opt_level=0)])
mod = opts(mod) mod = opts(mod)
optimized = mod['main'] optimized = mod['main']
......
...@@ -22,10 +22,9 @@ import types ...@@ -22,10 +22,9 @@ import types
import inspect import inspect
import functools import functools
import tvm import tvm.ir
from tvm import te from tvm import te
from tvm.runtime import ndarray as _nd from tvm.runtime import ndarray as _nd
from tvm.ir.transform import PassInfo, PassContext, Pass, ModulePass, Sequential, module_pass
from tvm import relay from tvm import relay
from . import _ffi_api from . import _ffi_api
...@@ -78,12 +77,13 @@ def build_config(opt_level=2, ...@@ -78,12 +77,13 @@ def build_config(opt_level=2,
pass_context: PassContext pass_context: PassContext
The pass context for optimizations. The pass context for optimizations.
""" """
return PassContext(opt_level, fallback_device, required_pass, return tvm.ir.transform.PassContext(
opt_level, fallback_device, required_pass,
disabled_pass, trace) disabled_pass, trace)
@tvm._ffi.register_object("relay.FunctionPass") @tvm._ffi.register_object("relay.FunctionPass")
class FunctionPass(Pass): class FunctionPass(tvm.ir.transform.Pass):
"""A pass that works on each tvm.relay.Function in a module. A function """A pass that works on each tvm.relay.Function in a module. A function
pass class should be created through `function_pass`. pass class should be created through `function_pass`.
""" """
...@@ -94,7 +94,7 @@ def InferType(): ...@@ -94,7 +94,7 @@ def InferType():
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered type inference pass. The registered type inference pass.
""" """
return _ffi_api.InferType() return _ffi_api.InferType()
...@@ -106,7 +106,7 @@ def FoldScaleAxis(): ...@@ -106,7 +106,7 @@ def FoldScaleAxis():
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass to fold expressions. The registered pass to fold expressions.
Note Note
...@@ -123,7 +123,7 @@ def BackwardFoldScaleAxis(): ...@@ -123,7 +123,7 @@ def BackwardFoldScaleAxis():
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass to backward fold expressions. The registered pass to backward fold expressions.
Note Note
...@@ -144,7 +144,7 @@ def RemoveUnusedFunctions(entry_functions=None): ...@@ -144,7 +144,7 @@ def RemoveUnusedFunctions(entry_functions=None):
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass to remove unused functions. The registered pass to remove unused functions.
""" """
if entry_functions is None: if entry_functions is None:
...@@ -156,7 +156,7 @@ def ForwardFoldScaleAxis(): ...@@ -156,7 +156,7 @@ def ForwardFoldScaleAxis():
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass to forward fold expressions. The registered pass to forward fold expressions.
Note Note
...@@ -174,7 +174,7 @@ def SimplifyInference(): ...@@ -174,7 +174,7 @@ def SimplifyInference():
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass to perform operator simplification. The registered pass to perform operator simplification.
""" """
return _ffi_api.SimplifyInference() return _ffi_api.SimplifyInference()
...@@ -185,7 +185,7 @@ def FastMath(): ...@@ -185,7 +185,7 @@ def FastMath():
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass to perform fast math operations. The registered pass to perform fast math operations.
""" """
return _ffi_api.FastMath() return _ffi_api.FastMath()
...@@ -198,7 +198,7 @@ def CanonicalizeOps(): ...@@ -198,7 +198,7 @@ def CanonicalizeOps():
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass performing the canonicalization. The registered pass performing the canonicalization.
""" """
return _ffi_api.CanonicalizeOps() return _ffi_api.CanonicalizeOps()
...@@ -214,7 +214,7 @@ def DeadCodeElimination(inline_once=False): ...@@ -214,7 +214,7 @@ def DeadCodeElimination(inline_once=False):
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass that eliminates the dead code in a Relay program. The registered pass that eliminates the dead code in a Relay program.
""" """
return _ffi_api.DeadCodeElimination(inline_once) return _ffi_api.DeadCodeElimination(inline_once)
...@@ -227,7 +227,7 @@ def LazyGradientInit(): ...@@ -227,7 +227,7 @@ def LazyGradientInit():
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
A pass which delays and/or reduces memory allocation, A pass which delays and/or reduces memory allocation,
by lazily allocating 0 or one filled tensors. by lazily allocating 0 or one filled tensors.
""" """
...@@ -238,7 +238,7 @@ def FoldConstant(): ...@@ -238,7 +238,7 @@ def FoldConstant():
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass for constant folding. The registered pass for constant folding.
""" """
return _ffi_api.FoldConstant() return _ffi_api.FoldConstant()
...@@ -255,7 +255,7 @@ def FuseOps(fuse_opt_level=-1): ...@@ -255,7 +255,7 @@ def FuseOps(fuse_opt_level=-1):
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass for operator fusion. The registered pass for operator fusion.
""" """
return _ffi_api.FuseOps(fuse_opt_level) return _ffi_api.FuseOps(fuse_opt_level)
...@@ -272,7 +272,7 @@ def CombineParallelConv2D(min_num_branches=3): ...@@ -272,7 +272,7 @@ def CombineParallelConv2D(min_num_branches=3):
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass that combines parallel conv2d operators. The registered pass that combines parallel conv2d operators.
""" """
return _ffi_api.CombineParallelConv2D(min_num_branches) return _ffi_api.CombineParallelConv2D(min_num_branches)
...@@ -304,7 +304,7 @@ def CombineParallelDense(min_num_branches=3): ...@@ -304,7 +304,7 @@ def CombineParallelDense(min_num_branches=3):
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass that combines parallel dense operators. The registered pass that combines parallel dense operators.
""" """
return _ffi_api.CombineParallelDense(min_num_branches) return _ffi_api.CombineParallelDense(min_num_branches)
...@@ -318,7 +318,7 @@ def AlterOpLayout(): ...@@ -318,7 +318,7 @@ def AlterOpLayout():
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass that alters the layout of operators. The registered pass that alters the layout of operators.
""" """
return _ffi_api.AlterOpLayout() return _ffi_api.AlterOpLayout()
...@@ -366,7 +366,7 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"): ...@@ -366,7 +366,7 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass that rewrites an expr. The registered pass that rewrites an expr.
""" """
return _ffi_api.Legalize(legalize_map_attr_name) return _ffi_api.Legalize(legalize_map_attr_name)
...@@ -387,7 +387,7 @@ def MergeComposite(pattern_table): ...@@ -387,7 +387,7 @@ def MergeComposite(pattern_table):
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass that merges operators into a single composite The registered pass that merges operators into a single composite
relay function. relay function.
""" """
...@@ -413,7 +413,7 @@ def MergeCompilerRegions(): ...@@ -413,7 +413,7 @@ def MergeCompilerRegions():
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass that merges compiler regions. The registered pass that merges compiler regions.
""" """
return _ffi_api.MergeCompilerRegions() return _ffi_api.MergeCompilerRegions()
...@@ -433,7 +433,7 @@ def RewriteAnnotatedOps(fallback_device): ...@@ -433,7 +433,7 @@ def RewriteAnnotatedOps(fallback_device):
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass that rewrites an expression with annotated The registered pass that rewrites an expression with annotated
`on_device` operators. `on_device` operators.
""" """
...@@ -448,7 +448,7 @@ def ToANormalForm(): ...@@ -448,7 +448,7 @@ def ToANormalForm():
Returns Returns
------- -------
ret: Union[tvm.relay.Pass, tvm.relay.Expr] ret: Union[tvm.transform.Pass, tvm.relay.Expr]
The registered pass that transforms an expression into A Normal Form. The registered pass that transforms an expression into A Normal Form.
""" """
return _ffi_api.ToANormalForm() return _ffi_api.ToANormalForm()
...@@ -462,7 +462,7 @@ def ToCPS(expr, mod=None): ...@@ -462,7 +462,7 @@ def ToCPS(expr, mod=None):
Returns Returns
------- -------
result: tvm.relay.Pass result: tvm.transform.Pass
The registered pass that transforms an expression into CPS. The registered pass that transforms an expression into CPS.
""" """
return _ffi_api.to_cps(expr, mod) return _ffi_api.to_cps(expr, mod)
...@@ -481,7 +481,7 @@ def EtaExpand(expand_constructor=False, expand_global_var=False): ...@@ -481,7 +481,7 @@ def EtaExpand(expand_constructor=False, expand_global_var=False):
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass that eta expands an expression. The registered pass that eta expands an expression.
""" """
return _ffi_api.EtaExpand(expand_constructor, expand_global_var) return _ffi_api.EtaExpand(expand_constructor, expand_global_var)
...@@ -492,7 +492,7 @@ def ToGraphNormalForm(): ...@@ -492,7 +492,7 @@ def ToGraphNormalForm():
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass that transforms an expression into Graph Normal Form. The registered pass that transforms an expression into Graph Normal Form.
""" """
return _ffi_api.ToGraphNormalForm() return _ffi_api.ToGraphNormalForm()
...@@ -509,7 +509,7 @@ def EliminateCommonSubexpr(fskip=None): ...@@ -509,7 +509,7 @@ def EliminateCommonSubexpr(fskip=None):
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass that eliminates common subexpressions. The registered pass that eliminates common subexpressions.
""" """
return _ffi_api.EliminateCommonSubexpr(fskip) return _ffi_api.EliminateCommonSubexpr(fskip)
...@@ -527,7 +527,7 @@ def PartialEvaluate(): ...@@ -527,7 +527,7 @@ def PartialEvaluate():
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass that performs partial evaluation on an expression. The registered pass that performs partial evaluation on an expression.
""" """
return _ffi_api.PartialEvaluate() return _ffi_api.PartialEvaluate()
...@@ -539,7 +539,7 @@ def CanonicalizeCast(): ...@@ -539,7 +539,7 @@ def CanonicalizeCast():
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass that canonicalizes cast expression. The registered pass that canonicalizes cast expression.
""" """
return _ffi_api.CanonicalizeCast() return _ffi_api.CanonicalizeCast()
...@@ -551,36 +551,19 @@ def LambdaLift(): ...@@ -551,36 +551,19 @@ def LambdaLift():
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The registered pass that lifts the lambda function. The registered pass that lifts the lambda function.
""" """
return _ffi_api.LambdaLift() return _ffi_api.LambdaLift()
def PrintIR(show_meta_data=True):
"""
Print the IR for a module to help debugging.
Parameters
----------
show_meta_data : bool
A boolean flag to indicate if meta data should be printed.
Returns
-------
ret : tvm.relay.Pass
The registered pass that prints the module IR.
"""
return _ffi_api.PrintIR(show_meta_data)
def PartitionGraph(): def PartitionGraph():
"""Partition a Relay program into regions that can be executed on different """Partition a Relay program into regions that can be executed on different
backends. backends.
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass that partitions the Relay program. The registered pass that partitions the Relay program.
""" """
return _ffi_api.PartitionGraph() return _ffi_api.PartitionGraph()
...@@ -598,7 +581,7 @@ def AnnotateTarget(targets): ...@@ -598,7 +581,7 @@ def AnnotateTarget(targets):
Returns Returns
------- -------
ret : tvm.relay.Pass ret : tvm.transform.Pass
The annotated pass that wrapps ops with subgraph_start and The annotated pass that wrapps ops with subgraph_start and
subgraph_end. subgraph_end.
""" """
...@@ -614,7 +597,7 @@ def Inline(): ...@@ -614,7 +597,7 @@ def Inline():
Returns Returns
------- -------
ret: tvm.relay.Pass ret: tvm.transform.Pass
The registered pass that performs inlining for a Relay IR module. The registered pass that performs inlining for a Relay IR module.
""" """
return _ffi_api.Inline() return _ffi_api.Inline()
...@@ -809,7 +792,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None): ...@@ -809,7 +792,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
def create_function_pass(pass_arg): def create_function_pass(pass_arg):
"""Internal function that creates a function pass""" """Internal function that creates a function pass"""
fname = name if name else pass_arg.__name__ fname = name if name else pass_arg.__name__
info = PassInfo(opt_level, fname, required) info = tvm.transform.PassInfo(opt_level, fname, required)
if inspect.isclass(pass_arg): if inspect.isclass(pass_arg):
return _wrap_class_function_pass(pass_arg, info) return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
......
...@@ -474,10 +474,10 @@ TVM_REGISTER_GLOBAL("transform.ExitPassContext") ...@@ -474,10 +474,10 @@ TVM_REGISTER_GLOBAL("transform.ExitPassContext")
.set_body_typed(PassContext::Internal::ExitScope); .set_body_typed(PassContext::Internal::ExitScope);
Pass PrintIR(std::string header) { Pass PrintIR(std::string header, bool show_meta_data) {
auto pass_func =[header](IRModule mod, const PassContext& ctx) { auto pass_func =[header, show_meta_data](IRModule mod, const PassContext& ctx) {
LOG(INFO) << "PrintIR(" << header << "):\n" LOG(INFO) << "PrintIR(" << header << "):\n"
<< mod; << AsText(mod, show_meta_data);
return mod; return mod;
}; };
return CreateModulePass(pass_func, 0, "PrintIR", {}); return CreateModulePass(pass_func, 0, "PrintIR", {});
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file src/relay/transforms/print_ir.cc
*
* \brief Print the module IR to help debugging.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
namespace tvm {
namespace relay {
namespace transform {
Pass PrintIR(bool show_meta_data) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) {
LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m, show_meta_data);
return m;
};
return CreateModulePass(pass_func, 0, "PrintIR", {});
}
TVM_REGISTER_GLOBAL("relay._transform.PrintIR")
.set_body_typed(PrintIR);
} // namespace transform
} // namespace relay
} // namespace tvm
...@@ -53,10 +53,10 @@ def test_checkpoint_alpha_equal(): ...@@ -53,10 +53,10 @@ def test_checkpoint_alpha_equal():
df = transform.gradient(run_infer_type(f)) df = transform.gradient(run_infer_type(f))
# run PE and DCE # run PE and DCE
with transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3):
passes = [transform.PartialEvaluate(), passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)] transform.DeadCodeElimination(inline_once=True)]
mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df)) mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
df = mod["main"] df = mod["main"]
df_parsed = relay.parser.fromtext( df_parsed = relay.parser.fromtext(
...@@ -109,10 +109,10 @@ def test_checkpoint_alpha_equal_tuple(): ...@@ -109,10 +109,10 @@ def test_checkpoint_alpha_equal_tuple():
df = transform.gradient(run_infer_type(f)) df = transform.gradient(run_infer_type(f))
# run PE and DCE # run PE and DCE
with transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3):
passes = [transform.PartialEvaluate(), passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)] transform.DeadCodeElimination(inline_once=True)]
mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df)) mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
df = mod["main"] df = mod["main"]
df_parsed = relay.parser.fromtext( df_parsed = relay.parser.fromtext(
......
...@@ -26,8 +26,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr ...@@ -26,8 +26,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes] passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes) seq = tvm.transform.Sequential(passes)
with transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
entry = mod["main"] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -28,8 +28,8 @@ from tvm.relay import transform ...@@ -28,8 +28,8 @@ from tvm.relay import transform
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes] passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes) seq = tvm.transform.Sequential(passes)
with transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
return mod["main"] return mod["main"]
......
...@@ -54,9 +54,9 @@ def test_canonicalize_cast(): ...@@ -54,9 +54,9 @@ def test_canonicalize_cast():
bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32") bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
y = before(data, conv_weight, bias1, bias2) y = before(data, conv_weight, bias1, bias2)
mod = tvm.IRModule.from_expr(y) mod = tvm.IRModule.from_expr(y)
seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(), seq = tvm.transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
_transform.InferType()]) _transform.InferType()])
with _transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
y = mod["main"] y = mod["main"]
y_expected = expected(data, conv_weight, bias1, bias2) y_expected = expected(data, conv_weight, bias1, bias2)
......
...@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3): ...@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
return mod["main"] return mod["main"]
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
return mod["main"] return mod["main"]
......
...@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3): ...@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
return mod["main"] return mod["main"]
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
return mod["main"] return mod["main"]
......
...@@ -26,8 +26,8 @@ from tvm.relay import transform, analysis ...@@ -26,8 +26,8 @@ from tvm.relay import transform, analysis
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes] passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes) seq = tvm.transform.Sequential(passes)
with transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
entry = mod["main"] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -47,7 +47,7 @@ e = env() ...@@ -47,7 +47,7 @@ e = env()
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod["main"] entry = mod["main"]
......
...@@ -24,7 +24,7 @@ from tvm.relay import transform, analysis ...@@ -24,7 +24,7 @@ from tvm.relay import transform, analysis
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod["main"] entry = mod["main"]
......
...@@ -33,8 +33,8 @@ def test_eta_expand_global_var(): ...@@ -33,8 +33,8 @@ def test_eta_expand_global_var():
@aux @aux
} }
""") """)
seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)]) seq = tvm.transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
with _transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
expected = relay.fromtext(r""" expected = relay.fromtext(r"""
v0.0.4 v0.0.4
...@@ -62,8 +62,8 @@ def test_eta_expand_constructor(): ...@@ -62,8 +62,8 @@ def test_eta_expand_constructor():
Cons Cons
} }
""") """)
seq = _transform.Sequential([_transform.EtaExpand(expand_constructor=True)]) seq = tvm.transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
with _transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
expected = relay.fromtext(r""" expected = relay.fromtext(r"""
v0.0.4 v0.0.4
......
...@@ -24,7 +24,7 @@ from tvm.relay.testing import run_infer_type, create_workload ...@@ -24,7 +24,7 @@ from tvm.relay.testing import run_infer_type, create_workload
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
...@@ -174,7 +174,7 @@ def test_fold_batch_norm(): ...@@ -174,7 +174,7 @@ def test_fold_batch_norm():
add = relay.add(conv, bias) add = relay.add(conv, bias)
return relay.Function(relay.analysis.free_vars(add), add) return relay.Function(relay.analysis.free_vars(add), add)
remove_bn_pass = transform.Sequential([ remove_bn_pass = tvm.transform.Sequential([
relay.transform.InferType(), relay.transform.InferType(),
relay.transform.SimplifyInference(), relay.transform.SimplifyInference(),
relay.transform.FoldConstant(), relay.transform.FoldConstant(),
......
...@@ -26,7 +26,7 @@ def _get_positive_scale(size): ...@@ -26,7 +26,7 @@ def _get_positive_scale(size):
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod["main"] entry = mod["main"]
......
...@@ -80,7 +80,7 @@ def test_add_tuple(): ...@@ -80,7 +80,7 @@ def test_add_tuple():
mod["main"] = y mod["main"] = y
mod = transform.LazyGradientInit()(mod) mod = transform.LazyGradientInit()(mod)
mod = transform.PrintIR(show_meta_data=True)(mod) mod = tvm.transform.PrintIR(show_meta_data=True)(mod)
y = mod["main"] y = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t], tensor_type) assert mod["main"].checked_type == relay.FuncType([t], tensor_type)
...@@ -248,7 +248,7 @@ def test_after_partial_eval(): ...@@ -248,7 +248,7 @@ def test_after_partial_eval():
mod["main"] = back_func mod["main"] = back_func
back_func = mod["main"] back_func = mod["main"]
seq = transform.Sequential([ seq = tvm.transform.Sequential([
transform.PartialEvaluate(), transform.PartialEvaluate(),
transform.LazyGradientInit(), transform.LazyGradientInit(),
transform.DeadCodeElimination() transform.DeadCodeElimination()
...@@ -284,7 +284,7 @@ def test_before_partial_eval(): ...@@ -284,7 +284,7 @@ def test_before_partial_eval():
back_func = run_infer_type(back_func) back_func = run_infer_type(back_func)
mod["main"] = back_func mod["main"] = back_func
seq = transform.Sequential([ seq = tvm.transform.Sequential([
transform.LazyGradientInit(), transform.LazyGradientInit(),
transform.PartialEvaluate(), transform.PartialEvaluate(),
transform.DeadCodeElimination() transform.DeadCodeElimination()
......
...@@ -28,8 +28,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr ...@@ -28,8 +28,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes] passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes) seq = tvm.transform.Sequential(passes)
with transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
entry = mod["main"] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -23,7 +23,7 @@ from tvm.relay import analysis, transform ...@@ -23,7 +23,7 @@ from tvm.relay import analysis, transform
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod["main"] entry = mod["main"]
......
...@@ -129,13 +129,13 @@ def test_module_pass(): ...@@ -129,13 +129,13 @@ def test_module_pass():
opt_tester = OptTester(mod) opt_tester = OptTester(mod)
pass_ctx = None pass_ctx = None
@_transform.module_pass(opt_level=opt_level, name=pass_name) @tvm.transform.module_pass(opt_level=opt_level, name=pass_name)
def transform(expr, ctx): def transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
def test_pass_registration(): def test_pass_registration():
mod_pass = transform mod_pass = transform
assert isinstance(mod_pass, _transform.ModulePass) assert isinstance(mod_pass, tvm.transform.ModulePass)
pass_info = mod_pass.info pass_info = mod_pass.info
assert pass_info.name == pass_name assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level assert pass_info.opt_level == opt_level
...@@ -143,8 +143,8 @@ def test_module_pass(): ...@@ -143,8 +143,8 @@ def test_module_pass():
def test_pass_registration_no_decorator(): def test_pass_registration_no_decorator():
def direct_transform(expr, ctx): def direct_transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
mod_pass = _transform.module_pass(direct_transform, opt_level=3) mod_pass = tvm.transform.module_pass(direct_transform, opt_level=3)
assert isinstance(mod_pass, _transform.ModulePass) assert isinstance(mod_pass, tvm.transform.ModulePass)
pass_info = mod_pass.info pass_info = mod_pass.info
assert pass_info.name == "direct_transform" assert pass_info.name == "direct_transform"
assert pass_info.opt_level == 3 assert pass_info.opt_level == 3
...@@ -285,7 +285,7 @@ def test_function_pass(): ...@@ -285,7 +285,7 @@ def test_function_pass():
def test_module_class_pass(): def test_module_class_pass():
@relay.transform.module_pass(opt_level=1) @tvm.transform.module_pass(opt_level=1)
class TestPipeline: class TestPipeline:
"""Simple test function to replace one argument to another.""" """Simple test function to replace one argument to another."""
def __init__(self, new_mod, replace): def __init__(self, new_mod, replace):
...@@ -309,7 +309,7 @@ def test_module_class_pass(): ...@@ -309,7 +309,7 @@ def test_module_class_pass():
def test_pass_info(): def test_pass_info():
info = relay.transform.PassInfo(opt_level=1, name="xyz") info = tvm.transform.PassInfo(opt_level=1, name="xyz")
assert info.opt_level == 1 assert info.opt_level == 1
assert info.name == "xyz" assert info.name == "xyz"
...@@ -350,7 +350,7 @@ def test_sequential_pass(): ...@@ -350,7 +350,7 @@ def test_sequential_pass():
opt_tester = OptTester(mod) opt_tester = OptTester(mod)
pass_ctx = None pass_ctx = None
@_transform.module_pass(opt_level=1) @tvm.transform.module_pass(opt_level=1)
def mod_transform(expr, ctx): def mod_transform(expr, ctx):
return opt_tester.transform(expr, ctx) return opt_tester.transform(expr, ctx)
...@@ -367,21 +367,21 @@ def test_sequential_pass(): ...@@ -367,21 +367,21 @@ def test_sequential_pass():
passes = [module_pass, function_pass] passes = [module_pass, function_pass]
opt_level = 2 opt_level = 2
pass_name = "sequential" pass_name = "sequential"
sequential = _transform.Sequential(passes=passes, opt_level=opt_level) sequential = tvm.transform.Sequential(passes=passes, opt_level=opt_level)
pass_info = sequential.info pass_info = sequential.info
assert pass_info.name == pass_name assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level assert pass_info.opt_level == opt_level
def test_no_pass(): def test_no_pass():
passes = [] passes = []
sequential = _transform.Sequential(opt_level=1, passes=passes) sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential(mod) ret_mod = sequential(mod)
mod_func = ret_mod[v_sub] mod_func = ret_mod[v_sub]
check_func(sub, mod_func) check_func(sub, mod_func)
def test_only_module_pass(): def test_only_module_pass():
passes = [module_pass] passes = [module_pass]
sequential = _transform.Sequential(opt_level=1, passes=passes) sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
with relay.build_config(required_pass=["mod_transform"]): with relay.build_config(required_pass=["mod_transform"]):
ret_mod = sequential(mod) ret_mod = sequential(mod)
# Check the subtract function. # Check the subtract function.
...@@ -396,7 +396,7 @@ def test_sequential_pass(): ...@@ -396,7 +396,7 @@ def test_sequential_pass():
def test_only_function_pass(): def test_only_function_pass():
# Check the subtract function. # Check the subtract function.
passes = [function_pass] passes = [function_pass]
sequential = _transform.Sequential(opt_level=1, passes=passes) sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
with relay.build_config(required_pass=["func_transform"]): with relay.build_config(required_pass=["func_transform"]):
ret_mod = sequential(mod) ret_mod = sequential(mod)
_, new_sub = extract_var_func(ret_mod, v_sub.name_hint) _, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
...@@ -411,7 +411,7 @@ def test_sequential_pass(): ...@@ -411,7 +411,7 @@ def test_sequential_pass():
# function pass. # function pass.
mod = tvm.IRModule({v_sub: sub, v_log: log}) mod = tvm.IRModule({v_sub: sub, v_log: log})
passes = [module_pass, function_pass] passes = [module_pass, function_pass]
sequential = _transform.Sequential(opt_level=1, passes=passes) sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
required = ["mod_transform", "func_transform"] required = ["mod_transform", "func_transform"]
with relay.build_config(required_pass=required): with relay.build_config(required_pass=required):
ret_mod = sequential(mod) ret_mod = sequential(mod)
...@@ -482,7 +482,7 @@ def test_sequential_with_scoping(): ...@@ -482,7 +482,7 @@ def test_sequential_with_scoping():
z1 = relay.add(z, z) z1 = relay.add(z, z)
return relay.Function([x], z1) return relay.Function([x], z1)
seq = _transform.Sequential([ seq = tvm.transform.Sequential([
relay.transform.InferType(), relay.transform.InferType(),
relay.transform.FoldConstant(), relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(), relay.transform.EliminateCommonSubexpr(),
...@@ -507,10 +507,10 @@ def test_print_ir(capfd): ...@@ -507,10 +507,10 @@ def test_print_ir(capfd):
y = relay.multiply(y, relay.const(2, "float32")) y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y) func = relay.Function([x], y)
seq = _transform.Sequential([ seq = tvm.transform.Sequential([
relay.transform.InferType(), relay.transform.InferType(),
relay.transform.FoldConstant(), relay.transform.FoldConstant(),
relay.transform.PrintIR(), tvm.transform.PrintIR(),
relay.transform.DeadCodeElimination() relay.transform.DeadCodeElimination()
]) ])
...@@ -520,7 +520,7 @@ def test_print_ir(capfd): ...@@ -520,7 +520,7 @@ def test_print_ir(capfd):
out = capfd.readouterr().err out = capfd.readouterr().err
assert "Dumping the module IR" in out assert "PrintIR" in out
assert "multiply" in out assert "multiply" in out
__TRACE_COUNTER__ = 0 __TRACE_COUNTER__ = 0
...@@ -539,7 +539,7 @@ def test_print_debug_callback(): ...@@ -539,7 +539,7 @@ def test_print_debug_callback():
y = relay.multiply(y, relay.const(2, "float32")) y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y) func = relay.Function([x], y)
seq = _transform.Sequential([ seq = tvm.transform.Sequential([
relay.transform.InferType(), relay.transform.InferType(),
relay.transform.FoldConstant(), relay.transform.FoldConstant(),
relay.transform.DeadCodeElimination() relay.transform.DeadCodeElimination()
......
...@@ -38,8 +38,8 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07): ...@@ -38,8 +38,8 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes] passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes) seq = tvm.transform.Sequential(passes)
with transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
entry = mod["main"] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
...@@ -58,7 +58,7 @@ def dcpe(expr, mod=None, grad=False): ...@@ -58,7 +58,7 @@ def dcpe(expr, mod=None, grad=False):
if mod: if mod:
assert isinstance(expr, Function) assert isinstance(expr, Function)
mod["main"] = expr mod["main"] = expr
seq = transform.Sequential(passes) seq = tvm.transform.Sequential(passes)
mod = seq(mod) mod = seq(mod)
return mod["main"] return mod["main"]
return run_opt_pass(expr, passes) return run_opt_pass(expr, passes)
......
...@@ -496,7 +496,7 @@ def test_function_lifting(): ...@@ -496,7 +496,7 @@ def test_function_lifting():
op_list = ["nn.batch_norm", "nn.conv2d"] op_list = ["nn.batch_norm", "nn.conv2d"]
mod = WhiteListAnnotator(op_list, "test_compiler")(mod) mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
opt_pass = transform.Sequential([ opt_pass = tvm.transform.Sequential([
transform.InferType(), transform.InferType(),
transform.PartitionGraph(), transform.PartitionGraph(),
transform.SimplifyInference(), transform.SimplifyInference(),
...@@ -578,7 +578,7 @@ def test_function_lifting_inline(): ...@@ -578,7 +578,7 @@ def test_function_lifting_inline():
op_list = ["nn.batch_norm", "nn.conv2d"] op_list = ["nn.batch_norm", "nn.conv2d"]
mod = WhiteListAnnotator(op_list, "test_compiler")(mod) mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
opt_pass = transform.Sequential([ opt_pass = tvm.transform.Sequential([
transform.InferType(), transform.InferType(),
transform.PartitionGraph(), transform.PartitionGraph(),
transform.SimplifyInference(), transform.SimplifyInference(),
...@@ -878,13 +878,13 @@ def test_dnnl_fuse(): ...@@ -878,13 +878,13 @@ def test_dnnl_fuse():
# This is required for constant folding # This is required for constant folding
mod["main"] = bind_params_by_name(mod["main"], params) mod["main"] = bind_params_by_name(mod["main"], params)
remove_bn_pass = transform.Sequential([ remove_bn_pass = tvm.transform.Sequential([
transform.InferType(), transform.InferType(),
transform.SimplifyInference(), transform.SimplifyInference(),
transform.FoldConstant(), transform.FoldConstant(),
transform.FoldScaleAxis(), transform.FoldScaleAxis(),
]) ])
composite_partition = transform.Sequential([ composite_partition = tvm.transform.Sequential([
remove_bn_pass, remove_bn_pass,
transform.MergeComposite(pattern_table), transform.MergeComposite(pattern_table),
transform.AnnotateTarget("dnnl"), transform.AnnotateTarget("dnnl"),
......
...@@ -37,8 +37,8 @@ def alpha_equal(x, y): ...@@ -37,8 +37,8 @@ def alpha_equal(x, y):
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes] passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes) seq = tvm.transform.Sequential(passes)
with transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
entry = mod["main"] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -28,8 +28,8 @@ from tvm.relay.analysis import Feature ...@@ -28,8 +28,8 @@ from tvm.relay.analysis import Feature
def run_opt_pass(expr, passes): def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes] passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes) seq = tvm.transform.Sequential(passes)
with transform.PassContext(opt_level=3): with tvm.transform.PassContext(opt_level=3):
mod = seq(mod) mod = seq(mod)
entry = mod["main"] entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body return entry if isinstance(expr, relay.Function) else entry.body
......
...@@ -71,7 +71,8 @@ def test_cps_pe(): ...@@ -71,7 +71,8 @@ def test_cps_pe():
x = run_infer_type(x) x = run_infer_type(x)
y = un_cps(x) y = un_cps(x)
y = run_infer_type(y) y = run_infer_type(y)
x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)])) x = run_opt_pass(x, tvm.transform.Sequential(
[transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
assert Feature.fRefCreate not in detect_feature(x) assert Feature.fRefCreate not in detect_feature(x)
unit = relay.Function([], relay.const(0., dtype='float32')) unit = relay.Function([], relay.const(0., dtype='float32'))
f_ref = relay.Var("f_ref") f_ref = relay.Var("f_ref")
......
...@@ -29,7 +29,7 @@ introduced an infrastructure to manage the optimization passes. ...@@ -29,7 +29,7 @@ introduced an infrastructure to manage the optimization passes.
The optimizations of a Relay program could be applied at various granularity, The optimizations of a Relay program could be applied at various granularity,
namely function-level and module-level using :py:class:`tvm.relay.transform.FunctionPass` namely function-level and module-level using :py:class:`tvm.relay.transform.FunctionPass`
and py:class:`tvm.relay.transform.ModulePass` and py:class:`tvm.relay.transform.ModulePass`
respectively. Or users can rely on py:class:`tvm.relay.transform.Sequential` to apply a sequence of passes respectively. Or users can rely on py:class:`tvm.transform.Sequential` to apply a sequence of passes
on a Relay program where the dependencies between passes can be resolved by the on a Relay program where the dependencies between passes can be resolved by the
pass infra. For more details about each type of these passes, please refer to pass infra. For more details about each type of these passes, please refer to
the :ref:`relay-pass-infra` the :ref:`relay-pass-infra`
...@@ -130,22 +130,22 @@ print(mod) ...@@ -130,22 +130,22 @@ print(mod)
# fusion, as this pass generates let bindings for each expression to # fusion, as this pass generates let bindings for each expression to
# canonicalize a Relay program. # canonicalize a Relay program.
# #
# Relay, hence, provides :py:class:`tvm.relay.transform.Sequential` to alleviate developers from handling # Relay, hence, provides :py:class:`tvm.transform.Sequential` to alleviate developers from handling
# these issues explicitly by specifying the required passes of each pass and # these issues explicitly by specifying the required passes of each pass and
# packing them as a whole to execute. For example, the same passes can now be # packing them as a whole to execute. For example, the same passes can now be
# applied using the sequential style as the following. :py:class:`tvm.relay.transform.Sequential` is # applied using the sequential style as the following. :py:class:`tvm.transform.Sequential` is
# similiar to `torch.nn.sequential <https://pytorch.org/docs/stable/nn.html#torch.nn.Sequential>`_ # similiar to `torch.nn.sequential <https://pytorch.org/docs/stable/nn.html#torch.nn.Sequential>`_
# and `mxnet.gluon.block <https://mxnet.incubator.apache.org/api/python/docs/_modules/mxnet/gluon/block.html>`_. # and `mxnet.gluon.block <https://mxnet.incubator.apache.org/api/python/docs/_modules/mxnet/gluon/block.html>`_.
# For example, `torch.nn.sequential` is used to contain a sequence of PyTorch # For example, `torch.nn.sequential` is used to contain a sequence of PyTorch
# `Modules` that will be added to build a network. It focuses on the network # `Modules` that will be added to build a network. It focuses on the network
# layers. Instead, the :py:class:`tvm.relay.transform.Sequential` in our pass infra works on the optimizing # layers. Instead, the :py:class:`tvm.transform.Sequential` in our pass infra works on the optimizing
# pass. # pass.
# Now let's execute some passes through :py:class:`tvm.relay.transform.Sequential` # Now let's execute some passes through :py:class:`tvm.transform.Sequential`
f = example() f = example()
mod = tvm.IRModule.from_expr(f) mod = tvm.IRModule.from_expr(f)
# Glob the interested passes. # Glob the interested passes.
seq = relay.transform.Sequential([relay.transform.FoldConstant(), seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(), relay.transform.EliminateCommonSubexpr(),
relay.transform.FuseOps(fuse_opt_level=2)]) relay.transform.FuseOps(fuse_opt_level=2)])
mod1 = seq(mod) mod1 = seq(mod)
...@@ -156,7 +156,7 @@ print(mod1) ...@@ -156,7 +156,7 @@ print(mod1)
# identical addition operations. This is because `EliminateCommonSubexpr` # identical addition operations. This is because `EliminateCommonSubexpr`
# was not actually performed. The reason is because only the passes that have # was not actually performed. The reason is because only the passes that have
# optimization level less or equal to 2 will be executed by default under # optimization level less or equal to 2 will be executed by default under
# :py:class:`tvm.relay.transform.Sequential`. The pass infra, # :py:class:`tvm.transform.Sequential`. The pass infra,
# however, provides a configuration interface # however, provides a configuration interface
# for users to customize the optimization level that they want to execute. # for users to customize the optimization level that they want to execute.
...@@ -186,7 +186,7 @@ with relay.build_config(opt_level=3): ...@@ -186,7 +186,7 @@ with relay.build_config(opt_level=3):
mod4 = seq(mod) mod4 = seq(mod)
print(mod4) print(mod4)
seq1 = relay.transform.Sequential([relay.transform.AlterOpLayout()]) seq1 = tvm.transform.Sequential([relay.transform.AlterOpLayout()])
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
with tvm.target.create("llvm"): with tvm.target.create("llvm"):
mod5 = seq1(mod) mod5 = seq1(mod)
...@@ -237,11 +237,11 @@ print(mod3) ...@@ -237,11 +237,11 @@ print(mod3)
f = example() f = example()
mod = tvm.IRModule.from_expr(f) mod = tvm.IRModule.from_expr(f)
seq = relay.transform.Sequential([relay.transform.FoldConstant(), seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
relay.transform.PrintIR(False), tvm.transform.PrintIR(),
relay.transform.EliminateCommonSubexpr(), relay.transform.EliminateCommonSubexpr(),
relay.transform.FuseOps(), relay.transform.FuseOps(),
relay.transform.PrintIR(False)]) tvm.transform.PrintIR()])
with relay.build_config(opt_level=3): with relay.build_config(opt_level=3):
mod = seq(mod) mod = seq(mod)
......
...@@ -24,7 +24,7 @@ from tvm.relay import ExprMutator ...@@ -24,7 +24,7 @@ from tvm.relay import ExprMutator
def run_opt_pass(expr, opt_pass): def run_opt_pass(expr, opt_pass):
"""Exectue a relay pass.""" """Exectue a relay pass."""
assert isinstance(opt_pass, transform.Pass) assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr) mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod) mod = opt_pass(mod)
entry = mod["main"] entry = mod["main"]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment