Commit 02631f67 by Jared Roesch Committed by Tianqi Chen

[Relay] Add generic & informative Relay error reporting (#2408)

parent 4e573235
......@@ -7,3 +7,6 @@
[submodule "dlpack"]
path = 3rdparty/dlpack
url = https://github.com/dmlc/dlpack
[submodule "3rdparty/rang"]
path = 3rdparty/rang
url = https://github.com/agauniyal/rang
Subproject commit cabe04d6d6b05356fa8f9741704924788f0dd762
......@@ -53,6 +53,7 @@ tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
include_directories("include")
include_directories("3rdparty/dlpack/include")
include_directories("3rdparty/dmlc-core/include")
include_directories("3rdparty/rang/include")
include_directories("3rdparty/compiler-rt")
# initial variables
......
......@@ -7,25 +7,134 @@
#define TVM_RELAY_ERROR_H_
#include <string>
#include <vector>
#include <sstream>
#include "./base.h"
#include "./expr.h"
#include "./module.h"
namespace tvm {
namespace relay {
struct Error : public dmlc::Error {
explicit Error(const std::string &msg) : dmlc::Error(msg) {}
};
#define RELAY_ERROR(msg) (RelayErrorStream() << msg)
// Forward declaratio for RelayErrorStream.
struct Error;
/*! \brief A wrapper around std::stringstream.
*
* This is designed to avoid platform specific
* issues compiling and using std::stringstream
* for error reporting.
*/
struct RelayErrorStream {
std::stringstream ss;
struct InternalError : public Error {
explicit InternalError(const std::string &msg) : Error(msg) {}
template<typename T>
RelayErrorStream& operator<<(const T& t) {
ss << t;
return *this;
}
std::string str() const {
return ss.str();
}
void Raise() const;
};
struct FatalTypeError : public Error {
explicit FatalTypeError(const std::string &s) : Error(s) {}
struct Error : public dmlc::Error {
Span sp;
explicit Error(const std::string& msg) : dmlc::Error(msg), sp() {}
Error(const std::stringstream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp() {} // NOLINT(*)
};
struct TypecheckerError : public Error {
explicit TypecheckerError(const std::string &msg) : Error(msg) {}
/*! \brief An abstraction around how errors are stored and reported.
* Designed to be opaque to users, so we can support a robust and simpler
* error reporting mode, as well as a more complex mode.
*
* The first mode is the most accurate: we report a Relay error at a specific
* Span, and then render the error message directly against a textual representation
* of the program, highlighting the exact lines in which it occurs. This mode is not
* implemented in this PR and will not work.
*
* The second mode is a general-purpose mode, which attempts to annotate the program's
* textual format with errors.
*
* The final mode represents the old mode, if we report an error that has no span or
* expression, we will default to throwing an exception with a textual representation
* of the error and no indication of where it occured in the original program.
*
* The latter mode is not ideal, and the goal of the new error reporting machinery is
* to avoid ever reporting errors in this style.
*/
class ErrorReporter {
public:
ErrorReporter() : errors_(), node_to_error_() {}
/*! \brief Report a tvm::relay::Error.
*
* This API is useful for reporting spanned errors.
*
* \param err The error to report.
*/
void Report(const Error& err) {
if (!err.sp.defined()) {
throw err;
}
this->errors_.push_back(err);
}
/*! \brief Report an error against a program, using the full program
* error reporting strategy.
*
* This error reporting method requires the global function in which
* to report an error, the expression to report the error on,
* and the error object.
*
* \param global The global function in which the expression is contained.
* \param node The expression or type to report the error at.
* \param err The error message to report.
*/
inline void ReportAt(const GlobalVar& global, const NodeRef& node, std::stringstream& err) {
this->ReportAt(global, node, Error(err));
}
/*! \brief Report an error against a program, using the full program
* error reporting strategy.
*
* This error reporting method requires the global function in which
* to report an error, the expression to report the error on,
* and the error object.
*
* \param global The global function in which the expression is contained.
* \param node The expression or type to report the error at.
* \param err The error to report.
*/
void ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err);
/*! \brief Render all reported errors and exit the program.
*
* This function should be used after executing a pass to render reported errors.
*
* It will build an error message from the set of errors, depending on the error
* reporting strategy.
*
* \param module The module to report errors on.
* \param use_color Controls whether to colorize the output.
*/
void RenderErrors(const Module& module, bool use_color = true);
inline bool AnyErrors() {
return errors_.size() != 0;
}
private:
std::vector<Error> errors_;
std::unordered_map<NodeRef, std::vector<size_t>, NodeHash, NodeEqual> node_to_error_;
std::unordered_map<NodeRef, GlobalVar, NodeHash, NodeEqual> node_to_gv_;
};
} // namespace relay
......
......@@ -43,11 +43,15 @@ class ModuleNode : public RelayNode {
/*! \brief A map from ids to all global functions. */
tvm::Map<GlobalVar, Function> functions;
/*! \brief The entry function (i.e. "main"). */
GlobalVar entry_func;
ModuleNode() {}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("functions", &functions);
v->Visit("global_var_map_", &global_var_map_);
v->Visit("entry_func", &entry_func);
}
TVM_DLL static Module make(tvm::Map<GlobalVar, Function> global_funcs);
......@@ -111,6 +115,20 @@ class ModuleNode : public RelayNode {
*/
void Update(const Module& other);
/*! \brief Construct a module from a standalone expression.
*
* Allows one to optionally pass a global function map as
* well.
*
* \param expr The expression to set as the entry point to the module.
* \param global_funcs The global function map.
*
* \returns A module with expr set as the entry point.
*/
static Module FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs = {});
static constexpr const char* _type_key = "relay.Module";
TVM_DECLARE_NODE_TYPE_INFO(ModuleNode, Node);
......@@ -132,6 +150,7 @@ struct Module : public NodeRef {
using ContainerType = ModuleNode;
};
} // namespace relay
} // namespace tvm
......
......@@ -6,8 +6,8 @@
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_
#include <tvm/relay/module.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <string>
......
......@@ -295,6 +295,12 @@ class TypeReporterNode : public Node {
*/
TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;
/*!
* \brief Set the location at which to report unification errors.
* \param ref The program node to report the error.
*/
TVM_DLL virtual void SetLocation(const NodeRef& ref) = 0;
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) final {}
......
/*!
* Copyright (c) 2018 by Contributors
* \file error_reporter.h
* \brief The set of errors raised by Relay.
*/
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/error.h>
#include <string>
#include <vector>
#include <rang.hpp>
namespace tvm {
namespace relay {
void RelayErrorStream::Raise() const {
throw Error(*this);
}
template<typename T, typename U>
using NodeMap = std::unordered_map<T, U, NodeHash, NodeEqual>;
void ErrorReporter::RenderErrors(const Module& module, bool use_color) {
// First we pick an error reporting strategy for each error.
// TODO(@jroesch): Spanned errors are currently not supported.
for (auto err : this->errors_) {
CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported";
}
NodeMap<GlobalVar, NodeMap<NodeRef, std::string>> error_maps;
// Set control mode in order to produce colors;
if (use_color) {
rang::setControlMode(rang::control::Force);
}
for (auto pair : this->node_to_gv_) {
auto node = pair.first;
auto global = Downcast<GlobalVar>(pair.second);
auto has_errs = this->node_to_error_.find(node);
CHECK(has_errs != this->node_to_error_.end());
const auto& error_indicies = has_errs->second;
std::stringstream err_msg;
err_msg << rang::fg::red;
for (auto index : error_indicies) {
err_msg << this->errors_[index].what() << "; ";
}
err_msg << rang::fg::reset;
// Setup error map.
auto it = error_maps.find(global);
if (it != error_maps.end()) {
it->second.insert({ node, err_msg.str() });
} else {
error_maps.insert({ global, { { node, err_msg.str() }}});
}
}
// Now we will construct the fully-annotated program to display to
// the user.
std::stringstream annotated_prog;
// First we output a header for the errors.
annotated_prog <<
rang::style::bold << std::endl <<
"Error(s) have occurred. We have annotated the program with them:"
<< std::endl << std::endl << rang::style::reset;
// For each global function which contains errors, we will
// construct an annotated function.
for (auto pair : error_maps) {
auto global = pair.first;
auto err_map = pair.second;
auto func = module->Lookup(global);
// We output the name of the function before displaying
// the annotated program.
annotated_prog <<
rang::style::bold <<
"In `" << global->name_hint << "`: " <<
std::endl <<
rang::style::reset;
// We then call into the Relay printer to generate the program.
//
// The annotation callback will annotate the error messages
// contained in the map.
annotated_prog << RelayPrint(func, false, [&err_map](tvm::relay::Expr expr) {
auto it = err_map.find(expr);
if (it != err_map.end()) {
return it->second;
} else {
return std::string("");
}
});
}
auto msg = annotated_prog.str();
if (use_color) {
rang::setControlMode(rang::control::Auto);
}
// Finally we report the error, currently we do so to LOG(FATAL),
// it may be good to instead report it to std::cout.
LOG(FATAL) << annotated_prog.str() << std::endl;
}
void ErrorReporter::ReportAt(const GlobalVar& global, const NodeRef& node, const Error& err) {
size_t index_to_insert = this->errors_.size();
this->errors_.push_back(err);
auto it = this->node_to_error_.find(node);
if (it != this->node_to_error_.end()) {
it->second.push_back(index_to_insert);
} else {
this->node_to_error_.insert({ node, { index_to_insert }});
}
this->node_to_gv_.insert({ node, global });
}
} // namespace relay
} // namespace tvm
......@@ -23,6 +23,8 @@ Module ModuleNode::make(tvm::Map<GlobalVar, Function> global_funcs) {
<< "Duplicate global function name " << kv.first->name_hint;
n->global_var_map_.Set(kv.first->name_hint, kv.first);
}
n->entry_func = GlobalVarNode::make("main");
return Module(n);
}
......@@ -96,6 +98,21 @@ void ModuleNode::Update(const Module& mod) {
}
}
Module ModuleNode::FromExpr(
const Expr& expr,
const tvm::Map<GlobalVar, Function>& global_funcs) {
auto mod = ModuleNode::make(global_funcs);
auto func_node = expr.as<FunctionNode>();
Function func;
if (func_node) {
func = GetRef<Function>(func_node);
} else {
func = FunctionNode::make({}, expr, Type(), {}, {});
}
mod->Add(mod->entry_func, func);
return mod;
}
TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Module")
......
......@@ -70,9 +70,12 @@ Type ConcreteBroadcast(const TensorType& t1,
} else if (EqualConstInt(s2, 1)) {
oshape.push_back(s1);
} else {
LOG(FATAL) << "Incompatible broadcast type " << t1 << " and " << t2;
RELAY_ERROR(
"Incompatible broadcast type "
<< t1 << " and " << t2).Raise();
}
}
size_t max_ndim = std::max(ndim1, ndim2);
auto& rshape = (ndim1 > ndim2) ? t1->shape : t2->shape;
for (; i <= max_ndim; ++i) {
......@@ -92,7 +95,8 @@ bool BroadcastRel(const Array<Type>& types,
if (auto t0 = ToTensorType(types[0])) {
if (auto t1 = ToTensorType(types[1])) {
CHECK_EQ(t0->dtype, t1->dtype);
reporter->Assign(types[2], ConcreteBroadcast(t0, t1, t0->dtype));
reporter->Assign(types[2],
ConcreteBroadcast(t0, t1, t0->dtype));
return true;
}
}
......
......@@ -16,7 +16,7 @@ class TypeSolver::Reporter : public TypeReporterNode {
: solver_(solver) {}
void Assign(const Type& dst, const Type& src) final {
solver_->Unify(dst, src);
solver_->Unify(dst, src, location);
}
bool Assert(const IndexExpr& cond) final {
......@@ -35,7 +35,14 @@ class TypeSolver::Reporter : public TypeReporterNode {
return true;
}
TVM_DLL void SetLocation(const NodeRef& ref) final {
location = ref;
}
private:
/*! \brief The location to report unification errors at. */
mutable NodeRef location;
TypeSolver* solver_;
};
......@@ -329,8 +336,10 @@ class TypeSolver::Merger : public TypeFunctor<void(const Type&)> {
};
// constructor
TypeSolver::TypeSolver()
: reporter_(make_node<Reporter>(this)) {
TypeSolver::TypeSolver(const GlobalVar &current_func, ErrorReporter* err_reporter)
: reporter_(make_node<Reporter>(this)),
current_func(current_func),
err_reporter_(err_reporter) {
}
// destructor
......@@ -351,16 +360,26 @@ void TypeSolver::MergeFromTo(TypeNode* src, TypeNode* dst) {
}
// Add equality constraint
Type TypeSolver::Unify(const Type& dst, const Type& src) {
Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) {
// NB(@jroesch): we should probably pass location into the unifier to do better
// error reporting as well.
Unifier unifier(this);
return unifier.Unify(dst, src);
}
void TypeSolver::ReportError(const Error& err, const NodeRef& location) {
this->err_reporter_->ReportAt(
this->current_func,
location,
err);
}
// Add type constraint to the solver.
void TypeSolver::AddConstraint(const TypeConstraint& constraint) {
void TypeSolver::AddConstraint(const TypeConstraint& constraint, const NodeRef& loc) {
if (auto *op = constraint.as<TypeRelationNode>()) {
// create a new relation node.
RelationNode* rnode = arena_.make<RelationNode>();
rnode->location = loc;
rnode->rel = GetRef<TypeRelation>(op);
rel_nodes_.push_back(rnode);
// populate the type information.
......@@ -404,29 +423,52 @@ bool TypeSolver::Solve() {
args.push_back(Resolve(tlink->value->FindRoot()->resolved_type));
CHECK_LE(args.size(), rel->args.size());
}
// call the function
CHECK(rnode->location.defined())
<< "undefined location, should be set when constructing relation node";
// We need to set this in order to understand where unification
// errors generated by the error reporting are coming from.
reporter_->SetLocation(rnode->location);
try {
// Call the Type Relation's function.
bool resolved = rel->func(args, rel->num_inputs, rel->attrs, reporter_);
// mark inqueue as false after the function call
// so that rnode itself won't get enqueued again.
rnode->inqueue = false;
if (resolved) {
++num_resolved_rels_;
}
rnode->resolved = resolved;
} catch (const Error& err) {
this->ReportError(err, rnode->location);
rnode->resolved = false;
} catch (const dmlc::Error& err) {
rnode->resolved = false;
this->ReportError(
RELAY_ERROR(
"an internal invariant was violdated while" \
"typechecking your program" <<
err.what()), rnode->location);
}
// Mark inqueue as false after the function call
// so that rnode itself won't get enqueued again.
rnode->inqueue = false;
}
// This criterion is not necessarily right for all the possible cases
// TODO(tqchen): We should also count the number of in-complete types.
return num_resolved_rels_ == rel_nodes_.size();
}
// Expose type solver only for debugging purposes.
TVM_REGISTER_API("relay._ir_pass._test_type_solver")
.set_body([](runtime::TVMArgs args, runtime::TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
auto solver = std::make_shared<TypeSolver>();
ErrorReporter err_reporter;
auto solver = std::make_shared<TypeSolver>(GlobalVarNode::make("test"), &err_reporter);
auto mod = [solver](std::string name) -> PackedFunc {
if (name == "Solve") {
......@@ -435,7 +477,7 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver")
});
} else if (name == "Unify") {
return TypedPackedFunc<Type(Type, Type)>([solver](Type lhs, Type rhs) {
return solver->Unify(lhs, rhs);
return solver->Unify(lhs, rhs, lhs);
});
} else if (name == "Resolve") {
return TypedPackedFunc<Type(Type)>([solver](Type t) {
......@@ -443,7 +485,9 @@ TVM_REGISTER_API("relay._ir_pass._test_type_solver")
});
} else if (name == "AddConstraint") {
return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
return solver->AddConstraint(c);
Expr e = VarNode::make("dummy_var",
IncompleteTypeNode::make(TypeVarNode::Kind::kType));
return solver->AddConstraint(c, e);
});
} else {
return PackedFunc();
......
......@@ -6,8 +6,10 @@
#ifndef TVM_RELAY_PASS_TYPE_SOLVER_H_
#define TVM_RELAY_PASS_TYPE_SOLVER_H_
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/error.h>
#include <vector>
#include <queue>
#include "../../common/arena.h"
......@@ -40,13 +42,14 @@ using common::LinkedList;
*/
class TypeSolver {
public:
TypeSolver();
TypeSolver(const GlobalVar& current_func, ErrorReporter* err_reporter);
~TypeSolver();
/*!
* \brief Add a type constraint to the solver.
* \param constraint The constraint to be added.
* \param location The location at which the constraint was incurred.
*/
void AddConstraint(const TypeConstraint& constraint);
void AddConstraint(const TypeConstraint& constraint, const NodeRef& lcoation);
/*!
* \brief Resolve type to the solution type in the solver.
* \param type The type to be resolved.
......@@ -62,8 +65,16 @@ class TypeSolver {
* \brief Unify lhs and rhs.
* \param lhs The left operand.
* \param rhs The right operand
* \param location The location at which the unification problem arose.
*/
Type Unify(const Type& lhs, const Type& rhs);
Type Unify(const Type& lhs, const Type& rhs, const NodeRef& location);
/*!
* \brief Report an error at the provided location.
* \param err The error to report.
* \param loc The location at which to report the error.
*/
void ReportError(const Error& err, const NodeRef& location);
private:
class OccursChecker;
......@@ -112,6 +123,7 @@ class TypeSolver {
return root;
}
};
/*! \brief relation node */
struct RelationNode {
/*! \brief Whether the relation is in the queue to be solved */
......@@ -122,7 +134,10 @@ class TypeSolver {
TypeRelation rel;
/*! \brief list types to this relation */
LinkedList<TypeNode*> type_list;
/*! \brief The location this type relation originated from. */
NodeRef location;
};
/*! \brief List of all allocated type nodes */
std::vector<TypeNode*> type_nodes_;
/*! \brief List of all allocated relation nodes */
......@@ -137,6 +152,11 @@ class TypeSolver {
common::Arena arena_;
/*! \brief Reporter that reports back to self */
TypeReporter reporter_;
/*! \brief The global representing the current function. */
GlobalVar current_func;
/*! \brief Error reporting. */
ErrorReporter* err_reporter_;
/*!
* \brief GetTypeNode that is corresponds to t.
* if it do not exist, create a new one.
......
import tvm
from tvm import relay
def check_type_err(expr, msg):
try:
expr = relay.ir_pass.infer_type(expr)
assert False
except tvm.TVMError as err:
assert msg in str(err)
def test_too_many_args():
x = relay.var('x', shape=(10, 10))
f = relay.Function([x], x)
y = relay.var('y', shape=(10, 10))
check_type_err(
f(x, y),
"the function is provided too many arguments expected 1, found 2;")
def test_too_few_args():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(10, 10))
f = relay.Function([x, y], x)
check_type_err(f(x), "the function is provided too few arguments expected 2, found 1;")
def test_rel_fail():
x = relay.var('x', shape=(10, 10))
y = relay.var('y', shape=(11, 10))
f = relay.Function([x, y], x + y)
check_type_err(f(x, y), "Incompatible broadcast type TensorType([10, 10], float32) and TensorType([11, 10], float32);")
if __name__ == "__main__":
test_too_many_args()
test_too_few_args()
test_rel_fail()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment