Commit 8a46444f by Zhi Committed by Tianqi Chen

fix some pass docs (#3767)

parent e518fe1c
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
tvm.relay.ir_pass
-----------------
.. automodule:: tvm.relay.ir_pass
:members:
\ No newline at end of file
...@@ -26,6 +26,50 @@ tvm.relay.transform ...@@ -26,6 +26,50 @@ tvm.relay.transform
.. autofunction:: tvm.relay.transform.function_pass .. autofunction:: tvm.relay.transform.function_pass
.. autofunction:: tvm.relay.transform.InferType
.. autofunction:: tvm.relay.transform.FoldScaleAxis
.. autofunction:: tvm.relay.transform.BackwardFoldScaleAxis
.. autofunction:: tvm.relay.transform.ForwardFoldScaleAxis
.. autofunction:: tvm.relay.transform.SimplifyInference
.. autofunction:: tvm.relay.transform.CanonicalizeOps
.. autofunction:: tvm.relay.transform.DeadCodeElimination
.. autofunction:: tvm.relay.transform.FoldConstant
.. autofunction:: tvm.relay.transform.FuseOps
.. autofunction:: tvm.relay.transform.CombineParallelConv2D
.. autofunction:: tvm.relay.transform.AlterOpLayout
.. autofunction:: tvm.relay.transform.Legalize
.. autofunction:: tvm.relay.transform.RewriteAnnotatedOps
.. autofunction:: tvm.relay.transform.ToANormalForm
.. autofunction:: tvm.relay.transform.ToCPS
.. autofunction:: tvm.relay.transform.EtaExpand
.. autofunction:: tvm.relay.transform.ToGraphNormalForm
.. autofunction:: tvm.relay.transform.EliminateCommonSubexpr
.. autofunction:: tvm.relay.transform.PartialEvaluate
.. autofunction:: tvm.relay.transform.CanonicalizeCast
.. autofunction:: tvm.relay.transform.LambdaLift
.. autofunction:: tvm.relay.transform.PrintIR
.. autoclass:: tvm.relay.transform.Pass .. autoclass:: tvm.relay.transform.Pass
:members: :members:
......
...@@ -26,11 +26,13 @@ In this part of documentation, we share the rationale for the specific choices m ...@@ -26,11 +26,13 @@ In this part of documentation, we share the rationale for the specific choices m
runtime runtime
debugger debugger
nnvm_json_spec
nnvm_overview
hybrid_script hybrid_script
relay_intro relay_intro
relay_add_op relay_add_op
relay_pass_infra
relay_add_pass relay_add_pass
virtual_machine
codebase_walkthrough codebase_walkthrough
inferbound inferbound
nnvm_json_spec
nnvm_overview
...@@ -32,7 +32,7 @@ from .. import nd as _nd ...@@ -32,7 +32,7 @@ from .. import nd as _nd
@register_relay_node @register_relay_node
class PassInfo(RelayNode): class PassInfo(RelayNode):
"""The class that contains the meta data required by a pass. It is the """The class contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis. container of information needed by running an optimization or analysis.
This class can be extended by adding new members when more meta data is This class can be extended by adding new members when more meta data is
needed. needed.
...@@ -132,11 +132,12 @@ def build_config(opt_level=2, ...@@ -132,11 +132,12 @@ def build_config(opt_level=2,
"SimplifyInference": 0, "SimplifyInference": 0,
"OpFusion": 1, "OpFusion": 1,
"FoldConstant": 2, "FoldConstant": 2,
"CombineParallelConv2D": 3,
"FoldScaleAxis": 3, "FoldScaleAxis": 3,
"AlterOpLayout": 3, "AlterOpLayout": 3,
"CanonicalizeOps": 3, "CanonicalizeOps": 3,
"CanonicalizeCast": 3,
"EliminateCommonSubexpr": 3, "EliminateCommonSubexpr": 3,
"CombineParallelConv2D": 4,
} }
fallback_device : int, str, or tvm.TVMContext, optional fallback_device : int, str, or tvm.TVMContext, optional
...@@ -250,30 +251,6 @@ class Sequential(Pass): ...@@ -250,30 +251,6 @@ class Sequential(Pass):
passes, opt_level, name, required) passes, opt_level, name, required)
def infer_type(expr, mod=None):
"""Infer the type of an expr.
Adding Function into a Module will change it's binding,
and some passes need type inference to work without binding modification.
However, InferType() work by putting stuff into a Module, thus changing all the binding.
This is an escape patch that allow type inference without binding changing.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The input module
Returns
-------
ret : tvm.relay.Expr
The output expression.
"""
return _transform.infer_type(expr, mod)
def InferType(): def InferType():
"""Infer the type of an expr. """Infer the type of an expr.
...@@ -297,7 +274,7 @@ def FoldScaleAxis(): ...@@ -297,7 +274,7 @@ def FoldScaleAxis():
Note Note
---- ----
Internally, we will call backward_fold_scale_axis before using Internally, we will call backward_fold_scale_axis before using
forward_fold_scale_axis. As backward folding targets common conv-bn forward_fold_scale_axis as backward folding targets the common conv->bn
pattern. pattern.
""" """
return _transform.FoldScaleAxis() return _transform.FoldScaleAxis()
...@@ -314,8 +291,8 @@ def BackwardFoldScaleAxis(): ...@@ -314,8 +291,8 @@ def BackwardFoldScaleAxis():
Note Note
---- ----
It is recommended to call backward_fold_scale_axis It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis. before using forward_fold_scale_axis as backward folding targets the common
As backward folding targets common conv-bn pattern. conv->bn pattern.
""" """
return _transform.BackwardFoldScaleAxis() return _transform.BackwardFoldScaleAxis()
...@@ -331,8 +308,8 @@ def ForwardFoldScaleAxis(): ...@@ -331,8 +308,8 @@ def ForwardFoldScaleAxis():
Note Note
---- ----
It is recommended to call backward_fold_scale_axis It is recommended to call backward_fold_scale_axis
before using forward_fold_scale_axis. before using forward_fold_scale_axis, as backward folding targets the
As backward folding targets common conv-bn pattern. common conv->bn pattern.
""" """
return _transform.ForwardFoldScaleAxis() return _transform.ForwardFoldScaleAxis()
...@@ -350,9 +327,9 @@ def SimplifyInference(): ...@@ -350,9 +327,9 @@ def SimplifyInference():
def CanonicalizeOps(): def CanonicalizeOps():
""" Canonicalize special operators to basic operators. """Canonicalize special operators to basic operators.
This can simplify followed analysis. (e.g. expanding bias_add to This can simplify followed analysis, e.g. expanding bias_add to
expand_dims and broadcast_add.) expand_dims and broadcast_add.
Returns Returns
------- -------
...@@ -363,7 +340,7 @@ def CanonicalizeOps(): ...@@ -363,7 +340,7 @@ def CanonicalizeOps():
def DeadCodeElimination(inline_once=False): def DeadCodeElimination(inline_once=False):
"""Remove expressions which does not effect the program result (dead code). """Remove expressions that do not have any users (dead code).
Parameters Parameters
---------- ----------
...@@ -379,7 +356,7 @@ def DeadCodeElimination(inline_once=False): ...@@ -379,7 +356,7 @@ def DeadCodeElimination(inline_once=False):
def FoldConstant(): def FoldConstant():
"""Fold the constant expression in expr. """Fold the constant expressions in a Relay program.
Returns Returns
------- -------
...@@ -513,7 +490,7 @@ def EtaExpand(): ...@@ -513,7 +490,7 @@ def EtaExpand():
def ToGraphNormalForm(): def ToGraphNormalForm():
"""Turn A Normal Form expression into Graph Normal Form expression """Turn a Relay program in A Normal Form into Graph Normal Form
Returns Returns
------- -------
......
...@@ -826,9 +826,6 @@ Function InferType(const Function& func, ...@@ -826,9 +826,6 @@ Function InferType(const Function& func,
return Downcast<Function>(func_ret); return Downcast<Function>(func_ret);
} }
TVM_REGISTER_API("relay._transform.infer_type")
.set_body_typed<Expr(Expr, Module)>([](Expr l, Module r) { return InferType(l, r); });
namespace transform { namespace transform {
Pass InferType() { Pass InferType() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment