Commit 9c383f64 by Jared Roesch Committed by Zhi

[PassManager] Implement pass manager tracing API (#4782)

* Implement pass tracing API

* Set is_before correctly

* Add docs for trace function

* Fix lint

* Remove PDB

* Ensure trace_func is set before calling

* Fix conditional
parent d54036a9
...@@ -621,6 +621,26 @@ By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will ...@@ -621,6 +621,26 @@ By inserting the ``PrintIR`` pass after ``FoldConstant``, the pass infra will
dump out the module IR when ``FoldConstant`` is done. Users can plug in this dump out the module IR when ``FoldConstant`` is done. Users can plug in this
pass after any pass they want to debug for viewing the optimization effect. pass after any pass they want to debug for viewing the optimization effect.
There is a more flexible debugging mechanism also exposed by the build configuration
object. One can pass a tracing function which can be used to execute arbitrary code
before and/or after each pass. A tracing function will receive a ``IRModule``, ``PassInfo``,
and a boolean indicating whether you are executing before, or after a pass.
An example is below.
.. code:: python
def print_ir(mod, info, is_before):
"""Print the name of the pass, the IR, only before passes execute."""
if is_before:
print(f"Running pass: {}", info)
print(mod)
with relay.build_config(opt_level=3, trace=print_ir):
with tvm.target.create("llvm"):
# Perform the optimizations.
mod = seq(mod)
For more pass infra related examples in Python and C++, please refer to For more pass infra related examples in Python and C++, please refer to
`tests/python/relay/test_pass_manager.py`_ and `tests/python/relay/test_pass_manager.py`_ and
`tests/cpp/relay_transform_sequential.cc`_, respectively. `tests/cpp/relay_transform_sequential.cc`_, respectively.
......
...@@ -65,6 +65,17 @@ ...@@ -65,6 +65,17 @@
namespace tvm { namespace tvm {
namespace transform { namespace transform {
// Forward declare for TraceFunc.
class PassInfo;
/*! \brief A callback for tracing passes, useful for debugging and logging.
*
*/
using TraceFunc =
runtime::TypedPackedFunc<void(const IRModule& ir_module,
const PassInfo& ctx,
bool is_before)>;
/*! /*!
* \brief PassContextNode contains the information that a pass can rely on, * \brief PassContextNode contains the information that a pass can rely on,
* such as analysis results. * such as analysis results.
...@@ -88,6 +99,8 @@ class PassContextNode : public Object { ...@@ -88,6 +99,8 @@ class PassContextNode : public Object {
/*! \brief The list of disabled passes. */ /*! \brief The list of disabled passes. */
Array<PrimExpr> disabled_pass; Array<PrimExpr> disabled_pass;
TraceFunc trace_func;
PassContextNode() = default; PassContextNode() = default;
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
...@@ -101,6 +114,7 @@ class PassContextNode : public Object { ...@@ -101,6 +114,7 @@ class PassContextNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object);
}; };
/*! /*!
* \brief PassContext that is used to configure the pass behavior. * \brief PassContext that is used to configure the pass behavior.
* *
...@@ -146,6 +160,14 @@ class PassContext : public ObjectRef { ...@@ -146,6 +160,14 @@ class PassContext : public ObjectRef {
*/ */
TVM_DLL static PassContext Current(); TVM_DLL static PassContext Current();
/*!
* \brief Apply the tracing functions of the context to the module, with the info.
* \param module The IRModule to trace.
* \param info The pass information.
* \param is_before Indicated whether the tracing is before or after a pass.
*/
TVM_DLL void Trace(const IRModule& module, const PassInfo& info, bool is_before) const;
// accessor. // accessor.
using ContainerType = PassContextNode; using ContainerType = PassContextNode;
class Internal; class Internal;
......
...@@ -78,7 +78,8 @@ class PassContext(RelayNode): ...@@ -78,7 +78,8 @@ class PassContext(RelayNode):
opt_level=2, opt_level=2,
fallback_device=_nd.cpu(), fallback_device=_nd.cpu(),
required_pass=None, required_pass=None,
disabled_pass=None): disabled_pass=None,
trace=None):
if isinstance(fallback_device, str): if isinstance(fallback_device, str):
fallback_device = _nd.context(fallback_device).device_type fallback_device = _nd.context(fallback_device).device_type
elif isinstance(fallback_device, TVMContext): elif isinstance(fallback_device, TVMContext):
...@@ -99,7 +100,7 @@ class PassContext(RelayNode): ...@@ -99,7 +100,7 @@ class PassContext(RelayNode):
self.__init_handle_by_constructor__(_transform.PassContext, opt_level, self.__init_handle_by_constructor__(_transform.PassContext, opt_level,
fallback_device, required, fallback_device, required,
disabled) disabled, trace)
def __enter__(self): def __enter__(self):
_transform.EnterPassContext(self) _transform.EnterPassContext(self)
...@@ -117,7 +118,8 @@ class PassContext(RelayNode): ...@@ -117,7 +118,8 @@ class PassContext(RelayNode):
def build_config(opt_level=2, def build_config(opt_level=2,
fallback_device=_nd.cpu(), fallback_device=_nd.cpu(),
required_pass=None, required_pass=None,
disabled_pass=None): disabled_pass=None,
trace=None):
"""Configure the build behavior by setting config variables. """Configure the build behavior by setting config variables.
Parameters Parameters
...@@ -151,13 +153,16 @@ def build_config(opt_level=2, ...@@ -151,13 +153,16 @@ def build_config(opt_level=2,
disabled_pass: set of str, optional disabled_pass: set of str, optional
Optimization passes to be disabled during optimization. Optimization passes to be disabled during optimization.
trace: Callable[[IRModule, PassInfo, bool], None]
A tracing function for debugging or introspection.
Returns Returns
------- -------
pass_context: PassContext pass_context: PassContext
The pass context for optimizations. The pass context for optimizations.
""" """
return PassContext(opt_level, fallback_device, required_pass, return PassContext(opt_level, fallback_device, required_pass,
disabled_pass) disabled_pass, trace)
@register_relay_node @register_relay_node
......
...@@ -84,6 +84,13 @@ PassContext PassContext::Create() { ...@@ -84,6 +84,13 @@ PassContext PassContext::Create() {
return PassContext(make_object<PassContextNode>()); return PassContext(make_object<PassContextNode>());
} }
void PassContext::Trace(const IRModule& module, const PassInfo& info, bool is_before) const {
auto pass_ctx_node = this->operator->();
if (pass_ctx_node->trace_func != nullptr) {
pass_ctx_node->trace_func(module, info, is_before);
}
}
class ModulePass; class ModulePass;
/*! /*!
...@@ -231,8 +238,10 @@ IRModule ModulePassNode::operator()(const IRModule& mod, ...@@ -231,8 +238,10 @@ IRModule ModulePassNode::operator()(const IRModule& mod,
<< " with opt level: " << " with opt level: "
<< pass_info->opt_level; << pass_info->opt_level;
CHECK(mod.defined()); CHECK(mod.defined());
pass_ctx.Trace(mod, pass_info, true);
IRModule updated_mod = pass_func(mod, pass_ctx); IRModule updated_mod = pass_func(mod, pass_ctx);
CHECK(updated_mod.defined()); CHECK(updated_mod.defined());
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod; return updated_mod;
} }
...@@ -414,10 +423,12 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext") ...@@ -414,10 +423,12 @@ TVM_REGISTER_GLOBAL("relay._transform.PassContext")
int fallback_device = args[1]; int fallback_device = args[1];
tvm::Array<tvm::PrimExpr> required = args[2]; tvm::Array<tvm::PrimExpr> required = args[2];
tvm::Array<tvm::PrimExpr> disabled = args[3]; tvm::Array<tvm::PrimExpr> disabled = args[3];
TraceFunc trace_func = args[4];
pctx->opt_level = opt_level; pctx->opt_level = opt_level;
pctx->fallback_device = fallback_device; pctx->fallback_device = fallback_device;
pctx->required_pass = std::move(required); pctx->required_pass = std::move(required);
pctx->disabled_pass = std::move(disabled); pctx->disabled_pass = std::move(disabled);
pctx->trace_func = std::move(trace_func);
*ret = pctx; *ret = pctx;
}); });
......
...@@ -116,7 +116,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, ...@@ -116,7 +116,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
<< pass_info->name << pass_info->name
<< " with opt level: " << " with opt level: "
<< pass_info->opt_level; << pass_info->opt_level;
pass_ctx.Trace(mod, pass_info, true);
// Execute the pass function and return a new module. // Execute the pass function and return a new module.
IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports()); IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports());
std::vector<std::pair<GlobalVar, Function> > updates; std::vector<std::pair<GlobalVar, Function> > updates;
...@@ -134,6 +134,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod, ...@@ -134,6 +134,7 @@ IRModule FunctionPassNode::operator()(const IRModule& mod,
for (const auto& pair : updates) { for (const auto& pair : updates) {
updated_mod->Add(pair.first, pair.second, true); updated_mod->Add(pair.first, pair.second, true);
} }
pass_ctx.Trace(updated_mod, pass_info, false);
return updated_mod; return updated_mod;
} }
......
...@@ -522,6 +522,36 @@ def test_print_ir(capfd): ...@@ -522,6 +522,36 @@ def test_print_ir(capfd):
assert "Dumping the module IR" in out assert "Dumping the module IR" in out
assert "multiply" in out assert "multiply" in out
__TRACE_COUNTER__ = 0
def _tracer(module, info, is_before):
global __TRACE_COUNTER__
if bool(is_before):
__TRACE_COUNTER__ += 1
def test_print_debug_callback():
global __TRACE_COUNTER__
shape = (1, 2, 3)
tp = relay.TensorType(shape, "float32")
x = relay.var("x", tp)
y = relay.add(x, x)
y = relay.multiply(y, relay.const(2, "float32"))
func = relay.Function([x], y)
seq = _transform.Sequential([
relay.transform.InferType(),
relay.transform.FoldConstant(),
relay.transform.DeadCodeElimination()
])
assert __TRACE_COUNTER__ == 0
mod = relay.Module({"main": func})
with relay.build_config(opt_level=3, trace=_tracer):
mod = seq(mod)
assert __TRACE_COUNTER__ == 4
if __name__ == "__main__": if __name__ == "__main__":
pytest.main() pytest.main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment