Commit 9911044b by Steven S. Lyubomirsky Committed by Zhi

[Relay][Pass][Docs] Update the doc for adding a Relay pass to mention the pass infra (#3583)

* Update the Relay adding pass doc to reference the new pass infrastructure

* Correct pass name

Co-Authored-By: Zhi <5145158+zhiics@users.noreply.github.com>

* Align header equals signs
parent 3ada7c0e
......@@ -22,17 +22,15 @@ Adding a Compiler Pass to Relay
Compiler passes are the primary interface for both extending Relay's feature
set and for performing optimizations on Relay programs. By writing a compiler
pass, you can then modify the AST and/or collect information about the AST,
depending on your goal. Indeed, some of Relay's most important "built-in"
features (e.g., autodiff and type inference) are nothing more than compiler
passes.
pass, you can modify the AST or collect information about the AST,
depending on your goal. Indeed, some of Relay's most important built-in
features (e.g., autodiff and type inference) are nothing more than "standard"
compiler passes.
At a high level, there are three key components to writing a pass:
At a high level, there are two key components to writing a pass:
- Creating one or more C++ classes that traverse the program
- Registering an API endpoint (a TVM packed function) with the
``TVM_REGISTER_API`` macro that performs the pass
- Wrapping the Python API hook in a neater interface
- Wrapping the traversal implementation and its metadata in the pass manager API so it can neatly interface with the :ref:`relay-pass-infra`
To begin, we'll give an overview of the key mechanisms for writing a compiler
pass. Then, we'll walk through a concrete example of the constant-folding
......@@ -183,9 +181,9 @@ Example: Constant Folding
-------------------------
In order to better understand the process of writing a pass, we will look at
the constant folding pass (found in ``src/relay/pass/fold_constant.cc`` and
in ``python/tvm/relay/ir_pass.py``) as a guide, because it is a relatively
simple pass that incorporates both types of traversals.
the constant folding pass (found in `src/relay/pass/fold_constant.cc`_)
as a guide, because it is a relatively simple pass that incorporates
both types of traversals.
Constant folding involves evaluating expressions in the program that only
involve constant values, then replacing those expressions with the result
......@@ -327,32 +325,82 @@ all of the arguments are constant (using ``ConstantChecker``). Evaluating the
call produces a **value**, so we use a helper method ``ValueToExpr`` to allow
us to place the evaluated expression back into the AST.
Now, we construct the public interface ``FoldConstant`` to our constant
folder, which is a standalone function outside of the ``ConstantFolder``
class. ``FoldConstant`` takes an expression and internally creates and uses a
Now, we construct a more convenient interface ``FoldConstant`` for our constant
folder. ``FoldConstant`` is a standalone function outside of the ``ConstantFolder``
class that takes an expression and internally creates and uses a
``ConstantFolder`` instance (the full definition can be found in
``include/tvm/relay/pass.h``).
`src/relay/pass/fold_constant.cc`_).
To allow other C++ modules to use our pass, we declare the public interface
in ``src/relay/pass/pass.h``:
Registering a Pass with the Pass Manager
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
*Note: please see the documentation on the :ref:`relay-pass-infra` for more specific detail on this subject.*
With the AST traversers written, the pass can be registered to become a TVM
API endpoint with the following code:
.. code:: c
TVM_DLL Expr FoldConstant(const Expr& expr);
namespace transform {
Registering an API Endpoint
~~~~~~~~~~~~~~~~~~~~~~~~~~~
Pass FoldConstant() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(FoldConstant(f));
};
return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
}
With the AST traversers written, the pass can be registered to become a TVM
API endpoint with the following code snippet:
} // namespace transform
If the ``Pass`` object produced by the above code is given to the pass infrastructure,
it will ensure that the AST traversal is applied to every function in the
given Relay module, which is the behavior one would expect for a constant folding
pass (it should fold all constants where possible).
The function ``CreateFunctionPass``
allows for registering the optimization level of the pass (in this case, 2), which can
be used to group together passes based on their general utility, a name for the pass,
and any dependencies for the pass. A pass's dependencies are given as a list of any passes
whose results are necessary to be able to run the current pass. ``FoldConstant`` does not
have any dependencies, but many Relay passes do depend on having type information,
so ``InferType`` is a common dependency; others may depend on the program's being in
A-normal form, via the ``ToANormalForm`` pass.
Note that the ``PassContext`` object contains information a pass uses for
error reporting and configuration options; ``FoldConstant`` does not need
this information but other passes may reference their ``PassContext`` objects.
The pass can now be invoked via the pass infrastructure, though it's a good idea to
also add a Python binding for the pass, as in this code snippet:
.. code:: c
TVM_REGISTER_API("relay._ir_pass.FoldConstant")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FoldConstant(args[0]);
});
TVM_REGISTER_API("relay._transform.FoldConstant")
.set_body_typed(FoldConstant);
Once ``Pass`` objects are defined in the above fashion, they can be invoked using the
pass infrastructure's ``Sequential`` construct, which takes a list of passes and applies
them in sequence to a Relay module, obtaining a transformed module as a result. For example,
the below code applies both the ``FoldConstant`` and ``ToANormalForm`` passes
(one after the other) to each function in ``mod`` and obtains a new module.
.. code:: python
seq = transform.Sequential([
relay.transform.FoldConstant(),
relay.transform.ToANormalForm()
])
new_mod = seq(mod)
More detail about registration can be found in :ref:`tvm-runtime-system` and more
information about the pass manager interface can be found in :ref:`relay-pass-infra`.
Relay's standard passes are listed in `include/tvm/relay/transform.h`_ and implemented
in `src/relay/pass/`_.
.. _include/tvm/relay/transform.h: https://github.com/dmlc/tvm/blob/master/include/tvm/relay/transform.h
.. _src/relay/pass: https://github.com/dmlc/tvm/tree/master/src/relay/pass
And the pass can now be used in C++ and Python, though it's a good idea to
wrap the API in Python, as described in :ref:`relay-add-op`. More detail
about registration can be found in :ref:`tvm-runtime-system`.
.. _src/relay/pass/fold_constant.cc: https://github.com/dmlc/tvm/blob/master/src/relay/pass/fold_constant.cc
......@@ -15,8 +15,10 @@
specific language governing permissions and limitations
under the License.
Relay Pass Infra
==================================
.. _relay-pass-infra:
Relay Pass Infrastructure
=========================
Relay features a series of optimization passes which improve performance metrics
of models such as mean inference, memory footprint, or power consumption for
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment