relay_add_pass.rst 15.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
..  Licensed to the Apache Software Foundation (ASF) under one
    or more contributor license agreements.  See the NOTICE file
    distributed with this work for additional information
    regarding copyright ownership.  The ASF licenses this file
    to you under the Apache License, Version 2.0 (the
    "License"); you may not use this file except in compliance
    with the License.  You may obtain a copy of the License at

..    http://www.apache.org/licenses/LICENSE-2.0

..  Unless required by applicable law or agreed to in writing,
    software distributed under the License is distributed on an
    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    KIND, either express or implied.  See the License for the
    specific language governing permissions and limitations
    under the License.

18 19 20 21 22 23 24
.. _relay-add-pass:

Adding a Compiler Pass to Relay
===============================

Compiler passes are the primary interface for both extending Relay's feature
set and for performing optimizations on Relay programs. By writing a compiler
25 26 27 28
pass, you can modify the AST or collect information about the AST,
depending on your goal. Indeed, some of Relay's most important built-in
features (e.g., autodiff and type inference) are nothing more than "standard"
compiler passes.
29

30
At a high level, there are two key components to writing a pass:
31 32

- Creating one or more C++ classes that traverse the program
33
- Wrapping the traversal implementation and its metadata in the pass manager API so it can neatly interface with the :ref:`relay-pass-infra`
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183

To begin, we'll give an overview of the key mechanisms for writing a compiler
pass. Then, we'll walk through a concrete example of the constant-folding
pass in Relay.

AST Traversers
--------------

The base class used to traverse Relay programs is ``ExprFunctor``. The public
interface it provides is a ``VisitExpr`` method that takes an expression and
zero or more arguments and returns an instance of some type. When you extend
this class, you define the AST traversal pattern by overriding
implementations of ``VisitExpr_`` for each type of expression.

The relation between ``VisitExpr`` and ``VisitExpr_`` has to do with
dispatch. Each ``VisitExpr_`` definition targets a specific type of
expression, but you don't always know which node type you'll be visiting.
To remedy this, ``ExprFunctor`` provides a ``VisitExpr`` function which
routes from the given expression to the ``VisitExpr_`` case that handles it.
Although C++ already provides dynamic dispatch, ``ExprFunctor`` defines its
own vtable, which ``VisitExpr`` uses. By defining our own vtable, we have
more control over dispatch. For example, if we wanted to define a
``PrintVisitor`` traverser that printed "Here" before every visit, we
could override ``VisitExpr``:

.. code:: c

    void PrintVisitor::VisitExpr(const Expr& expr) {
      std::cout << "Here" << std::endl;
      ExprFunctor::VisitExpr(expr);
    }

``ExprFunctor`` itself is a very general class, which is why more often than
not, you will be extending ``ExprVisitor`` or ``ExprMutator``. These classes
extend ``ExprFunctor`` and provide default implementations of ``VisitExpr_``
that capture common traversal patterns for each expression type. Having these
default implementations means we only need to provide overriding
implementations for the expression types where we want different behavior. We
describe each subclass on its own in the following sections.

Expression Visitors
~~~~~~~~~~~~~~~~~~~

``ExprVisitor`` is for passes that don't modify the program and instead
perform program analyses and collect information. With this class,
``VisitExpr`` and the private counterparts return nothing. The ``VisitExpr_``
implementations provided by this class simply visit all of the expression's
fields that are expressions. The default implementation for ``IfNode`` is
shown below.

.. code:: c

    void ExprVisitor::VisitExpr_(const IfNode* op) {
      this->VisitExpr(op->cond);
      this->VisitExpr(op->true_branch);
      this->VisitExpr(op->false_branch);
    }

Note that we're calling ``VisitExpr`` and not ``VisitExpr_`` here, so we can
use the vtable in ``ExprFunctor`` for routing.

Now, if we wanted to write a class ``CallChecker`` that checks if any
function calls appear in the program, we would only need to extend
``ExprVisitor`` and define the following ``VisitExpr_`` method:

.. code:: c

    void VisitExpr_(const CallNode* n) final {
      result_ = true;
    }

where ``result_`` is a field. In this case, we don't need to further recurse
on the fields of the ``CallNode``, because ``result_`` is already true and we
now know the original expression contains a call. To make this visitor
usable, we would provide the following public method:

.. code:: c

    bool Check(const Expr& expr) final {
      result_ = false;
      VisitExpr(expr);
      return result_;
    }

And that's all we need. It is very common to define a public interface that
performs some bookkeeping before invoking the top-level recursion. We could
of course further wrap the API by making a standalone procedure that creates
a ``CallChecker`` instance and calls ``Check`` on it, but the takeaway is
that we've achieved our goal with very little effort.

Expression Mutators
~~~~~~~~~~~~~~~~~~~

``ExprMutator`` is for passes that transform the program in some way. With
this class, ``VisitExpr`` and its private counterparts return ``Expr``. The
default ``VisitExpr_`` implementations provided by this class visit all of
the expression's fields that are expressions and set the fields to be the
result of visiting them. The default implementation for ``TupleGetItemNode``
is shown below.

.. code:: c

    Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
      auto t = this->Mutate(g->tuple);
      if (g->tuple == t) {
        return GetRef<Expr>(g);
      } else {
        return TupleGetItemNode::make(t, g->index);
      }
    }

There are a few things to notice here. First, ``Mutate`` is an alias for
``VisitExpr`` in ``ExprMutator``. Second, we only return a new node if the
call to ``Mutate`` modified the ``tuple`` field. This method of update is
called a functional update and doing so avoids unnecessary allocations.

One feature ``ExprMutator`` has that ``ExprVisitor`` doesn't is a built-in
``memo_`` field for caching results. It makes sense that ``ExprMutator`` has
a memoizer, because we know which types of results we're caching (i.e.,
``Expr``), whereas the visit methods of ``ExprVisitor`` don't return
anything. Usually, when we want to cache results in a subclass of
``ExprVisitor``, we need to define the cache ourselves.

Now, if we wanted to write a class ``IfCollapser`` that replaces every if
statement with its true branch, we would override ``VisitExpr_`` for
``IfNode``:

.. code:: c

    Expr ExprMutator::VisitExpr_(const IfNode* op) {
      return this->Mutate(op->true_branch);
    }

Note that the returned expression will not necessarily be an ``IfNode``, and
this is fine, because the return type is ``Expr``. Now, we create the public
interface:

.. code:: c

    Expr CollapseIfs(const Expr& expr) final {
      return this->Mutate(expr);
    }

With this mutator, we didn't need to do any bookkeeping, but we still want to
follow the convention of having a descriptive method as the interface.

Example: Constant Folding
-------------------------

In order to better understand the process of writing a pass, we will look at
184 185 186
the constant folding pass (found in `src/relay/pass/fold_constant.cc`_)
as a guide, because it is a relatively simple pass that incorporates
both types of traversals.
187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327

Constant folding involves evaluating expressions in the program that only
involve constant values, then replacing those expressions with the result
of evaluating them. The goal of this pass is to frontload all of the
computations that we can. To achieve this, the constant folding pass makes
use of a visitor (``ConstantChecker``) and a mutator (``ConstantFolder``).

The ``ConstantChecker`` Visitor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This visitor is used to check if an expression is constant. In Relay, we
define an expression to be constant if it is a ``ConstantNode`` or it is a
``TupleNode`` with only constant fields.

We use a ``memo_`` field to map from nodes to whether they are constant and
to cache these results. Below are the ``VisitExpr_`` definitions in the
``ConstantChecker``.

.. code:: c

    void VisitExpr_(const ConstantNode* n) final {
      memo_[GetRef<Constant>(n)] = true;
    }

    void VisitExpr_(const TupleNode* n) final {
      bool result = true;
      for (const auto& field : n->fields) {
        if (!Check(field)) {
          result = false;
          break;
        }
      }
      memo_[GetRef<Tuple>(n)] = result;
    }

The bookkeeping used to coordinate these definitions is a ``Check`` method
that returns whether the given expression is considered constant.

.. code:: c

    bool Check(const Expr& expr) {
      const auto it = memo_.find(expr);
      if (it != memo_.end())
        return it->second;
      VisitExpr(expr);
      return memo_[expr];
    }

We don't modify ``memo_`` for every node we encounter; instead we only modify
``memo_`` when the encountered node could potentially be constant. Then we
rely on the default value being false when ``memo_`` doesn't contain
``expr``.

The ``ConstantFolder`` Mutator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

This mutator performs the bulk of the constant folding pass and internally
uses ``ConstantChecker``. In Relay, there are three node types that are
involved in constant folding: ``LetNode``, ``TupleItemGetNode``, and
``CallNode``. In the following paragraphs, we explain the roles of each in
the pass.

.. code:: c

    Expr VisitExpr_(const LetNode* op) final {
      Expr value = this->Mutate(op->value);
      if (value.as<ConstantNode>()) {
        memo_[op->var] = value;
        return this->Mutate(op->body);
      } else {
        Var var = Downcast<Var>(this->Mutate(op->var));
        Expr body = this->Mutate(op->body);
        if (var.same_as(op->var) &&
            value.same_as(op->value) &&
            body.same_as(op->body)) {
          return GetRef<Expr>(op);
        } else {
          return LetNode::make(var, value, body);
        }
      }
    }

In the ``LetNode`` case, we first attempt to const-fold the value being bound
in the expression. If we can, then we populate ``memo_`` and return the
result of visiting the body---essentially, propagating the bound value to its
use sites in the body. If we can't const-fold the bound value, we mimic the
default implementation.

.. code:: c

    Expr VisitExpr_(const TupleGetItemNode* op) final {
      Expr res = ExprMutator::VisitExpr_(op);
      op = res.as<TupleGetItemNode>();
      if (const auto* tuple = op->tuple.as<TupleNode>()) {
        return tuple->fields[op->index];
      } else {
        return res;
      }
    }

In the ``TupleItemGetNode`` case, we check if ``op->tuple`` field is a
``TupleNode``. If so, we replace the tuple get with the field of the tuple
pointed to by ``op->index``. The reason we need to check is because
``op->tuple`` might evaluate to a tuple, without itself being a tuple.

.. code:: c

    Expr VisitExpr_(const CallNode* call) final {
      static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
      Expr res = ExprMutator::VisitExpr_(call);
      call = res.as<CallNode>();
      // We don't constant fold function with zero arguments.
      // This is a heuristic that is useful.
      // For example it is harmful to fold ones(shape=(4, 5)).
      if (call->args.size() == 0) return res;
      const OpNode* op = call->op.as<OpNode>();
      if (op == nullptr) return res;
      // skip stateful ops.
      if (op_stateful.get(GetRef<Op>(op), false)) return res;
      bool all_const_args = true;
      for (Expr arg : call->args) {
        if (!checker_.Check(arg)) {
          all_const_args = false;
        }
      }
      if (all_const_args) {
        return ConstEvaluate(res);
      } else {
        return res;
      }
    }

In the ``CallNode`` case, we first use the ``VisitExpr_`` of ``ExprMutator``
to visit the call, which const-folds all of the fields of the call. We use
``ExprMutator::VisitExpr_`` instead of ``VisitExpr``, because we want to
bypass the vtable (to avoid an infinite loop) and use the default
implementation provided by ``ExprMutator``. Then we evaluate the call only if
all of the arguments are constant (using ``ConstantChecker``). Evaluating the
call produces a **value**, so we use a helper method ``ValueToExpr`` to allow
us to place the evaluated expression back into the AST.

328 329 330
Now, we construct a more convenient interface ``FoldConstant`` for our constant
folder. ``FoldConstant`` is a standalone function outside of the ``ConstantFolder``
class that takes an expression and internally creates and uses a
331
``ConstantFolder`` instance (the full definition can be found in
332
`src/relay/pass/fold_constant.cc`_).
333

334 335 336 337 338 339 340 341

Registering a Pass with the Pass Manager
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

*Note: please see the documentation on the :ref:`relay-pass-infra` for more specific detail on this subject.*

With the AST traversers written, the pass can be registered to become a TVM
API endpoint with the following code:
342 343 344

.. code:: c

345
    namespace transform {
346

347 348 349 350 351 352 353
    Pass FoldConstant() {
      runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
        [=](Function f, Module m, PassContext pc) {
          return Downcast<Function>(FoldConstant(f));
      };
      return CreateFunctionPass(pass_func, 2, "FoldConstant", {});
    }
354

355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
    }  // namespace transform

If the ``Pass`` object produced by the above code is given to the pass infrastructure,
it will ensure that the AST traversal is applied to every function in the
given Relay module, which is the behavior one would expect for a constant folding
pass (it should fold all constants where possible).

The function ``CreateFunctionPass``
allows for registering the optimization level of the pass (in this case, 2), which can
be used to group together passes based on their general utility, a name for the pass,
and any dependencies for the pass. A pass's dependencies are given as a list of any passes
whose results are necessary to be able to run the current pass. ``FoldConstant`` does not
have any dependencies, but many Relay passes do depend on having type information,
so ``InferType`` is a common dependency; others may depend on the program's being in
A-normal form, via the ``ToANormalForm`` pass.

Note that the ``PassContext`` object contains information a pass uses for
error reporting and configuration options; ``FoldConstant`` does not need
this information but other passes may reference their ``PassContext`` objects.

The pass can now be invoked via the pass infrastructure, though it's a good idea to
also add a Python binding for the pass, as in this code snippet:
377 378 379

.. code:: c

380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401
   TVM_REGISTER_API("relay._transform.FoldConstant")
   .set_body_typed(FoldConstant);

Once ``Pass`` objects are defined in the above fashion, they can be invoked using the
pass infrastructure's ``Sequential`` construct, which takes a list of passes and applies
them in sequence to a Relay module, obtaining a transformed module as a result. For example,
the below code applies both the ``FoldConstant`` and ``ToANormalForm`` passes
(one after the other) to each function in ``mod`` and obtains a new module.

.. code:: python

    seq = transform.Sequential([
        relay.transform.FoldConstant(),
        relay.transform.ToANormalForm()
    ])
    new_mod = seq(mod)

More detail about registration can be found in :ref:`tvm-runtime-system` and more
information about the pass manager interface can be found in :ref:`relay-pass-infra`.
Relay's standard passes are listed in `include/tvm/relay/transform.h`_ and implemented
in `src/relay/pass/`_.

402
.. _include/tvm/relay/transform.h: https://github.com/apache/incubator-tvm/blob/master/include/tvm/relay/transform.h
403

404
.. _src/relay/pass: https://github.com/apache/incubator-tvm/tree/master/src/relay/pass
405

406
.. _src/relay/pass/fold_constant.cc: https://github.com/apache/incubator-tvm/blob/master/src/relay/pass/fold_constant.cc