Unverified Commit edc3674d by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Move error.h into ir (#4701)

We will use a single ErrorReporter to report errors during
program transformations.
parent 8037fc82
...@@ -18,12 +18,13 @@ ...@@ -18,12 +18,13 @@
*/ */
/*! /*!
* \file error.h * \file tvm/ir/error.h
* \brief The set of errors raised by Relay. * \brief Utilities for error tracking and reporting.
*/ */
#ifndef TVM_RELAY_ERROR_H_ #ifndef TVM_IR_ERROR_H_
#define TVM_RELAY_ERROR_H_ #define TVM_IR_ERROR_H_
#include <tvm/ir/span.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <string> #include <string>
...@@ -31,49 +32,65 @@ ...@@ -31,49 +32,65 @@
#include <sstream> #include <sstream>
#include <unordered_map> #include <unordered_map>
#include "./base.h"
#include "./expr.h"
namespace tvm { namespace tvm {
namespace relay { /*!
* \brief A wrapper around std::stringstream to build error.
#define RELAY_ERROR(msg) (RelayErrorStream() << msg)
// Forward declaratio for RelayErrorStream.
struct Error;
/*! \brief A wrapper around std::stringstream.
* *
* This is designed to avoid platform specific * Can be consumed by Error to construct an error.
* issues compiling and using std::stringstream *
* for error reporting. * \code
*
* void ReportError(const Error& err);
*
* void Test(int number) {
* // Use error reporter to construct an error.
* ReportError(ErrorBuilder() << "This is an error number=" << number);
* }
*
* \endcode
*/ */
struct RelayErrorStream { struct ErrorBuilder {
std::stringstream ss; public:
template<typename T> template<typename T>
RelayErrorStream& operator<<(const T& t) { ErrorBuilder& operator<<(const T& val) { // NOLINT(*)
ss << t; stream_ << val;
return *this; return *this;
} }
std::string str() const { private:
return ss.str(); std::stringstream stream_;
} friend class Error;
void Raise() const;
}; };
struct Error : public dmlc::Error { /*!
Span sp; * \brief Custom Error class to be thrown during compilation.
explicit Error(const std::string& msg) : dmlc::Error(msg), sp(nullptr) {} */
Error(const RelayErrorStream& msg) : dmlc::Error(msg.str()), sp(nullptr) {} // NOLINT(*) class Error : public dmlc::Error {
Error(const Error& err) : dmlc::Error(err.what()), sp(nullptr) {} public:
Error() : dmlc::Error(""), sp(nullptr) {} /*! \brief Location of the error */
Span span;
/*!
* \brief construct error from message.
* \param msg The message
*/
explicit Error(const std::string& msg) : dmlc::Error(msg), span(nullptr) {}
/*!
* \brief construct error from error builder.
* \param err The error builder
*/
Error(const ErrorBuilder& err) : dmlc::Error(err.stream_.str()), span(nullptr) {} // NOLINT(*)
/*!
* \brief copy constructor.
* \param other The other ereor.
*/
Error(const Error& other) : dmlc::Error(other.what()), span(other.span) {} // NOLINT(*)
/*!
* \brief default constructor. */
Error() : dmlc::Error(""), span(nullptr) {}
}; };
/*! \brief An abstraction around how errors are stored and reported. /*!
* \brief An abstraction around how errors are stored and reported.
* Designed to be opaque to users, so we can support a robust and simpler * Designed to be opaque to users, so we can support a robust and simpler
* error reporting mode, as well as a more complex mode. * error reporting mode, as well as a more complex mode.
* *
...@@ -94,23 +111,26 @@ struct Error : public dmlc::Error { ...@@ -94,23 +111,26 @@ struct Error : public dmlc::Error {
*/ */
class ErrorReporter { class ErrorReporter {
public: public:
/*! \brief default constructor. */
ErrorReporter() : errors_(), node_to_error_() {} ErrorReporter() : errors_(), node_to_error_() {}
/*! \brief Report a tvm::relay::Error. /*!
* \brief Report a tvm::Error.
* *
* This API is useful for reporting spanned errors. * This API is useful for reporting spanned errors.
* *
* \param err The error to report. * \param err The error to report.
*/ */
void Report(const Error& err) { void Report(const Error& err) {
if (!err.sp.defined()) { if (!err.span.defined()) {
throw err; throw err;
} }
this->errors_.push_back(err); this->errors_.push_back(err);
} }
/*! \brief Report an error against a program, using the full program /*!
* \brief Report an error against a program, using the full program
* error reporting strategy. * error reporting strategy.
* *
* This error reporting method requires the global function in which * This error reporting method requires the global function in which
...@@ -121,12 +141,13 @@ class ErrorReporter { ...@@ -121,12 +141,13 @@ class ErrorReporter {
* \param node The expression or type to report the error at. * \param node The expression or type to report the error at.
* \param err The error message to report. * \param err The error message to report.
*/ */
inline void ReportAt(const GlobalVar& global, const ObjectRef& node, std::stringstream& err) { void ReportAt(const GlobalVar& global, const ObjectRef& node, std::stringstream& err) {
std::string err_msg = err.str(); std::string err_msg = err.str();
this->ReportAt(global, node, Error(err_msg)); this->ReportAt(global, node, Error(err_msg));
} }
/*! \brief Report an error against a program, using the full program /*!
* \brief Report an error against a program, using the full program
* error reporting strategy. * error reporting strategy.
* *
* This error reporting method requires the global function in which * This error reporting method requires the global function in which
...@@ -139,7 +160,8 @@ class ErrorReporter { ...@@ -139,7 +160,8 @@ class ErrorReporter {
*/ */
void ReportAt(const GlobalVar& global, const ObjectRef& node, const Error& err); void ReportAt(const GlobalVar& global, const ObjectRef& node, const Error& err);
/*! \brief Render all reported errors and exit the program. /*!
* \brief Render all reported errors and exit the program.
* *
* This function should be used after executing a pass to render reported errors. * This function should be used after executing a pass to render reported errors.
* *
...@@ -161,7 +183,5 @@ class ErrorReporter { ...@@ -161,7 +183,5 @@ class ErrorReporter {
std::unordered_map<ObjectRef, GlobalVar, ObjectHash, ObjectEqual> node_to_gv_; std::unordered_map<ObjectRef, GlobalVar, ObjectHash, ObjectEqual> node_to_gv_;
}; };
} // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_IR_ERROR_H_
#endif // TVM_RELAY_ERROR_H_
...@@ -26,13 +26,16 @@ ...@@ -26,13 +26,16 @@
#define TVM_RELAY_EXPR_FUNCTOR_H_ #define TVM_RELAY_EXPR_FUNCTOR_H_
#include <tvm/node/functor.h> #include <tvm/node/functor.h>
#include <tvm/ir/error.h>
#include <string> #include <string>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include "./expr.h" #include "./expr.h"
#include "./adt.h" #include "./adt.h"
#include "./op.h" #include "./op.h"
#include "./error.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
......
...@@ -26,12 +26,14 @@ ...@@ -26,12 +26,14 @@
#define TVM_RELAY_PATTERN_FUNCTOR_H_ #define TVM_RELAY_PATTERN_FUNCTOR_H_
#include <tvm/node/functor.h> #include <tvm/node/functor.h>
#include <tvm/ir/error.h>
#include <string> #include <string>
#include <utility> #include <utility>
#include <unordered_map> #include <unordered_map>
#include "./expr.h" #include "./expr.h"
#include "./op.h" #include "./op.h"
#include "./error.h"
#include "./adt.h" #include "./adt.h"
namespace tvm { namespace tvm {
......
...@@ -59,7 +59,7 @@ ...@@ -59,7 +59,7 @@
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/relay/error.h> #include <tvm/ir/error.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
......
...@@ -18,23 +18,24 @@ ...@@ -18,23 +18,24 @@
*/ */
/*! /*!
* \file error_reporter.h * \file ir/error.cc
* \brief The set of errors raised by Relay. * \brief Utilities for error tracking and reporting.
*/ */
#include <tvm/relay/expr.h>
#include <tvm/ir/module.h> #include <tvm/ir/module.h>
#include <tvm/relay/error.h> #include <tvm/ir/error.h>
// NOTE on dependencies on relay AsText.
// We calls into relay's printing module for better rendering.
// These dependency does not happen at the interface-level.
// And is only used to enhance developer experiences when relay
// functions are presented.
#include <tvm/relay/expr.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <rang.hpp> #include <rang.hpp>
namespace tvm { namespace tvm {
namespace relay {
void RelayErrorStream::Raise() const {
throw Error(*this);
}
template<typename T, typename U> template<typename T, typename U>
using NodeMap = std::unordered_map<T, U, ObjectHash, ObjectEqual>; using NodeMap = std::unordered_map<T, U, ObjectHash, ObjectEqual>;
...@@ -43,7 +44,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { ...@@ -43,7 +44,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) {
// First we pick an error reporting strategy for each error. // First we pick an error reporting strategy for each error.
// TODO(@jroesch): Spanned errors are currently not supported. // TODO(@jroesch): Spanned errors are currently not supported.
for (auto err : this->errors_) { for (auto err : this->errors_) {
CHECK(!err.sp.defined()) << "attempting to use spanned errors, currently not supported"; CHECK(!err.span.defined()) << "attempting to use spanned errors, currently not supported";
} }
NodeMap<GlobalVar, NodeMap<ObjectRef, std::string>> error_maps; NodeMap<GlobalVar, NodeMap<ObjectRef, std::string>> error_maps;
...@@ -110,7 +111,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { ...@@ -110,7 +111,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) {
// //
// The annotation callback will annotate the error messages // The annotation callback will annotate the error messages
// contained in the map. // contained in the map.
annotated_prog << AsText(func, false, [&err_map](tvm::relay::Expr expr) { annotated_prog << relay::AsText(func, false, [&err_map](tvm::relay::Expr expr) {
auto it = err_map.find(expr); auto it = err_map.find(expr);
if (it != err_map.end()) { if (it != err_map.end()) {
CHECK_NE(it->second.size(), 0); CHECK_NE(it->second.size(), 0);
...@@ -144,5 +145,4 @@ void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node, con ...@@ -144,5 +145,4 @@ void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node, con
this->node_to_gv_.insert({ node, global }); this->node_to_gv_.insert({ node, global });
} }
} // namespace relay
} // namespace tvm } // namespace tvm
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/relay/error.h> #include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/relay/qnn/transform.h> #include <tvm/relay/qnn/transform.h>
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#ifndef TVM_RELAY_BACKEND_VM_COMPILER_H_ #ifndef TVM_RELAY_BACKEND_VM_COMPILER_H_
#define TVM_RELAY_BACKEND_VM_COMPILER_H_ #define TVM_RELAY_BACKEND_VM_COMPILER_H_
#include <tvm/relay/error.h> #include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/logging.h> #include <tvm/logging.h>
......
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
* \brief Transform operators. * \brief Transform operators.
*/ */
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/error.h> #include <tvm/ir/error.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir.h> #include <tvm/ir.h>
...@@ -392,7 +392,7 @@ bool StackRel(const Array<Type>& types, ...@@ -392,7 +392,7 @@ bool StackRel(const Array<Type>& types,
for (size_t j = 0; j < first->shape.size(); ++j) { for (size_t j = 0; j < first->shape.size(); ++j) {
if (j == static_cast<size_t>(axis)) continue; if (j == static_cast<size_t>(axis)) continue;
if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue; if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue;
throw relay::Error("relay.stack requires all tensors have the same shape " throw Error("relay.stack requires all tensors have the same shape "
"on non-stacking axes"); "on non-stacking axes");
} }
} }
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_ #ifndef TVM_RELAY_OP_TENSOR_TRANSFORM_H_
#define TVM_RELAY_OP_TENSOR_TRANSFORM_H_ #define TVM_RELAY_OP_TENSOR_TRANSFORM_H_
#include <tvm/relay/error.h> #include <tvm/ir/error.h>
#include <vector> #include <vector>
#include <algorithm> #include <algorithm>
#include <limits> #include <limits>
...@@ -48,10 +48,10 @@ bool ConcatenateRel(const Array<Type>& types, ...@@ -48,10 +48,10 @@ bool ConcatenateRel(const Array<Type>& types,
*/ */
const auto* tensor_tuple = types[0].as<TupleTypeNode>(); const auto* tensor_tuple = types[0].as<TupleTypeNode>();
if (tensor_tuple == nullptr) { if (tensor_tuple == nullptr) {
throw relay::Error( throw Error(
RELAY_ERROR( ErrorBuilder()
"concatenate requires a tuple of tensors as the first argument, found " << "concatenate requires a tuple of tensors as the first argument, found "
<< PrettyPrint(types[0]))); << PrettyPrint(types[0]));
} else if (types[0].as<IncompleteTypeNode>() != nullptr) { } else if (types[0].as<IncompleteTypeNode>() != nullptr) {
return false; return false;
} }
...@@ -68,10 +68,10 @@ bool ConcatenateRel(const Array<Type>& types, ...@@ -68,10 +68,10 @@ bool ConcatenateRel(const Array<Type>& types,
// Sanity check: axis // Sanity check: axis
int axis = param->axis; int axis = param->axis;
if (!(-ndim <= axis && axis < ndim)) { if (!(-ndim <= axis && axis < ndim)) {
throw relay::Error(RELAY_ERROR( throw Error(ErrorBuilder() <<
"concatenate only accepts `axis` in [-ndim, ndim)" << "concatenate only accepts `axis` in [-ndim, ndim)" <<
", but got axis = " << axis << ", but got axis = " << axis <<
", and ndim = " << ndim)); ", and ndim = " << ndim);
} }
axis = axis < 0 ? ndim + axis : axis; axis = axis < 0 ? ndim + axis : axis;
...@@ -85,16 +85,16 @@ bool ConcatenateRel(const Array<Type>& types, ...@@ -85,16 +85,16 @@ bool ConcatenateRel(const Array<Type>& types,
int e_ndim = static_cast<int>(e->shape.size()); int e_ndim = static_cast<int>(e->shape.size());
const DataType& e_dtype = e->dtype; const DataType& e_dtype = e->dtype;
if (e_ndim != ndim) { if (e_ndim != ndim) {
throw relay::Error("relay.concatenate requires all tensors have the same ndim"); throw Error("relay.concatenate requires all tensors have the same ndim");
} }
if (e_dtype != dtype) { if (e_dtype != dtype) {
throw relay::Error("relay.concatenate requires all tensors have the same dtype"); throw Error("relay.concatenate requires all tensors have the same dtype");
} }
for (size_t j = 0; j < first->shape.size(); ++j) { for (size_t j = 0; j < first->shape.size(); ++j) {
if (j == static_cast<size_t>(axis)) continue; if (j == static_cast<size_t>(axis)) continue;
if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue; if (reporter->AssertEQ(first->shape[j], e->shape[j])) continue;
throw relay::Error("relay.concatenate requires all tensors have the same shape " throw Error("relay.concatenate requires all tensors have the same shape "
"on non-concatenating axes"); "on non-concatenating axes");
} }
} }
......
...@@ -93,9 +93,9 @@ Type ConcreteBroadcast(const TensorType& t1, ...@@ -93,9 +93,9 @@ Type ConcreteBroadcast(const TensorType& t1,
} else if (EqualCheck(s1, s2)) { } else if (EqualCheck(s1, s2)) {
oshape.push_back(s1); oshape.push_back(s1);
} else { } else {
RELAY_ERROR( throw Error(ErrorBuilder()
"Incompatible broadcast type " << "Incompatible broadcast type "
<< t1 << " and " << t2).Raise(); << t1 << " and " << t2);
} }
} }
......
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
#ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_ #ifndef TVM_RELAY_OP_TYPE_RELATIONS_H_
#define TVM_RELAY_OP_TYPE_RELATIONS_H_ #define TVM_RELAY_OP_TYPE_RELATIONS_H_
#include <tvm/relay/error.h> #include <tvm/ir/error.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <string> #include <string>
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
* contains a data type such as `int`, `float`, `uint`. * contains a data type such as `int`, `float`, `uint`.
*/ */
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/error.h> #include <tvm/ir/error.h>
#include "../ir/type_functor.h" #include "../ir/type_functor.h"
namespace tvm { namespace tvm {
...@@ -55,11 +55,12 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> { ...@@ -55,11 +55,12 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
Kind expected, const std::string& description) { Kind expected, const std::string& description) {
Kind k = this->VisitType(t); Kind k = this->VisitType(t);
if (k != expected) { if (k != expected) {
ReportFatalError(RELAY_ERROR("Incorrect kind for a " << description ReportFatalError(ErrorBuilder()
<< ". Type " << t << " inside " << outer << "Incorrect kind for a " << description
<< " is of kind " << k << ". Type " << t << " inside " << outer
<< " but was expected to be " << " is of kind " << k
<< expected)); << " but was expected to be "
<< expected);
} }
} }
...@@ -127,8 +128,9 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> { ...@@ -127,8 +128,9 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
TypeCall tc = GetRef<TypeCall>(op); TypeCall tc = GetRef<TypeCall>(op);
const auto* gtv = op->func.as<GlobalTypeVarNode>(); const auto* gtv = op->func.as<GlobalTypeVarNode>();
if (gtv == nullptr) { if (gtv == nullptr) {
ReportFatalError(RELAY_ERROR("The callee in " << tc ReportFatalError(
<< " is not a global type var, but is " << op->func)); ErrorBuilder() <<"The callee in " << tc
<< " is not a global type var, but is " << op->func);
} }
CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function"); CheckKindMatches(op->func, tc, Kind::kAdtHandle, "type call function");
...@@ -141,8 +143,9 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> { ...@@ -141,8 +143,9 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
auto var = GetRef<GlobalTypeVar>(gtv); auto var = GetRef<GlobalTypeVar>(gtv);
auto data = mod->LookupTypeDef(var); auto data = mod->LookupTypeDef(var);
if (data->type_vars.size() != op->args.size()) { if (data->type_vars.size() != op->args.size()) {
ReportFatalError(RELAY_ERROR("Expected " << data->type_vars.size() << "arguments for " << tc ReportFatalError(ErrorBuilder()
<< "; got " << op->args.size())); << "Expected " << data->type_vars.size() << "arguments for " << tc
<< "; got " << op->args.size());
} }
return Kind::kType; return Kind::kType;
} }
...@@ -161,8 +164,9 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> { ...@@ -161,8 +164,9 @@ struct KindChecker : TypeFunctor<Kind(const Type&)> {
for (const auto& con : op->constructors) { for (const auto& con : op->constructors) {
if (!con->belong_to.same_as(op->header)) { if (!con->belong_to.same_as(op->header)) {
ReportFatalError(RELAY_ERROR(con << " has header " << con->belong_to ReportFatalError(ErrorBuilder()
<< " but " << op << " has header " << op->header)); <<con << " has header " << con->belong_to
<< " but " << op << " has header " << op->header);
} }
for (const Type& t : con->inputs) { for (const Type& t : con->inputs) {
......
...@@ -28,7 +28,7 @@ ...@@ -28,7 +28,7 @@
* dynamic error unless exhaustiveness is checked in advance. * dynamic error unless exhaustiveness is checked in advance.
*/ */
#include <tvm/relay/adt.h> #include <tvm/relay/adt.h>
#include <tvm/relay/error.h> #include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include <stack> #include <stack>
......
...@@ -38,7 +38,7 @@ ...@@ -38,7 +38,7 @@
* constraints we will trigger an error. * constraints we will trigger an error.
*/ */
#include <tvm/relay/error.h> #include <tvm/ir/error.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
...@@ -144,11 +144,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -144,11 +144,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
} catch (const dmlc::Error &e) { } catch (const dmlc::Error &e) {
this->ReportFatalError( this->ReportFatalError(
expr, expr,
RELAY_ERROR("Error unifying `" ErrorBuilder()
<< "Error unifying `"
<< t1 << t1
<< "` and `" << "` and `"
<< t2 << t2
<< "`: " << e.what())); << "`: " << e.what());
return Type(); return Type();
} }
} }
...@@ -188,9 +189,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -188,9 +189,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
if (!mod_.defined()) { if (!mod_.defined()) {
this->ReportFatalError( this->ReportFatalError(
GetRef<GlobalVar>(op), GetRef<GlobalVar>(op),
RELAY_ERROR( ErrorBuilder() <<
"Cannot do type inference on global variables " \ "Cannot do type inference on global variables " \
"without a module")); "without a module");
} }
Expr e = mod_->Lookup(var); Expr e = mod_->Lookup(var);
return e->checked_type(); return e->checked_type();
...@@ -239,16 +240,18 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -239,16 +240,18 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
auto* tc = unified.as<TypeCallNode>(); auto* tc = unified.as<TypeCallNode>();
if (!tc) { if (!tc) {
this->ReportFatalError(pc, RELAY_ERROR("Expected a type call, got " << unified)); this->ReportFatalError(pc, ErrorBuilder() << "Expected a type call, got " << unified);
} }
if (td->header != tc->func) { if (td->header != tc->func) {
this->ReportFatalError(pc, RELAY_ERROR("ADT headers must match, but we have " this->ReportFatalError(pc,
<< td->header << " and " << tc->func)); ErrorBuilder() << "ADT headers must match, but we have "
<< td->header << " and " << tc->func);
} }
if (td->type_vars.size() != tc->args.size()) { if (td->type_vars.size() != tc->args.size()) {
this->ReportFatalError(pc, RELAY_ERROR("The number of type args must match" this->ReportFatalError(pc,
<< "the number of type vars in the type data: " ErrorBuilder() << "The number of type args must match"
<< td->type_vars.size() << " != " << tc->args.size())); << "the number of type vars in the type data: "
<< td->type_vars.size() << " != " << tc->args.size());
} }
std::unordered_map<TypeVar, Type, ObjectHash, ObjectEqual> type_var_map_; std::unordered_map<TypeVar, Type, ObjectHash, ObjectEqual> type_var_map_;
for (size_t i = 0; i < td->type_vars.size(); ++i) { for (size_t i = 0; i < td->type_vars.size(); ++i) {
...@@ -256,9 +259,10 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -256,9 +259,10 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
} }
CHECK(con->constructor->inputs.size() == con->patterns.size()) << "not enough pattern"; CHECK(con->constructor->inputs.size() == con->patterns.size()) << "not enough pattern";
if (con->constructor->inputs.size() != con->patterns.size()) { if (con->constructor->inputs.size() != con->patterns.size()) {
this->ReportFatalError(pc, RELAY_ERROR("Not enough inputs for the constructor; " this->ReportFatalError(pc,
<< "expected " << con->constructor->inputs.size() ErrorBuilder() << "Not enough inputs for the constructor; "
<< ", got " << con->patterns.size())); << "expected " << con->constructor->inputs.size()
<< ", got " << con->patterns.size());
} }
for (size_t i = 0; i < con->constructor->inputs.size(); ++i) { for (size_t i = 0; i < con->constructor->inputs.size(); ++i) {
VisitPattern(con->patterns[i], Bind(con->constructor->inputs[i], type_var_map_)); VisitPattern(con->patterns[i], Bind(con->constructor->inputs[i], type_var_map_));
...@@ -278,7 +282,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -278,7 +282,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
auto* tt = unified.as<TupleTypeNode>(); auto* tt = unified.as<TupleTypeNode>();
if (!tt) { if (!tt) {
this->ReportFatalError(pt, RELAY_ERROR("Expected a tuple type, got " << unified)); this->ReportFatalError(pt, ErrorBuilder() << "Expected a tuple type, got " << unified);
} }
CHECK(tup->patterns.size() == tt->fields.size()) << "not enough pattern"; CHECK(tup->patterns.size() == tt->fields.size()) << "not enough pattern";
for (size_t i = 0; i < tup->patterns.size(); ++i) { for (size_t i = 0; i < tup->patterns.size(); ++i) {
...@@ -310,7 +314,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -310,7 +314,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Match match = GetRef<Match>(op); Match match = GetRef<Match>(op);
Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_); Array<Pattern> unmatched_cases = UnmatchedCases(match, this->mod_);
if (unmatched_cases.size() != 0) { if (unmatched_cases.size() != 0) {
RelayErrorStream ss; ErrorBuilder ss;
ss << "match expression does not handle the following cases: "; ss << "match expression does not handle the following cases: ";
int i = 0; int i = 0;
for (auto cs : unmatched_cases) { for (auto cs : unmatched_cases) {
...@@ -454,8 +458,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -454,8 +458,9 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
if (fn_ty_node == nullptr && inc_ty_node == nullptr) { if (fn_ty_node == nullptr && inc_ty_node == nullptr) {
this->ReportFatalError( this->ReportFatalError(
GetRef<Call>(call), GetRef<Call>(call),
RELAY_ERROR("only expressions with function types can be called, found " ErrorBuilder()
<< ftype)); << "only expressions with function types can be called, found "
<< ftype);
} }
// incomplete type => it must be a function taking the arg types // incomplete type => it must be a function taking the arg types
...@@ -470,11 +475,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -470,11 +475,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
Array<Type> type_args = call->type_args; Array<Type> type_args = call->type_args;
if (type_args.size() > fn_ty_node->type_params.size()) { if (type_args.size() > fn_ty_node->type_params.size()) {
this->ReportFatalError(GetRef<Call>(call), this->ReportFatalError(GetRef<Call>(call),
RELAY_ERROR("Incorrect number of type args in " ErrorBuilder()
<< "Incorrect number of type args in "
<< call->span << ": " << call->span << ": "
<< "Expected " << "Expected "
<< fn_ty_node->type_params.size() << fn_ty_node->type_params.size()
<< "but got " << type_args.size())); << "but got " << type_args.size());
} }
FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args); FuncType fn_ty = InstantiateFuncType(fn_ty_node, type_args);
...@@ -488,13 +494,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>, ...@@ -488,13 +494,15 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
if (type_arity < number_of_args) { if (type_arity < number_of_args) {
this->ReportFatalError( this->ReportFatalError(
GetRef<Call>(call), GetRef<Call>(call),
RELAY_ERROR("the function is provided too many arguments " ErrorBuilder()
<< "expected " << type_arity << ", found " << number_of_args)); << "the function is provided too many arguments "
<< "expected " << type_arity << ", found " << number_of_args);
} else { } else {
this->ReportFatalError( this->ReportFatalError(
GetRef<Call>(call), GetRef<Call>(call),
RELAY_ERROR("the function is provided too few arguments " ErrorBuilder()
<< "expected " << type_arity << ", found " << number_of_args)); << "the function is provided too few arguments "
<< "expected " << type_arity << ", found " << number_of_args);
} }
} }
......
...@@ -124,10 +124,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -124,10 +124,11 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
} else { } else {
Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type); Type resolved = this->VisitType(lhs->resolved_type, rhs->resolved_type);
if (!resolved.defined()) { if (!resolved.defined()) {
solver_->ReportError(RELAY_ERROR("unable to unify: " solver_->ReportError(
<< "`" << PrettyPrint(lhs->resolved_type) << "` and `" ErrorBuilder() << "unable to unify: "
<< PrettyPrint(rhs->resolved_type) << "`"), << "`" << PrettyPrint(lhs->resolved_type) << "` and `"
this->loc); << PrettyPrint(rhs->resolved_type) << "`",
this->loc);
return lhs->resolved_type; return lhs->resolved_type;
} else { } else {
TypeNode* top = solver_->GetTypeNode(resolved); TypeNode* top = solver_->GetTypeNode(resolved);
...@@ -225,13 +226,13 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -225,13 +226,13 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
tvm::Array<IndexExpr> shape; tvm::Array<IndexExpr> shape;
if (tt1->shape.size() != tt2->shape.size()) { if (tt1->shape.size() != tt2->shape.size()) {
this->solver_->ReportError( this->solver_->ReportError(
RELAY_ERROR( ErrorBuilder() <<
"tensor type `" << PrettyPrint(tt1) << "tensor type `" << PrettyPrint(tt1) <<
"` has " << tt1->shape.size() << "` has " << tt1->shape.size() <<
" dimensions, while `" << " dimensions, while `" <<
PrettyPrint(tt2) << PrettyPrint(tt2) <<
"` has " << tt2->shape.size() << "` has " << tt2->shape.size() <<
" dimensions"), this->loc); " dimensions", this->loc);
return Type(nullptr); return Type(nullptr);
} }
...@@ -253,7 +254,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> { ...@@ -253,7 +254,7 @@ class TypeSolver::Unifier : public TypeFunctor<Type(const Type&, const Type&)> {
} }
if (mismatches.size() != 0) { if (mismatches.size() != 0) {
RelayErrorStream err; ErrorBuilder err;
err << "in particular "; err << "in particular ";
for (auto mismatch : mismatches) { for (auto mismatch : mismatches) {
err << "dimension " err << "dimension "
...@@ -639,10 +640,11 @@ bool TypeSolver::Solve() { ...@@ -639,10 +640,11 @@ bool TypeSolver::Solve() {
rnode->resolved = false; rnode->resolved = false;
} catch (const dmlc::Error& err) { } catch (const dmlc::Error& err) {
rnode->resolved = false; rnode->resolved = false;
this->ReportError(RELAY_ERROR("an internal invariant was violated while " this->ReportError(
"typechecking your program " ErrorBuilder() << "an internal invariant was violated while "
<< err.what()), << "typechecking your program "
rnode->location); << err.what(),
rnode->location);
} }
// Mark inqueue as false after the function call // Mark inqueue as false after the function call
......
...@@ -27,7 +27,7 @@ ...@@ -27,7 +27,7 @@
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/type.h> #include <tvm/relay/type.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/error.h> #include <tvm/ir/error.h>
#include <vector> #include <vector>
#include <queue> #include <queue>
#include <unordered_map> #include <unordered_map>
......
...@@ -41,9 +41,10 @@ bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& at ...@@ -41,9 +41,10 @@ bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& at
// Check the scale and zero point types // Check the scale and zero point types
const auto* input_scales_tuple = types[1].as<TupleTypeNode>(); const auto* input_scales_tuple = types[1].as<TupleTypeNode>();
if (input_scales_tuple == nullptr) { if (input_scales_tuple == nullptr) {
throw relay::Error( throw Error(
RELAY_ERROR("qnn concatenate requires a tuple of scales as the second argument, found " ErrorBuilder()
<< PrettyPrint(types[1]))); << "qnn concatenate requires a tuple of scales as the second argument, found "
<< PrettyPrint(types[1]));
} }
for (const auto& input_scale : input_scales_tuple->fields) { for (const auto& input_scale : input_scales_tuple->fields) {
CHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx] CHECK(IsScalarType(input_scale, DataType::Float(32))); // input_scales[idx]
...@@ -51,9 +52,10 @@ bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& at ...@@ -51,9 +52,10 @@ bool QnnConcatenateRel(const Array<Type>& types, int num_inputs, const Attrs& at
const auto* input_zero_points_tuple = types[2].as<TupleTypeNode>(); const auto* input_zero_points_tuple = types[2].as<TupleTypeNode>();
if (input_zero_points_tuple == nullptr) { if (input_zero_points_tuple == nullptr) {
throw relay::Error( throw Error(
RELAY_ERROR("qnn concatenate requires a tuple of zero_points as the third argument, found " ErrorBuilder()
<< PrettyPrint(types[2]))); << "qnn concatenate requires a tuple of zero_points as the third argument, found "
<< PrettyPrint(types[2]));
} }
for (const auto& input_zero_point : input_zero_points_tuple->fields) { for (const auto& input_zero_point : input_zero_points_tuple->fields) {
CHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx] CHECK(IsScalarType(input_zero_point, DataType::Int(32))); // input_zero_points[idx]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment