Commit 00e6108e by Jared Roesch Committed by Tianqi Chen

Add a the ability to trigger debugging in the interpreter without recompiling (#2219)

parent 993fe12f
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/relay/attrs/debug.h
* \brief Auxiliary attributes for debug operators.
*/
#ifndef TVM_RELAY_ATTRS_DEBUG_H_
#define TVM_RELAY_ATTRS_DEBUG_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*!
* \brief Options for the debug operators.
*/
struct DebugAttrs : public tvm::AttrsNode<DebugAttrs> {
EnvFunc debug_func;
TVM_DECLARE_ATTRS(DebugAttrs, "relay.attrs.DebugAttrs") {
TVM_ATTR_FIELD(debug_func)
.describe("The function to use when debugging.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_DEBUG_H_
...@@ -49,6 +49,11 @@ using TOpPattern = int; ...@@ -49,6 +49,11 @@ using TOpPattern = int;
using TOpIsStateful = bool; using TOpIsStateful = bool;
/*! /*!
* \brief Mark the operator as non-computational.
*/
using TNonComputational = bool;
/*!
* \brief Computation description interface. * \brief Computation description interface.
* *
* \note This function have a special convention * \note This function have a special convention
......
...@@ -10,6 +10,7 @@ from . import module ...@@ -10,6 +10,7 @@ from . import module
from . import ir_pass from . import ir_pass
from .build_module import build, build_config, create_executor from .build_module import build, build_config, create_executor
from . import parser from . import parser
from . import debug
# Root operators # Root operators
from .op import Op from .op import Op
...@@ -63,11 +64,5 @@ var = expr.var ...@@ -63,11 +64,5 @@ var = expr.var
const = expr.const const = expr.const
bind = expr.bind bind = expr.bind
# pylint: disable=unused-argument
@register_func("relay.debug")
def _debug(*args):
import pdb
pdb.set_trace()
# Parser # Parser
fromtext = parser.fromtext fromtext = parser.fromtext
# pylint: disable=wildcard-import, redefined-builtin, invalid-name
"""The Relay IR namespace containing the IR definition and compiler."""
from __future__ import absolute_import
from .base import NodeBase, register_relay_node
from ..api import register_func
@register_relay_node
class InterpreterState(NodeBase):
pass
# pylint: disable=unused-argument
def _debugger_init(expr, stack):
import pdb
pdb.set_trace()
# pylint: disable=unused-argument
@register_func("relay.debug")
def _debug(*args):
_, _, _, ist = args
print("Relay Debugger")
print(" You can manipulate the expression under evaluation with the name `expr`.")
print(" You can manipulate the call stack with the name `stack`.")
print("--------------")
print("--------------")
_debugger_init(ist.current_expr, ist.stack)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# operator defs # operator defs
from .op import get, register, register_schedule, register_compute, register_alter_op_layout, \ from .op import get, register, register_schedule, register_compute, register_alter_op_layout, \
Op Op
from .op import debug
# Operators # Operators
from .reduce import * from .reduce import *
...@@ -13,6 +14,7 @@ from . import image ...@@ -13,6 +14,7 @@ from . import image
from . import vision from . import vision
from . import op_attrs from . import op_attrs
# operator registry # operator registry
from . import _tensor from . import _tensor
from . import _transform from . import _transform
......
...@@ -8,6 +8,7 @@ from ..base import register_relay_node ...@@ -8,6 +8,7 @@ from ..base import register_relay_node
from ..expr import Expr from ..expr import Expr
from ...api import register_func from ...api import register_func
from ...build_module import lower, build from ...build_module import lower, build
from . import _make
@register_relay_node @register_relay_node
class Op(Expr): class Op(Expr):
...@@ -183,3 +184,18 @@ def schedule_injective(attrs, outputs, target): ...@@ -183,3 +184,18 @@ def schedule_injective(attrs, outputs, target):
"""Generic schedule for binary broadcast.""" """Generic schedule for binary broadcast."""
with target: with target:
return topi.generic.schedule_injective(outputs) return topi.generic.schedule_injective(outputs)
__DEBUG_COUNTER__ = 0
def debug(expr, debug_func=None):
"""The main entry point to the debugger."""
global __DEBUG_COUNTER__
if debug_func:
name = "debugger_func{}".format(__DEBUG_COUNTER__)
register_func(name, debug_func)
__DEBUG_COUNTER__ += 1
else:
name = ''
return _make.debug(expr, name)
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/relay/pass.h> #include <tvm/relay/pass.h>
#include <tvm/relay/attrs/debug.h>
#include "compile_engine.h" #include "compile_engine.h"
namespace tvm { namespace tvm {
...@@ -124,13 +125,48 @@ struct Stack { ...@@ -124,13 +125,48 @@ struct Stack {
}; };
}; };
/*! \brief A representation of the interpreter state which can be passed back to Python. */
class InterpreterState;
/*! \brief A container capturing the state of the interpreter. */
class InterpreterStateNode : public Node {
public:
using Frame = tvm::Map<Var, Value>;
using Stack = tvm::Array<Frame>;
/*! \brief The current expression under evaluation. */
Expr current_expr;
/*! \brief The call stack of the interpreter. */
Stack stack;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("current_expr", &current_expr);
v->Visit("stack", &stack);
}
TVM_DLL static InterpreterState make(Expr current_expr, Stack stack);
static constexpr const char* _type_key = "relay.InterpreterState";
TVM_DECLARE_NODE_TYPE_INFO(InterpreterStateNode, Node);
};
RELAY_DEFINE_NODE_REF(InterpreterState, InterpreterStateNode, NodeRef);
InterpreterState InterpreterStateNode::make(Expr current_expr, Stack stack) {
NodePtr<InterpreterStateNode> n = make_node<InterpreterStateNode>();
n->current_expr = std::move(current_expr);
n->stack = std::move(stack);
return InterpreterState(n);
}
// NOTE: the current interpreter assumes A-normal form. // NOTE: the current interpreter assumes A-normal form.
// which is better for execution. // which is better for execution.
// //
// It will run duplicated computations when taking program that // It will run duplicated computations when taking program that
// contains DAG in dataflow-form. // contains DAG in dataflow-form.
// Conversion to ANF is recommended before running the interpretation.
// //
// Conversion to ANF is recommended before running the interpretation.
class Interpreter : class Interpreter :
public ExprFunctor<Value(const Expr& n)> { public ExprFunctor<Value(const Expr& n)> {
public: public:
...@@ -209,6 +245,21 @@ class Interpreter : ...@@ -209,6 +245,21 @@ class Interpreter :
Value InvokePrimitiveOp(Function func, Value InvokePrimitiveOp(Function func,
const Array<Value>& args) { const Array<Value>& args) {
auto call_node = func->body.as<CallNode>();
if (call_node && call_node->op == Op::Get("debug")) {
auto dattrs = call_node->attrs.as<DebugAttrs>();
auto interp_state = this->get_state(call_node->args[0]);
if (dattrs->debug_func.defined()) {
dattrs->debug_func(interp_state);
} else {
RELAY_DEBUG(interp_state);
}
return args[0];
}
// Marshal the arguments. // Marshal the arguments.
// Handle tuple input/output by flattening them. // Handle tuple input/output by flattening them.
size_t arg_len = 0; size_t arg_len = 0;
...@@ -381,6 +432,16 @@ class Interpreter : ...@@ -381,6 +432,16 @@ class Interpreter :
} }
} }
InterpreterState get_state(Expr e = Expr()) const {
InterpreterStateNode::Stack stack;
for (auto fr : this->stack_.frames) {
InterpreterStateNode::Frame frame = fr.locals;
stack.push_back(frame);
}
auto state = InterpreterStateNode::make(e, stack);
return state;
}
private: private:
// module // module
Module mod_; Module mod_;
......
/*!
* Copyright (c) 2018 by Contributors
* \file nn.cc
* \brief Property def of nn operators.
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/debug.h>
#include <topi/elemwise.h>
#include <vector>
#include "./type_relations.h"
#include "./op_common.h"
#include "./layout.h"
namespace tvm {
namespace relay {
Array<Tensor> DebugCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return Array<Tensor>{ topi::identity(inputs[0]) };
}
RELAY_REGISTER_OP("debug")
.describe(R"code(Enter the interpreter's debugger.
)code" TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("program", "Tuple", "The program to execute before debugging.")
.set_support_level(1)
.add_type_rel("Debug", IdentityRel)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<FTVMCompute>("FTVMCompute", DebugCompute);
Expr MakeDebug(Expr expr, std::string name) {
auto dattrs = make_node<DebugAttrs>();
if (name.size() > 0) {
dattrs->debug_func = EnvFunc::Get(name);
} else {
dattrs->debug_func = EnvFunc();
}
static const Op& op = Op::Get("debug");
return CallNode::make(op, {expr}, Attrs(dattrs), {});
}
TVM_REGISTER_API("relay.op._make.debug")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeDebug, args, rv);
});
} // namespace relay
} // namespace tvm
from tvm.relay import var, const, create_executor
from tvm.relay.op import debug
_test_debug_hit = False
def test_debug():
global _test_debug_hit
ex = create_executor()
x = var('x', shape=(), dtype='int32')
_test_debug_hit = False
def did_exec(x):
global _test_debug_hit
_test_debug_hit = True
prog = debug(x, debug_func=did_exec)
result = ex.evaluate(prog, { x: const(1) })
assert _test_debug_hit
assert result.asnumpy() == 1
def test_debug_with_expr():
global _test_debug_hit
_test_debug_hit = False
ex = create_executor()
x = var('x', shape=(), dtype='int32')
_test_debug_hit = False
def did_exec(x):
global _test_debug_hit
_test_debug_hit = True
prog = debug(x + x * x, debug_func=did_exec)
result = ex.evaluate(prog, { x: const(2) })
assert _test_debug_hit
assert result.asnumpy() == 6
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment