Commit fbb74892 by Logan Weber Committed by Jared Roesch

[Relay] Add compiler pass tutorial docs (#2746)

* Add Relay compiler pass tutorial docs

* Add Python API hook wrapping step

* Incorporate feedback

* More doc iteration

* Mooooore iteration

* Rewrite `runtime.md` in rst
parent 5293c6bf
......@@ -31,4 +31,5 @@ In this part of documentation, we share the rationale for the specific choices m
hybrid_script
relay_intro
relay_add_op
relay_add_pass
codebase_walkthrough
.. _relay-add-pass:
Adding a Compiler Pass to Relay
===============================
Compiler passes are the primary interface for both extending Relay's feature
set and for performing optimizations on Relay programs. By writing a compiler
pass, you can then modify the AST and/or collect information about the AST,
depending on your goal. Indeed, some of Relay's most important "built-in"
features (e.g., autodiff and type inference) are nothing more than compiler
passes.
At a high level, there are three key components to writing a pass:
- Creating one or more C++ classes that traverse the program
- Registering an API endpoint (a TVM packed function) with the
``TVM_REGISTER_API`` macro that performs the pass
- Wrapping the Python API hook in a neater interface
To begin, we'll give an overview of the key mechanisms for writing a compiler
pass. Then, we'll walk through a concrete example of the constant-folding
pass in Relay.
AST Traversers
--------------
The base class used to traverse Relay programs is ``ExprFunctor``. The public
interface it provides is a ``VisitExpr`` method that takes an expression and
zero or more arguments and returns an instance of some type. When you extend
this class, you define the AST traversal pattern by overriding
implementations of ``VisitExpr_`` for each type of expression.
The relation between ``VisitExpr`` and ``VisitExpr_`` has to do with
dispatch. Each ``VisitExpr_`` definition targets a specific type of
expression, but you don't always know which node type you'll be visiting.
To remedy this, ``ExprFunctor`` provides a ``VisitExpr`` function which
routes from the given expression to the ``VisitExpr_`` case that handles it.
Although C++ already provides dynamic dispatch, ``ExprFunctor`` defines its
own vtable, which ``VisitExpr`` uses. By defining our own vtable, we have
more control over dispatch. For example, if we wanted to define a
``PrintVisitor`` traverser that printed "Here" before every visit, we
could override ``VisitExpr``:
.. code:: c
void PrintVisitor::VisitExpr(const Expr& expr) {
std::cout << "Here" << std::endl;
ExprFunctor::VisitExpr(expr);
}
``ExprFunctor`` itself is a very general class, which is why more often than
not, you will be extending ``ExprVisitor`` or ``ExprMutator``. These classes
extend ``ExprFunctor`` and provide default implementations of ``VisitExpr_``
that capture common traversal patterns for each expression type. Having these
default implementations means we only need to provide overriding
implementations for the expression types where we want different behavior. We
describe each subclass on its own in the following sections.
Expression Visitors
~~~~~~~~~~~~~~~~~~~
``ExprVisitor`` is for passes that don't modify the program and instead
perform program analyses and collect information. With this class,
``VisitExpr`` and the private counterparts return nothing. The ``VisitExpr_``
implementations provided by this class simply visit all of the expression's
fields that are expressions. The default implementation for ``IfNode`` is
shown below.
.. code:: c
void ExprVisitor::VisitExpr_(const IfNode* op) {
this->VisitExpr(op->cond);
this->VisitExpr(op->true_branch);
this->VisitExpr(op->false_branch);
}
Note that we're calling ``VisitExpr`` and not ``VisitExpr_`` here, so we can
use the vtable in ``ExprFunctor`` for routing.
Now, if we wanted to write a class ``CallChecker`` that checks if any
function calls appear in the program, we would only need to extend
``ExprVisitor`` and define the following ``VisitExpr_`` method:
.. code:: c
void VisitExpr_(const CallNode* n) final {
result_ = true;
}
where ``result_`` is a field. In this case, we don't need to further recurse
on the fields of the ``CallNode``, because ``result_`` is already true and we
now know the original expression contains a call. To make this visitor
usable, we would provide the following public method:
.. code:: c
bool Check(const Expr& expr) final {
result_ = false;
VisitExpr(expr);
return result_;
}
And that's all we need. It is very common to define a public interface that
performs some bookkeeping before invoking the top-level recursion. We could
of course further wrap the API by making a standalone procedure that creates
a ``CallChecker`` instance and calls ``Check`` on it, but the takeaway is
that we've achieved our goal with very little effort.
Expression Mutators
~~~~~~~~~~~~~~~~~~~
``ExprMutator`` is for passes that transform the program in some way. With
this class, ``VisitExpr`` and its private counterparts return ``Expr``. The
default ``VisitExpr_`` implementations provided by this class visit all of
the expression's fields that are expressions and set the fields to be the
result of visiting them. The default implementation for ``TupleGetItemNode``
is shown below.
.. code:: c
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
auto t = this->Mutate(g->tuple);
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItemNode::make(t, g->index);
}
}
There are a few things to notice here. First, ``Mutate`` is an alias for
``VisitExpr`` in ``ExprMutator``. Second, we only return a new node if the
call to ``Mutate`` modified the ``tuple`` field. This method of update is
called a functional update and doing so avoids unnecessary allocations.
One feature ``ExprMutator`` has that ``ExprVisitor`` doesn't is a built-in
``memo_`` field for caching results. It makes sense that ``ExprMutator`` has
a memoizer, because we know which types of results we're caching (i.e.,
``Expr``), whereas the visit methods of ``ExprVisitor`` don't return
anything. Usually, when we want to cache results in a subclass of
``ExprVisitor``, we need to define the cache ourselves.
Now, if we wanted to write a class ``IfCollapser`` that replaces every if
statement with its true branch, we would override ``VisitExpr_`` for
``IfNode``:
.. code:: c
Expr ExprMutator::VisitExpr_(const IfNode* op) {
return this->Mutate(op->true_branch);
}
Note that the returned expression will not necessarily be an ``IfNode``, and
this is fine, because the return type is ``Expr``. Now, we create the public
interface:
.. code:: c
Expr CollapseIfs(const Expr& expr) final {
return this->Mutate(expr);
}
With this mutator, we didn't need to do any bookkeeping, but we still want to
follow the convention of having a descriptive method as the interface.
Example: Constant Folding
-------------------------
In order to better understand the process of writing a pass, we will look at
the constant folding pass (found in ``src/relay/pass/fold_constant.cc`` and
in ``python/tvm/relay/ir_pass.py``) as a guide, because it is a relatively
simple pass that incorporates both types of traversals.
Constant folding involves evaluating expressions in the program that only
involve constant values, then replacing those expressions with the result
of evaluating them. The goal of this pass is to frontload all of the
computations that we can. To achieve this, the constant folding pass makes
use of a visitor (``ConstantChecker``) and a mutator (``ConstantFolder``).
The ``ConstantChecker`` Visitor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This visitor is used to check if an expression is constant. In Relay, we
define an expression to be constant if it is a ``ConstantNode`` or it is a
``TupleNode`` with only constant fields.
We use a ``memo_`` field to map from nodes to whether they are constant and
to cache these results. Below are the ``VisitExpr_`` definitions in the
``ConstantChecker``.
.. code:: c
void VisitExpr_(const ConstantNode* n) final {
memo_[GetRef<Constant>(n)] = true;
}
void VisitExpr_(const TupleNode* n) final {
bool result = true;
for (const auto& field : n->fields) {
if (!Check(field)) {
result = false;
break;
}
}
memo_[GetRef<Tuple>(n)] = result;
}
The bookkeeping used to coordinate these definitions is a ``Check`` method
that returns whether the given expression is considered constant.
.. code:: c
bool Check(const Expr& expr) {
const auto it = memo_.find(expr);
if (it != memo_.end())
return it->second;
VisitExpr(expr);
return memo_[expr];
}
We don't modify ``memo_`` for every node we encounter; instead we only modify
``memo_`` when the encountered node could potentially be constant. Then we
rely on the default value being false when ``memo_`` doesn't contain
``expr``.
The ``ConstantFolder`` Mutator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
This mutator performs the bulk of the constant folding pass and internally
uses ``ConstantChecker``. In Relay, there are three node types that are
involved in constant folding: ``LetNode``, ``TupleItemGetNode``, and
``CallNode``. In the following paragraphs, we explain the roles of each in
the pass.
.. code:: c
Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
if (value.as<ConstantNode>()) {
memo_[op->var] = value;
return this->Mutate(op->body);
} else {
Var var = Downcast<Var>(this->Mutate(op->var));
Expr body = this->Mutate(op->body);
if (var.same_as(op->var) &&
value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return LetNode::make(var, value, body);
}
}
}
In the ``LetNode`` case, we first attempt to const-fold the value being bound
in the expression. If we can, then we populate ``memo_`` and return the
result of visiting the body---essentially, propagating the bound value to its
use sites in the body. If we can't const-fold the bound value, we mimic the
default implementation.
.. code:: c
Expr VisitExpr_(const TupleGetItemNode* op) final {
Expr res = ExprMutator::VisitExpr_(op);
op = res.as<TupleGetItemNode>();
if (const auto* tuple = op->tuple.as<TupleNode>()) {
return tuple->fields[op->index];
} else {
return res;
}
}
In the ``TupleItemGetNode`` case, we check if ``op->tuple`` field is a
``TupleNode``. If so, we replace the tuple get with the field of the tuple
pointed to by ``op->index``. The reason we need to check is because
``op->tuple`` might evaluate to a tuple, without itself being a tuple.
.. code:: c
Expr VisitExpr_(const CallNode* call) final {
static auto op_stateful = Op::GetAttr<TOpIsStateful>("TOpIsStateful");
Expr res = ExprMutator::VisitExpr_(call);
call = res.as<CallNode>();
// We don't constant fold function with zero arguments.
// This is a heuristic that is useful.
// For example it is harmful to fold ones(shape=(4, 5)).
if (call->args.size() == 0) return res;
const OpNode* op = call->op.as<OpNode>();
if (op == nullptr) return res;
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
bool all_const_args = true;
for (Expr arg : call->args) {
if (!checker_.Check(arg)) {
all_const_args = false;
}
}
if (all_const_args) {
return ConstEvaluate(res);
} else {
return res;
}
}
In the ``CallNode`` case, we first use the ``VisitExpr_`` of ``ExprMutator``
to visit the call, which const-folds all of the fields of the call. We use
``ExprMutator::VisitExpr_`` instead of ``VisitExpr``, because we want to
bypass the vtable (to avoid an infinite loop) and use the default
implementation provided by ``ExprMutator``. Then we evaluate the call only if
all of the arguments are constant (using ``ConstantChecker``). Evaluating the
call produces a **value**, so we use a helper method ``ValueToExpr`` to allow
us to place the evaluated expression back into the AST.
Now, we construct the public interface ``FoldConstant`` to our constant
folder, which is a standalone function outside of the ``ConstantFolder``
class. ``FoldConstant`` takes an expression and internally creates and uses a
``ConstantFolder`` instance (the full definition can be found in
``include/tvm/relay/pass.h``).
To allow other C++ modules to use our pass, we declare the public interface
in ``src/relay/pass/pass.h``:
.. code:: c
TVM_DLL Expr FoldConstant(const Expr& expr);
Registering an API Endpoint
~~~~~~~~~~~~~~~~~~~~~~~~~~~
With the AST traversers written, the pass can be registered to become a TVM
API endpoint with the following code snippet:
.. code:: c
TVM_REGISTER_API("relay._ir_pass.FoldConstant")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FoldConstant(args[0]);
});
And the pass can now be used in C++ and Python, though it's a good idea to
wrap the API in Python, as described in :ref:`relay-add-op`. More detail
about registration can be found in :ref:`tvm-runtime-system`.
<!--- Licensed to the Apache Software Foundation (ASF) under one -->
<!--- or more contributor license agreements. See the NOTICE file -->
<!--- distributed with this work for additional information -->
<!--- regarding copyright ownership. The ASF licenses this file -->
<!--- to you under the Apache License, Version 2.0 (the -->
<!--- "License"); you may not use this file except in compliance -->
<!--- with the License. You may obtain a copy of the License at -->
<!--- http://www.apache.org/licenses/LICENSE-2.0 -->
<!--- Unless required by applicable law or agreed to in writing, -->
<!--- software distributed under the License is distributed on an -->
<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -->
<!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations -->
<!--- under the License. -->
# TVM Runtime System
.. Licensed to the Apache Software Foundation (ASF) under one
.. or more contributor license agreements. See the NOTICE file
.. distributed with this work for additional information
.. regarding copyright ownership. The ASF licenses this file
.. to you under the Apache License, Version 2.0 (the
.. "License"); you may not use this file except in compliance
.. with the License. You may obtain a copy of the License at
..
.. http://www.apache.org/licenses/LICENSE-2.0
..
.. Unless required by applicable law or agreed to in writing,
.. software distributed under the License is distributed on an
.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
.. KIND, either express or implied. See the License for the
.. specific language governing permissions and limitations
.. under the License.
.. _tvm-runtime-system:
TVM Runtime System
==================
TVM supports multiple programming languages for the compiler stack development and deployment.
In this note, we explain the key elements of the TVM runtime.
![](http://www.tvm.ai/images/release/tvm_flexible.png)
.. image:: http://www.tvm.ai/images/release/tvm_flexible.png
We need to satisfy quite a few interesting requirements
We need to satisfy quite a few interesting requirements:
- Deployment: invoke the compiled function from python/javascript/c++ language.
- Debug: define a function in python and call that from a compiled function.
......@@ -34,30 +37,34 @@ We need to satisfy quite a few interesting requirements
We want to be able to define a function from any language and call from another.
We also want the runtime core to be minimal to deploy to embedded devices.
## PackedFunc
PackedFunc
----------
[PackedFunc](https://github.com/dmlc/tvm/blob/master/include/tvm/runtime/packed_func.h) is a simple but elegant solution
`PackedFunc`_ is a simple but elegant solution
we find to solve the challenges listed. The following code block provides an example in C++
```c++
#include <tvm/runtime/packed_func.h>
void MyAdd(TVMArgs args, TVMRetValue* rv) {
// automatically convert arguments to desired type.
int a = args[0];
int b = args[1];
// automatically assign value return to rv
*rv = a + b;
}
void CallPacked() {
PackedFunc myadd = PackedFunc(MyAdd);
// get back 3
int c = myadd(1, 2);
}
```
.. _PackedFunc: https://github.com/dmlc/tvm/blob/master/include/tvm/runtime/packed_func.h
.. code:: c
#include <tvm/runtime/packed_func.h>
void MyAdd(TVMArgs args, TVMRetValue* rv) {
// automatically convert arguments to desired type.
int a = args[0];
int b = args[1];
// automatically assign value return to rv
*rv = a + b;
}
void CallPacked() {
PackedFunc myadd = PackedFunc(MyAdd);
// get back 3
int c = myadd(1, 2);
}
In the above codeblock, we defined a PackedFunc MyAdd. It takes two arguments
: ```args``` represents input arguments and ```rv``` represents return value.
: ``args`` represents input arguments and ``rv`` represents return value.
The function is type-erased, which means that the function signature does not restrict which input type to pass in or type to return.
Under the hood, when we call a PackedFunc, it packs the input arguments to TVMArgs on stack,
and gets the result back via TVMRetValue.
......@@ -65,21 +72,23 @@ and gets the result back via TVMRetValue.
Thanks to template tricks in C++, we can call a PackedFunc just like a normal function. Because of its type-erased nature, we can call a PackedFunc from dynamic languages like python, without additional glue code for each new type function created.
The following example registers PackedFunc in C++ and calls from python.
```c++
// register a global packed function in c++
TVM_REGISTER_GLOBAL("myadd")
.set_body(MyAdd);
```
```python
import tvm
.. code:: c
myadd = tvm.get_global_func("myadd")
# prints 3
print(myadd(1, 2))
```
// register a global packed function in c++
TVM_REGISTER_GLOBAL("myadd")
.set_body(MyAdd);
Most of the magic of PackedFunc lies in ```TVMArgs``` and ```TVMRetValue``` structure.
We restrict a list of possible types which can be passed, here are the common ones
.. code:: python
import tvm
myadd = tvm.get_global_func("myadd")
# prints 3
print(myadd(1, 2))
Most of the magic of PackedFunc lies in ``TVMArgs`` and ``TVMRetValue`` structure.
We restrict a list of possible types which can be passed.
Here are the common ones:
- int, float and string
- PackedFunc itself
......@@ -92,43 +101,54 @@ Despite being minimum, the PackedFunc is sufficient for the use-case of deep lea
most functions only take DLTensor or numbers.
Since one PackedFunc can take another PackedFunc as an argument,
we can pass functions from python(as PackedFunc) to C++.
```c++
TVM_REGISTER_GLOBAL("callhello")
.set_body([](TVMArgs args, TVMRetValue* rv) {
PackedFunc f = args[0];
f("hello world");
});
```
```python
import tvm
def callback(msg):
print(msg)
# convert to PackedFunc
f = tvm.convert(callback)
callhello = tvm.get_global_func("callhello")
# prints hello world
callhello(f)
```
TVM provides a [minimum C API](https://github.com/dmlc/tvm/blob/master/include/tvm/runtime/c_runtime_api.h),
we can pass functions from python (as PackedFunc) to C++.
.. code:: c
TVM_REGISTER_GLOBAL("callhello")
.set_body([](TVMArgs args, TVMRetValue* rv) {
PackedFunc f = args[0];
f("hello world");
});
.. code:: python
import tvm
def callback(msg):
print(msg)
# convert to PackedFunc
f = tvm.convert(callback)
callhello = tvm.get_global_func("callhello")
# prints hello world
callhello(f)
TVM provides a `minimum C API`_,
which allows us to embed the PackedFunc into any languages. Besides python, so far we supported
[java](https://github.com/dmlc/tvm/tree/master/jvm) and [javascript](https://github.com/dmlc/tvm/tree/master/web).
`java`_ and `javascript`_.
This philosophy of embedded API is very like Lua, except that we don't have a new language but use C++.
.. _minimum C API: https://github.com/dmlc/tvm/blob/master/include/tvm/runtime/c_runtime_api.h
.. _java: https://github.com/dmlc/tvm/tree/master/jvm
.. _javascript: https://github.com/dmlc/tvm/tree/master/web
One fun fact about PackedFunc is that we use it for both compiler and deployment stack.
- All TVM's compiler pass functions are exposed to frontend as PackedFunc, see [here](https://github.com/dmlc/tvm/tree/master/src/api)
- All TVM's compiler pass functions are exposed to frontend as PackedFunc, see `here`_
- The compiled module also returns the compiled function as PackedFunc
.. _here: https://github.com/dmlc/tvm/tree/master/src/api
To keep the runtime minimum, we isolated the IR Node support from the deployment runtime. The resulting runtime takes around 200K - 600K depending on how many runtime driver modules (e.g., CUDA) get included.
The overhead of calling into PackedFunc vs. a normal function is small, as it is only saving a few values on the stack.
So it is OK as long as we don't wrap small functions.
In summary, the PackedFunc is the universal glue in TVM where we use it extensively to support our compiler and deployment.
## Module
Module
------
Since TVM supports multiple types of devices, we need to support different type of drivers.
We have to use the driver API to load the kernel, set up the argument in packed format and perform kernel launch.
......@@ -136,28 +156,34 @@ We also need to patch up the driver API so that the exposed functions are thread
So we often need to implement these driver glues in C++ and expose them to the user.
We can certainly not do it for each type of functions, so again PackedFunc is our answer.
TVM defines the compiled object as [Module](https://github.com/dmlc/tvm/blob/master/include/tvm/runtime/module.h).
TVM defines the compiled object as `Module`_.
The user can get the compiled function from Module as PackedFunc.
The generated compiled code can dynamically get function from Module in runtime. It caches the function handle in the first call and reuses in subsequent calls. We use this to link device code and callback into any PackedFunc(e.g., python) from generated code.
.. _Module: https://github.com/dmlc/tvm/blob/master/include/tvm/runtime/module.h
The ModuleNode is an abstract class that can be implemented by each type of device.
So far we support modules for CUDA, Metal, OpenCL and loading dynamic shared libraries. This abstraction makes introduction
of new device easy, and we do not need to redo the host code generation for each type of device.
## Remote Deployment
Remote Deployment
-----------------
The PackedFunc and Module system also makes it easy to ship the function into remote devices directly.
Under the hood, we have an RPCModule that serializes the arguments to do the data movement and launches the computation on the remote.
![](http://www.tvm.ai/images/release/tvm_rpc.png)
.. image:: http://www.tvm.ai/images/release/tvm_rpc.png
The RPC server itself is minimum and can be bundled into the runtime. We can start a minimum TVM
RPC server on iPhone/android/raspberry pi or even the browser. The cross compilation on server and shipping of the module for testing can be done in the same script. Checkout
[Cross compilation and RPC tutorial](http://docs.tvm.ai/tutorials/deployment/cross_compilation_and_rpc.html#sphx-glr-tutorials-deployment-cross-compilation-and-rpc-py) for more details.
`Cross compilation and RPC tutorial`_ for more details.
.. _Cross compilation and RPC tutorial: http://docs.tvm.ai/tutorials/deployment/cross_compilation_and_rpc.html#sphx-glr-tutorials-deployment-cross-compilation-and-rpc-py
This instant feedback gives us a lot of advantages. For example, to test the correctness of generated code on iPhone, we no longer have to write test-cases in swift/objective-c from scratch -- We can use RPC to execute on iPhone, copy the result back and do verification on the host via numpy. We can also do the profiling using the same script.
## TVM Node and Compiler Stack
TVM Node and Compiler Stack
---------------------------
As we mentioned earlier, we build compiler stack API on top of the PackedFunc runtime system.
We faced a constant changing of the compiler API for the need of research. We need a new language object or IR node whenever we want to test out new primitives.
......@@ -166,89 +192,101 @@ However, we don't want to change our API from time to time. Besides that, we als
- be able to serialize any language object and IRs
- be able to explore, print, and manipulate the IR objects in front-end language to do quick prototyping.
We introduced a base class, called [Node](https://github.com/dmlc/HalideIR/blob/master/src/tvm/node.h#L52) to solve this problem.
We introduced a base class, called `Node`_ to solve this problem.
All the language object in the compiler stack is a subclass of Node. Each node contains a string type_key that uniquely identifies
the type of object. We choose string instead of int as type key so new Node class can be added in the decentralized fashion without
adding the code back to the central repo. To ease the speed of dispatching, we allocate an integer type_index at runtime for each type_key.
.. _Node: https://github.com/dmlc/HalideIR/blob/master/src/tvm/node.h#L52
Since usually one Node object could be referenced in multiple places in the language, we use a shared_ptr to keep
track of reference. We use NodeRef class to represent a reference to the Node.
We can roughly view NodeRef class as shared_ptr to the Node container.
We can also define subclass NodeRef to hold each subtypes of Node. Each Node class needs to define the VisitAttr function.
```c++
class AttrVisitor {
public:
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, void** value) = 0;
virtual void Visit(const char* key, Type* value) = 0;
virtual void Visit(const char* key, NodeRef* value) = 0;
// ...
};
class Node {
public:
virtual void VisitAttrs(AttrVisitor* visitor) {}
// ...
};
```
.. code:: c
class AttrVisitor {
public:
virtual void Visit(const char* key, double* value) = 0;
virtual void Visit(const char* key, int64_t* value) = 0;
virtual void Visit(const char* key, uint64_t* value) = 0;
virtual void Visit(const char* key, int* value) = 0;
virtual void Visit(const char* key, bool* value) = 0;
virtual void Visit(const char* key, std::string* value) = 0;
virtual void Visit(const char* key, void** value) = 0;
virtual void Visit(const char* key, Type* value) = 0;
virtual void Visit(const char* key, NodeRef* value) = 0;
// ...
};
class Node {
public:
virtual void VisitAttrs(AttrVisitor* visitor) {}
// ...
};
Each Node subclass will override this to visit its members. Here is an example implementation of TensorNode.
```c++
class TensorNode : public Node {
public:
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief data type in the content of the tensor */
Type dtype;
/*! \brief the source operation, can be None */
Operation op;
/*! \brief the output index from source operation */
int value_index{0};
/*! \brief constructor */
TensorNode() {}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}
};
```
In the above examples, both ```Operation``` and ```Array<Expr>``` are NodeRef.
.. code:: c
class TensorNode : public Node {
public:
/*! \brief The shape of the tensor */
Array<Expr> shape;
/*! \brief data type in the content of the tensor */
Type dtype;
/*! \brief the source operation, can be None */
Operation op;
/*! \brief the output index from source operation */
int value_index{0};
/*! \brief constructor */
TensorNode() {}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("op", &op);
v->Visit("value_index", &value_index);
}
};
In the above examples, both ``Operation`` and ``Array<Expr>`` are NodeRef.
The VisitAttrs gives us a reflection API to visit each member of the object.
We can use this function to visit the node and serialize any language object recursively.
It also allows us to get members of an object easily in front-end language.
For example, in the following code, we accessed the op field of the TensorNode.
```python
import tvm
.. code:: python
import tvm
x = tvm.placeholder((3,4), name="x")
# access the op field of TensorNode
print(x.op.name)
```
x = tvm.placeholder((3,4), name="x")
# access the op field of TensorNode
print(x.op.name)
New Node can be added to C++ without changing the front-end runtime, making it easy to make extensions to the compiler stack.
Note that this is not the fastest way to expose members to front-end language, but might be one of the simplest
approaches possible. We also find that it fits our purposes as we mainly use python for testing and prototyping and still use c++
to do the heavy lifting job.
## Implementation Details
Implementation Details
----------------------
Each argument in PackedFunc contains a union value [TVMValue](https://github.com/dmlc/tvm/blob/master/include/tvm/runtime/c_runtime_api.h#L122)
Each argument in PackedFunc contains a union value `TVMValue`_
and a type code. This design allows the dynamically typed language to convert to the corresponding type directly, and statically typed language to
do runtime type checking during conversion.
.. _TVMValue: https://github.com/dmlc/tvm/blob/master/include/tvm/runtime/c_runtime_api.h#L122
The relevant files are
- [packed_func.h](https://github.com/dmlc/tvm/blob/master/include/tvm/runtime/packed_func.h) for C++ API
- [c_runtime_api.cc](https://github.com/dmlc/tvm/blob/master/src/runtime/c_runtime_api.cc#L262) for C API and how to provide callback.
- `packed_func.h`_ for C++ API
- `c_runtime_api.cc`_ for C API and how to provide callback.
.. _packed_func.h: https://github.com/dmlc/tvm/blob/master/include/tvm/runtime/packed_func.h
.. _c_runtime_api.cc: https://github.com/dmlc/tvm/blob/master/src/runtime/c_runtime_api.cc#L262
To support extension types, we used a registry system to register type related information, like support of any
in C++, see [Extension types](https://github.com/dmlc/tvm/tree/master/apps/extension) for more details.
in C++, see `Extension types`_ for more details.
.. _Extension types: https://github.com/dmlc/tvm/tree/master/apps/extension
......@@ -35,8 +35,11 @@ using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;
class ConstantChecker : private ExprVisitor {
public:
// Check whether an expression is constant. The results are memorized.
// Check whether an expression is constant. The results are memoized.
bool Check(const Expr& expr) {
// The `ConstantNode` case is common enough that we check directly for the
// case here, to avoid the time overhead of dispatching through the vtable
// and the space overhead of memoizing always-true results.
if (expr.as<ConstantNode>()) {
return true;
}
......@@ -44,7 +47,7 @@ class ConstantChecker : private ExprVisitor {
if (it != memo_.end())
return it->second;
VisitExpr(expr);
return memo_[expr]; // return memorized result or the default value false
return memo_[expr]; // return memoized result or the default value false
}
private:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment