Unverified Commit 275e317c by Tianqi Chen Committed by GitHub

[RELAY] Remove re-exports of tvm.transform (#5337)

parent f08d5d78
......@@ -21,3 +21,11 @@ tvm.ir
:members:
:imported-members:
:autosummary:
tvm.transform
-------------
.. automodule:: tvm.transform
:members:
:imported-members:
:autosummary:
......@@ -227,7 +227,7 @@ ConvertLayout pass is extremely easy to use. The pass is not a part of default r
# Convert the layout to NCHW
# RemoveUnunsedFunctions is used to clean up the graph.
seq = relay.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
relay.transform.ConvertLayout('NCHW')])
with relay.transform.PassContext(opt_level=3):
mod = seq(mod)
......
......@@ -582,7 +582,7 @@ using ``Sequential`` associated with other types of passes.
func = relay.Function([x], z2)
# Customize the optimization pipeline.
seq = _transform.Sequential([
seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
......@@ -609,7 +609,7 @@ sequential pass example could be like the following to enable IR dumping for
.. code:: python
seq = _transform.Sequential([
seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.PrintIR(),
......
......@@ -361,9 +361,11 @@ TVM_DLL Pass CreateModulePass(
/*!
* \brief A special trace pass that prints the header and IR to LOG(INFO).
* \param header The header to be attached to the output.
* \param show_meta_data Whether should we show meta data.
* \return The pass.
*/
TVM_DLL Pass PrintIR(std::string header);
TVM_DLL Pass PrintIR(std::string header = "", bool show_meta_data = false);
} // namespace transform
} // namespace tvm
......
......@@ -106,7 +106,7 @@ def create_updater_06_to_07():
"relay.PassInfo": _rename("transform.PassInfo"),
"relay.PassContext": _rename("transform.PassContext"),
"relay.ModulePass": _rename("transform.ModulePass"),
"relay.Sequantial": _rename("transform.Sequantial"),
"relay.Sequential": _rename("transform.Sequential"),
# TIR
"Variable": _update_tir_var("tir.Var"),
"SizeVar": _update_tir_var("tir.SizeVar"),
......
......@@ -329,7 +329,7 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None):
return create_module_pass
def PrintIR(header):
def PrintIR(header="", show_meta_data=False):
"""A special trace pass that prints the header and IR.
Parameters
......@@ -337,8 +337,11 @@ def PrintIR(header):
header : str
The header to be displayed along with the dump.
show_meta_data : bool
A boolean flag to indicate if meta data should be printed.
Returns
--------
The pass
"""
return _ffi_transform_api.PrintIR(header)
return _ffi_transform_api.PrintIR(header, show_meta_data)
......@@ -128,20 +128,9 @@ Prelude = prelude.Prelude
# Scope builder
ScopeBuilder = scope_builder.ScopeBuilder
module_pass = transform.module_pass
function_pass = transform.function_pass
# Parser
fromtext = parser.fromtext
# Param Serialization
save_param_dict = param_dict.save_param_dict
load_param_dict = param_dict.load_param_dict
# Pass manager
PassInfo = transform.PassInfo
PassContext = transform.PassContext
Pass = transform.Pass
ModulePass = transform.ModulePass
FunctionPass = transform.FunctionPass
Sequential = transform.Sequential
......@@ -210,10 +210,10 @@ class Interpreter(Executor):
opt_mod : tvm.IRModule
The optimized module.
"""
seq = transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType()])
seq = tvm.transform.Sequential([transform.SimplifyInference(),
transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType()])
return seq(self.mod)
def _make_executor(self, expr=None):
......
......@@ -60,7 +60,7 @@ def CanonicalizeOps():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that canonicalizes QNN ops to Relay ops.
"""
......@@ -108,7 +108,7 @@ def Legalize():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that legalizes QNN ops.
"""
......
......@@ -17,6 +17,7 @@
#pylint: disable=unused-argument, not-context-manager
"""Automatic quantization toolkit."""
import tvm.ir
import tvm
from tvm.runtime import Object
from . import _quantize
......@@ -240,7 +241,7 @@ def partition():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass for VTA rewrite.
"""
return _quantize.QuantizePartition()
......@@ -253,7 +254,7 @@ def annotate():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass for quantization annotation.
"""
return _quantize.QuantizeAnnotate()
......@@ -267,7 +268,7 @@ def realize():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass for quantization realization.
"""
return _quantize.QuantizeRealize()
......@@ -298,11 +299,12 @@ def prerequisite_optimize(mod, params=None):
""" Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """
optimize = _transform.Sequential([_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
optimize = tvm.transform.Sequential(
[_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
if params:
mod['main'] = _bind_params(mod['main'], params)
......@@ -336,19 +338,20 @@ def quantize(mod, params=None, dataset=None):
"""
mod = prerequisite_optimize(mod, params)
calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1,
name="QuantizeCalibrate")
calibrate_pass = tvm.transform.module_pass(
calibrate(dataset), opt_level=1,
name="QuantizeCalibrate")
quant_passes = [partition(),
annotate(),
calibrate_pass]
if not current_qconfig().do_simulation:
quant_passes.append(realize())
quant_passes.append(_transform.FoldConstant())
quantize_seq = _transform.Sequential(quant_passes)
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
quantize_seq = tvm.transform.Sequential(quant_passes)
with tvm.transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
with quantize_context():
mod = quantize_seq(mod)
......
......@@ -47,7 +47,7 @@ from .py_converter import to_python, run_as_python
from ..transform import gradient
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
......
......@@ -95,8 +95,8 @@ class PythonConverter(ExprFunctor):
# necessary pass: SimplifyInference (otherwise we can't generate code for some operators)
# and fusion (to get primitive functions)
opts = relay.transform.Sequential([relay.transform.SimplifyInference(),
relay.transform.FuseOps(fuse_opt_level=0)])
opts = tvm.transform.Sequential([relay.transform.SimplifyInference(),
relay.transform.FuseOps(fuse_opt_level=0)])
mod = opts(mod)
optimized = mod['main']
return optimized if isinstance(unwrapped, Function) else optimized.body
......
......@@ -22,10 +22,9 @@ import types
import inspect
import functools
import tvm
import tvm.ir
from tvm import te
from tvm.runtime import ndarray as _nd
from tvm.ir.transform import PassInfo, PassContext, Pass, ModulePass, Sequential, module_pass
from tvm import relay
from . import _ffi_api
......@@ -78,12 +77,13 @@ def build_config(opt_level=2,
pass_context: PassContext
The pass context for optimizations.
"""
return PassContext(opt_level, fallback_device, required_pass,
disabled_pass, trace)
return tvm.ir.transform.PassContext(
opt_level, fallback_device, required_pass,
disabled_pass, trace)
@tvm._ffi.register_object("relay.FunctionPass")
class FunctionPass(Pass):
class FunctionPass(tvm.ir.transform.Pass):
"""A pass that works on each tvm.relay.Function in a module. A function
pass class should be created through `function_pass`.
"""
......@@ -94,7 +94,7 @@ def InferType():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered type inference pass.
"""
return _ffi_api.InferType()
......@@ -106,7 +106,7 @@ def FoldScaleAxis():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass to fold expressions.
Note
......@@ -123,7 +123,7 @@ def BackwardFoldScaleAxis():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass to backward fold expressions.
Note
......@@ -144,7 +144,7 @@ def RemoveUnusedFunctions(entry_functions=None):
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass to remove unused functions.
"""
if entry_functions is None:
......@@ -156,7 +156,7 @@ def ForwardFoldScaleAxis():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass to forward fold expressions.
Note
......@@ -174,7 +174,7 @@ def SimplifyInference():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass to perform operator simplification.
"""
return _ffi_api.SimplifyInference()
......@@ -185,7 +185,7 @@ def FastMath():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass to perform fast math operations.
"""
return _ffi_api.FastMath()
......@@ -198,7 +198,7 @@ def CanonicalizeOps():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass performing the canonicalization.
"""
return _ffi_api.CanonicalizeOps()
......@@ -214,7 +214,7 @@ def DeadCodeElimination(inline_once=False):
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass that eliminates the dead code in a Relay program.
"""
return _ffi_api.DeadCodeElimination(inline_once)
......@@ -227,7 +227,7 @@ def LazyGradientInit():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
A pass which delays and/or reduces memory allocation,
by lazily allocating 0 or one filled tensors.
"""
......@@ -238,7 +238,7 @@ def FoldConstant():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass for constant folding.
"""
return _ffi_api.FoldConstant()
......@@ -255,7 +255,7 @@ def FuseOps(fuse_opt_level=-1):
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass for operator fusion.
"""
return _ffi_api.FuseOps(fuse_opt_level)
......@@ -272,7 +272,7 @@ def CombineParallelConv2D(min_num_branches=3):
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass that combines parallel conv2d operators.
"""
return _ffi_api.CombineParallelConv2D(min_num_branches)
......@@ -304,7 +304,7 @@ def CombineParallelDense(min_num_branches=3):
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass that combines parallel dense operators.
"""
return _ffi_api.CombineParallelDense(min_num_branches)
......@@ -318,7 +318,7 @@ def AlterOpLayout():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that alters the layout of operators.
"""
return _ffi_api.AlterOpLayout()
......@@ -366,7 +366,7 @@ def Legalize(legalize_map_attr_name="FTVMLegalize"):
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that rewrites an expr.
"""
return _ffi_api.Legalize(legalize_map_attr_name)
......@@ -387,7 +387,7 @@ def MergeComposite(pattern_table):
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that merges operators into a single composite
relay function.
"""
......@@ -413,7 +413,7 @@ def MergeCompilerRegions():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that merges compiler regions.
"""
return _ffi_api.MergeCompilerRegions()
......@@ -433,7 +433,7 @@ def RewriteAnnotatedOps(fallback_device):
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass that rewrites an expression with annotated
`on_device` operators.
"""
......@@ -448,7 +448,7 @@ def ToANormalForm():
Returns
-------
ret: Union[tvm.relay.Pass, tvm.relay.Expr]
ret: Union[tvm.transform.Pass, tvm.relay.Expr]
The registered pass that transforms an expression into A Normal Form.
"""
return _ffi_api.ToANormalForm()
......@@ -462,7 +462,7 @@ def ToCPS(expr, mod=None):
Returns
-------
result: tvm.relay.Pass
result: tvm.transform.Pass
The registered pass that transforms an expression into CPS.
"""
return _ffi_api.to_cps(expr, mod)
......@@ -481,7 +481,7 @@ def EtaExpand(expand_constructor=False, expand_global_var=False):
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass that eta expands an expression.
"""
return _ffi_api.EtaExpand(expand_constructor, expand_global_var)
......@@ -492,7 +492,7 @@ def ToGraphNormalForm():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that transforms an expression into Graph Normal Form.
"""
return _ffi_api.ToGraphNormalForm()
......@@ -509,7 +509,7 @@ def EliminateCommonSubexpr(fskip=None):
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that eliminates common subexpressions.
"""
return _ffi_api.EliminateCommonSubexpr(fskip)
......@@ -527,7 +527,7 @@ def PartialEvaluate():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass that performs partial evaluation on an expression.
"""
return _ffi_api.PartialEvaluate()
......@@ -539,7 +539,7 @@ def CanonicalizeCast():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that canonicalizes cast expression.
"""
return _ffi_api.CanonicalizeCast()
......@@ -551,36 +551,19 @@ def LambdaLift():
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The registered pass that lifts the lambda function.
"""
return _ffi_api.LambdaLift()
def PrintIR(show_meta_data=True):
"""
Print the IR for a module to help debugging.
Parameters
----------
show_meta_data : bool
A boolean flag to indicate if meta data should be printed.
Returns
-------
ret : tvm.relay.Pass
The registered pass that prints the module IR.
"""
return _ffi_api.PrintIR(show_meta_data)
def PartitionGraph():
"""Partition a Relay program into regions that can be executed on different
backends.
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass that partitions the Relay program.
"""
return _ffi_api.PartitionGraph()
......@@ -598,7 +581,7 @@ def AnnotateTarget(targets):
Returns
-------
ret : tvm.relay.Pass
ret : tvm.transform.Pass
The annotated pass that wrapps ops with subgraph_start and
subgraph_end.
"""
......@@ -614,7 +597,7 @@ def Inline():
Returns
-------
ret: tvm.relay.Pass
ret: tvm.transform.Pass
The registered pass that performs inlining for a Relay IR module.
"""
return _ffi_api.Inline()
......@@ -809,7 +792,7 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None):
def create_function_pass(pass_arg):
"""Internal function that creates a function pass"""
fname = name if name else pass_arg.__name__
info = PassInfo(opt_level, fname, required)
info = tvm.transform.PassInfo(opt_level, fname, required)
if inspect.isclass(pass_arg):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
......
......@@ -474,10 +474,10 @@ TVM_REGISTER_GLOBAL("transform.ExitPassContext")
.set_body_typed(PassContext::Internal::ExitScope);
Pass PrintIR(std::string header) {
auto pass_func =[header](IRModule mod, const PassContext& ctx) {
Pass PrintIR(std::string header, bool show_meta_data) {
auto pass_func =[header, show_meta_data](IRModule mod, const PassContext& ctx) {
LOG(INFO) << "PrintIR(" << header << "):\n"
<< mod;
<< AsText(mod, show_meta_data);
return mod;
};
return CreateModulePass(pass_func, 0, "PrintIR", {});
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
*
* \file src/relay/transforms/print_ir.cc
*
* \brief Print the module IR to help debugging.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
namespace tvm {
namespace relay {
namespace transform {
Pass PrintIR(bool show_meta_data) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) {
LOG(INFO) << "Dumping the module IR: " << std::endl << AsText(m, show_meta_data);
return m;
};
return CreateModulePass(pass_func, 0, "PrintIR", {});
}
TVM_REGISTER_GLOBAL("relay._transform.PrintIR")
.set_body_typed(PrintIR);
} // namespace transform
} // namespace relay
} // namespace tvm
......@@ -53,10 +53,10 @@ def test_checkpoint_alpha_equal():
df = transform.gradient(run_infer_type(f))
# run PE and DCE
with transform.PassContext(opt_level=3):
with tvm.transform.PassContext(opt_level=3):
passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)]
mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df))
mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
df = mod["main"]
df_parsed = relay.parser.fromtext(
......@@ -109,10 +109,10 @@ def test_checkpoint_alpha_equal_tuple():
df = transform.gradient(run_infer_type(f))
# run PE and DCE
with transform.PassContext(opt_level=3):
with tvm.transform.PassContext(opt_level=3):
passes = [transform.PartialEvaluate(),
transform.DeadCodeElimination(inline_once=True)]
mod = transform.Sequential(passes)(tvm.IRModule.from_expr(df))
mod = tvm.transform.Sequential(passes)(tvm.IRModule.from_expr(df))
df = mod["main"]
df_parsed = relay.parser.fromtext(
......
......@@ -26,8 +26,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential(passes)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
......
......@@ -28,8 +28,8 @@ from tvm.relay import transform
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential(passes)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
return mod["main"]
......
......@@ -54,9 +54,9 @@ def test_canonicalize_cast():
bias2 = relay.var("bias2", shape=(16, 1, 1), dtype="int32")
y = before(data, conv_weight, bias1, bias2)
mod = tvm.IRModule.from_expr(y)
seq = _transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
seq = tvm.transform.Sequential([_transform.InferType(), _transform.CanonicalizeCast(),
_transform.InferType()])
with _transform.PassContext(opt_level=3):
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
y = mod["main"]
y_expected = expected(data, conv_weight, bias1, bias2)
......
......@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
return mod["main"]
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
return mod["main"]
......
......@@ -26,7 +26,7 @@ def run_combine_parallel(expr, min_num_branches=3):
return mod["main"]
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
return mod["main"]
......
......@@ -26,8 +26,8 @@ from tvm.relay import transform, analysis
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential(passes)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
......
......@@ -47,7 +47,7 @@ e = env()
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
......
......@@ -24,7 +24,7 @@ from tvm.relay import transform, analysis
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
......
......@@ -33,8 +33,8 @@ def test_eta_expand_global_var():
@aux
}
""")
seq = _transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
with _transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential([_transform.EtaExpand(expand_global_var=True)])
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
expected = relay.fromtext(r"""
v0.0.4
......@@ -62,8 +62,8 @@ def test_eta_expand_constructor():
Cons
}
""")
seq = _transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
with _transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential([_transform.EtaExpand(expand_constructor=True)])
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
expected = relay.fromtext(r"""
v0.0.4
......
......@@ -24,7 +24,7 @@ from tvm.relay.testing import run_infer_type, create_workload
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
......@@ -174,7 +174,7 @@ def test_fold_batch_norm():
add = relay.add(conv, bias)
return relay.Function(relay.analysis.free_vars(add), add)
remove_bn_pass = transform.Sequential([
remove_bn_pass = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.SimplifyInference(),
relay.transform.FoldConstant(),
......
......@@ -26,7 +26,7 @@ def _get_positive_scale(size):
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
......
......@@ -80,7 +80,7 @@ def test_add_tuple():
mod["main"] = y
mod = transform.LazyGradientInit()(mod)
mod = transform.PrintIR(show_meta_data=True)(mod)
mod = tvm.transform.PrintIR(show_meta_data=True)(mod)
y = mod["main"]
assert mod["main"].checked_type == relay.FuncType([t], tensor_type)
......@@ -116,7 +116,7 @@ def test_mult():
def test_ret_tuple():
"""Test tuple return type. Check types and semantic equivalence."""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
......@@ -141,7 +141,7 @@ def test_ret_tuple():
def test_add_broadcast():
"""Test adding matrices of different size. Check types and semantic equivalence."""
mod = tvm.IRModule()
shape1 = (3, 4, 1)
shape2 = (1, 5)
dtype = 'float32'
......@@ -173,7 +173,7 @@ def test_reverse_ad_identity():
"""Simple test with reverse mode ad."""
# of f(x) = x
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
......@@ -201,7 +201,7 @@ def test_reverse_ad_identity():
def test_multivar_reverse_ad():
"""Simple test with multivariate reverse mode ad."""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
......@@ -232,7 +232,7 @@ def test_multivar_reverse_ad():
def test_after_partial_eval():
"""Test transformation following reverse mode ad and PartialEval"""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
......@@ -248,7 +248,7 @@ def test_after_partial_eval():
mod["main"] = back_func
back_func = mod["main"]
seq = transform.Sequential([
seq = tvm.transform.Sequential([
transform.PartialEvaluate(),
transform.LazyGradientInit(),
transform.DeadCodeElimination()
......@@ -270,7 +270,7 @@ def test_after_partial_eval():
def test_before_partial_eval():
"""Test transformation before PartialEval"""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
......@@ -284,7 +284,7 @@ def test_before_partial_eval():
back_func = run_infer_type(back_func)
mod["main"] = back_func
seq = transform.Sequential([
seq = tvm.transform.Sequential([
transform.LazyGradientInit(),
transform.PartialEvaluate(),
transform.DeadCodeElimination()
......@@ -306,7 +306,7 @@ def test_before_partial_eval():
def test_zeros():
"""Simple test using "zeros" op"""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
......@@ -328,7 +328,7 @@ def test_zeros():
def test_ones():
"""Simple test using "ones" op"""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
......@@ -350,7 +350,7 @@ def test_ones():
def test_zeros_like():
"""Simple test using "zeros_like" op"""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
......@@ -372,7 +372,7 @@ def test_zeros_like():
def test_ones_like():
"""Simple test using "ones_like" op"""
mod = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
......
......@@ -28,8 +28,8 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential(passes)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
......
......@@ -23,7 +23,7 @@ from tvm.relay import analysis, transform
def run_opt_pass(expr, opt_pass):
assert isinstance(opt_pass, transform.Pass)
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
......
......@@ -129,13 +129,13 @@ def test_module_pass():
opt_tester = OptTester(mod)
pass_ctx = None
@_transform.module_pass(opt_level=opt_level, name=pass_name)
@tvm.transform.module_pass(opt_level=opt_level, name=pass_name)
def transform(expr, ctx):
return opt_tester.transform(expr, ctx)
def test_pass_registration():
mod_pass = transform
assert isinstance(mod_pass, _transform.ModulePass)
assert isinstance(mod_pass, tvm.transform.ModulePass)
pass_info = mod_pass.info
assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level
......@@ -143,8 +143,8 @@ def test_module_pass():
def test_pass_registration_no_decorator():
def direct_transform(expr, ctx):
return opt_tester.transform(expr, ctx)
mod_pass = _transform.module_pass(direct_transform, opt_level=3)
assert isinstance(mod_pass, _transform.ModulePass)
mod_pass = tvm.transform.module_pass(direct_transform, opt_level=3)
assert isinstance(mod_pass, tvm.transform.ModulePass)
pass_info = mod_pass.info
assert pass_info.name == "direct_transform"
assert pass_info.opt_level == 3
......@@ -285,7 +285,7 @@ def test_function_pass():
def test_module_class_pass():
@relay.transform.module_pass(opt_level=1)
@tvm.transform.module_pass(opt_level=1)
class TestPipeline:
"""Simple test function to replace one argument to another."""
def __init__(self, new_mod, replace):
......@@ -309,7 +309,7 @@ def test_module_class_pass():
def test_pass_info():
info = relay.transform.PassInfo(opt_level=1, name="xyz")
info = tvm.transform.PassInfo(opt_level=1, name="xyz")
assert info.opt_level == 1
assert info.name == "xyz"
......@@ -350,7 +350,7 @@ def test_sequential_pass():
opt_tester = OptTester(mod)
pass_ctx = None
@_transform.module_pass(opt_level=1)
@tvm.transform.module_pass(opt_level=1)
def mod_transform(expr, ctx):
return opt_tester.transform(expr, ctx)
......@@ -367,21 +367,21 @@ def test_sequential_pass():
passes = [module_pass, function_pass]
opt_level = 2
pass_name = "sequential"
sequential = _transform.Sequential(passes=passes, opt_level=opt_level)
sequential = tvm.transform.Sequential(passes=passes, opt_level=opt_level)
pass_info = sequential.info
assert pass_info.name == pass_name
assert pass_info.opt_level == opt_level
def test_no_pass():
passes = []
sequential = _transform.Sequential(opt_level=1, passes=passes)
sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
ret_mod = sequential(mod)
mod_func = ret_mod[v_sub]
check_func(sub, mod_func)
def test_only_module_pass():
passes = [module_pass]
sequential = _transform.Sequential(opt_level=1, passes=passes)
sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
with relay.build_config(required_pass=["mod_transform"]):
ret_mod = sequential(mod)
# Check the subtract function.
......@@ -396,7 +396,7 @@ def test_sequential_pass():
def test_only_function_pass():
# Check the subtract function.
passes = [function_pass]
sequential = _transform.Sequential(opt_level=1, passes=passes)
sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
with relay.build_config(required_pass=["func_transform"]):
ret_mod = sequential(mod)
_, new_sub = extract_var_func(ret_mod, v_sub.name_hint)
......@@ -411,7 +411,7 @@ def test_sequential_pass():
# function pass.
mod = tvm.IRModule({v_sub: sub, v_log: log})
passes = [module_pass, function_pass]
sequential = _transform.Sequential(opt_level=1, passes=passes)
sequential = tvm.transform.Sequential(opt_level=1, passes=passes)
required = ["mod_transform", "func_transform"]
with relay.build_config(required_pass=required):
ret_mod = sequential(mod)
......@@ -482,7 +482,7 @@ def test_sequential_with_scoping():
z1 = relay.add(z, z)
return relay.Function([x], z1)
seq = _transform.Sequential([
seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
......@@ -507,10 +507,10 @@ def test_print_ir(capfd):
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)
seq = _transform.Sequential([
seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.PrintIR(),
tvm.transform.PrintIR(),
relay.transform.DeadCodeElimination()
])
......@@ -520,7 +520,7 @@ def test_print_ir(capfd):
out = capfd.readouterr().err
assert "Dumping the module IR" in out
assert "PrintIR" in out
assert "multiply" in out
__TRACE_COUNTER__ = 0
......@@ -539,7 +539,7 @@ def test_print_debug_callback():
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)
seq = _transform.Sequential([
seq = tvm.transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.DeadCodeElimination()
......
......@@ -38,8 +38,8 @@ def check_eval(expr, expected_result, mod=None, rtol=1e-07):
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential(passes)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
......@@ -58,7 +58,7 @@ def dcpe(expr, mod=None, grad=False):
if mod:
assert isinstance(expr, Function)
mod["main"] = expr
seq = transform.Sequential(passes)
seq = tvm.transform.Sequential(passes)
mod = seq(mod)
return mod["main"]
return run_opt_pass(expr, passes)
......
......@@ -496,7 +496,7 @@ def test_function_lifting():
op_list = ["nn.batch_norm", "nn.conv2d"]
mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
opt_pass = transform.Sequential([
opt_pass = tvm.transform.Sequential([
transform.InferType(),
transform.PartitionGraph(),
transform.SimplifyInference(),
......@@ -578,7 +578,7 @@ def test_function_lifting_inline():
op_list = ["nn.batch_norm", "nn.conv2d"]
mod = WhiteListAnnotator(op_list, "test_compiler")(mod)
opt_pass = transform.Sequential([
opt_pass = tvm.transform.Sequential([
transform.InferType(),
transform.PartitionGraph(),
transform.SimplifyInference(),
......@@ -878,13 +878,13 @@ def test_dnnl_fuse():
# This is required for constant folding
mod["main"] = bind_params_by_name(mod["main"], params)
remove_bn_pass = transform.Sequential([
remove_bn_pass = tvm.transform.Sequential([
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
])
composite_partition = transform.Sequential([
composite_partition = tvm.transform.Sequential([
remove_bn_pass,
transform.MergeComposite(pattern_table),
transform.AnnotateTarget("dnnl"),
......
......@@ -37,8 +37,8 @@ def alpha_equal(x, y):
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential(passes)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
......
......@@ -28,8 +28,8 @@ from tvm.relay.analysis import Feature
def run_opt_pass(expr, passes):
passes = passes if isinstance(passes, list) else [passes]
mod = tvm.IRModule.from_expr(expr)
seq = transform.Sequential(passes)
with transform.PassContext(opt_level=3):
seq = tvm.transform.Sequential(passes)
with tvm.transform.PassContext(opt_level=3):
mod = seq(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
......
......@@ -71,7 +71,8 @@ def test_cps_pe():
x = run_infer_type(x)
y = un_cps(x)
y = run_infer_type(y)
x = run_opt_pass(x, transform.Sequential([transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
x = run_opt_pass(x, tvm.transform.Sequential(
[transform.PartialEvaluate(), transform.DeadCodeElimination(inline_once=True)]))
assert Feature.fRefCreate not in detect_feature(x)
unit = relay.Function([], relay.const(0., dtype='float32'))
f_ref = relay.Var("f_ref")
......
......@@ -29,7 +29,7 @@ introduced an infrastructure to manage the optimization passes.
The optimizations of a Relay program could be applied at various granularity,
namely function-level and module-level using :py:class:`tvm.relay.transform.FunctionPass`
and py:class:`tvm.relay.transform.ModulePass`
respectively. Or users can rely on py:class:`tvm.relay.transform.Sequential` to apply a sequence of passes
respectively. Or users can rely on py:class:`tvm.transform.Sequential` to apply a sequence of passes
on a Relay program where the dependencies between passes can be resolved by the
pass infra. For more details about each type of these passes, please refer to
the :ref:`relay-pass-infra`
......@@ -130,22 +130,22 @@ print(mod)
# fusion, as this pass generates let bindings for each expression to
# canonicalize a Relay program.
#
# Relay, hence, provides :py:class:`tvm.relay.transform.Sequential` to alleviate developers from handling
# Relay, hence, provides :py:class:`tvm.transform.Sequential` to alleviate developers from handling
# these issues explicitly by specifying the required passes of each pass and
# packing them as a whole to execute. For example, the same passes can now be
# applied using the sequential style as the following. :py:class:`tvm.relay.transform.Sequential` is
# applied using the sequential style as the following. :py:class:`tvm.transform.Sequential` is
# similiar to `torch.nn.sequential <https://pytorch.org/docs/stable/nn.html#torch.nn.Sequential>`_
# and `mxnet.gluon.block <https://mxnet.incubator.apache.org/api/python/docs/_modules/mxnet/gluon/block.html>`_.
# For example, `torch.nn.sequential` is used to contain a sequence of PyTorch
# `Modules` that will be added to build a network. It focuses on the network
# layers. Instead, the :py:class:`tvm.relay.transform.Sequential` in our pass infra works on the optimizing
# layers. Instead, the :py:class:`tvm.transform.Sequential` in our pass infra works on the optimizing
# pass.
# Now let's execute some passes through :py:class:`tvm.relay.transform.Sequential`
# Now let's execute some passes through :py:class:`tvm.transform.Sequential`
f = example()
mod = tvm.IRModule.from_expr(f)
# Glob the interested passes.
seq = relay.transform.Sequential([relay.transform.FoldConstant(),
seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
relay.transform.EliminateCommonSubexpr(),
relay.transform.FuseOps(fuse_opt_level=2)])
mod1 = seq(mod)
......@@ -156,7 +156,7 @@ print(mod1)
# identical addition operations. This is because `EliminateCommonSubexpr`
# was not actually performed. The reason is because only the passes that have
# optimization level less or equal to 2 will be executed by default under
# :py:class:`tvm.relay.transform.Sequential`. The pass infra,
# :py:class:`tvm.transform.Sequential`. The pass infra,
# however, provides a configuration interface
# for users to customize the optimization level that they want to execute.
......@@ -186,7 +186,7 @@ with relay.build_config(opt_level=3):
mod4 = seq(mod)
print(mod4)
seq1 = relay.transform.Sequential([relay.transform.AlterOpLayout()])
seq1 = tvm.transform.Sequential([relay.transform.AlterOpLayout()])
with relay.build_config(opt_level=3):
with tvm.target.create("llvm"):
mod5 = seq1(mod)
......@@ -237,11 +237,11 @@ print(mod3)
f = example()
mod = tvm.IRModule.from_expr(f)
seq = relay.transform.Sequential([relay.transform.FoldConstant(),
relay.transform.PrintIR(False),
relay.transform.EliminateCommonSubexpr(),
relay.transform.FuseOps(),
relay.transform.PrintIR(False)])
seq = tvm.transform.Sequential([relay.transform.FoldConstant(),
tvm.transform.PrintIR(),
relay.transform.EliminateCommonSubexpr(),
relay.transform.FuseOps(),
tvm.transform.PrintIR()])
with relay.build_config(opt_level=3):
mod = seq(mod)
......
......@@ -24,7 +24,7 @@ from tvm.relay import ExprMutator
def run_opt_pass(expr, opt_pass):
"""Exectue a relay pass."""
assert isinstance(opt_pass, transform.Pass)
assert isinstance(opt_pass, tvm.transform.Pass)
mod = tvm.IRModule.from_expr(expr)
mod = opt_pass(mod)
entry = mod["main"]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment