Unverified Commit f63631fc by Tianqi Chen Committed by GitHub

[RUNTIME] Scaffold structured error handling. (#2838)

parent a1c2fd1b
Subproject commit 9acddddfc349eda4ef99552d11cb905afeafed39
Subproject commit 2b5b1ba9c1103f438d164aca32da7cffd8cd48e8
tvm.error
---------
.. automodule:: tvm.error
:members:
:imported-members:
......@@ -11,6 +11,7 @@ Python API
target
build
module
error
ndarray
container
function
......
.. _error_guide:
Error Handling Guide
====================
TVM contains structured error classes to indicate specific types of error.
Please raise a specific error type when possible, so that users can
write code to handle a specific error category if necessary.
All the error types are defined in :any:`tvm.error` namespace.
You can directly raise the specific error object in python.
In other languages like c++, you simply add ``<ErrorType>:`` prefix to
the error message(see below).
Raise a Specific Error in C++
-----------------------------
You can add ``<ErrorType>:`` prefix to your error message to
raise an error of the corresponding type.
Note that you do not have to add a new type
:any:`tvm.error.TVMError` will be raised by default when
there is no error type prefix in the message.
This mechanism works for both ``LOG(FATAL)`` and ``CHECK`` macros.
The following code gives an example on how to do so.
.. code:: c
// src/api_test.cc
void ErrorTest(int x, int y) {
CHECK_EQ(x, y) << "ValueError: expect x and y to be equal."
if (x == 1) {
LOG(FATAL) << "InternalError: cannot reach here";
}
}
The above function is registered as PackedFunc into the python frontend,
under the name ``tvm._api_internal._ErrorTest``.
Here is what will happen if we call the registered function:
.. code::
>>> import tvm
>>> tvm._api_internal._ErrorTest(0, 1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/path/to/tvm/python/tvm/_ffi/_ctypes/function.py", line 190, in __call__
raise get_last_ffi_error()
ValueError: Traceback (most recent call last):
[bt] (3) /path/to/tvm/build/libtvm.so(TVMFuncCall+0x48) [0x7fab500b8ca8]
[bt] (2) /path/to/tvm/build/libtvm.so(+0x1c4126) [0x7fab4f7f5126]
[bt] (1) /path/to/tvm/build/libtvm.so(+0x1ba2f8) [0x7fab4f7eb2f8]
[bt] (0) /path/to/tvm/build/libtvm.so(+0x177d12) [0x7fab4f7a8d12]
File "/path/to/tvm/src/api/api_test.cc", line 80
ValueError: Check failed: x == y (0 vs. 1) : expect x and y to be equal.
>>>
>>> tvm._api_internal._ErrorTest(1, 1)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/path/to/tvm/python/tvm/_ffi/_ctypes/function.py", line 190, in __call__
raise get_last_ffi_error()
tvm.error.InternalError: Traceback (most recent call last):
[bt] (3) /path/to/tvm/build/libtvm.so(TVMFuncCall+0x48) [0x7fab500b8ca8]
[bt] (2) /path/to/tvm/build/libtvm.so(+0x1c4126) [0x7fab4f7f5126]
[bt] (1) /path/to/tvm/build/libtvm.so(+0x1ba35c) [0x7fab4f7eb35c]
[bt] (0) /path/to/tvm/build/libtvm.so(+0x177d12) [0x7fab4f7a8d12]
File "/path/to/tvm/src/api/api_test.cc", line 83
InternalError: cannot reach here
TVM hint: You hit an internal error. Please open a thread on https://discuss.tvm.ai/ to report it.
As you can see in the above example, TVM's ffi system combines
both the python and c++'s stacktrace into a single message, and generate the
corresponding error class automatically.
How to choose an Error Type
---------------------------
You can go through the error types are listed below, try to use common
sense and also refer to the choices in the existing code.
We try to keep a reasonable amount of error types.
If you feel there is a need to add a new error type, do the following steps:
- Send a RFC proposal with a description and usage examples in the current codebase.
- Add the new error type to :any:`tvm.error` with clear documents.
- Update the list in this file to include the new error type.
- Change the code to use the new error type.
We also recommend to use less abstraction when creating the short error messages.
The code is more readable in this way, and also opens path to craft specific
error messages when necessary.
.. code:: python
def preferred():
# Very clear about what is being raised and what is the error message.
raise OpNotImplemented("Operator relu is not implemented in the MXNet fronend")
def _op_not_implemented(op_name):
return OpNotImplemented("Operator {} is not implemented.").format(op_name)
def not_preferred():
# Introduces another level of indirection.
raise _op_not_implemented("relu")
If we need to introduce a wrapper function that constructs multi-line error messages,
please put wrapper in the same file so other developers can look up the implementation easily.
System-wide Errors
------------------
.. autoclass:: tvm.error.TVMError
.. autoclass:: tvm.error.InternalError
Frontend Errors
---------------
.. autoclass:: tvm.error.OpNotImplemented
.. autoclass:: tvm.error.OpAttributeInvalid
.. autoclass:: tvm.error.OpAttributeRequired
.. autoclass:: tvm.error.OpAttributeNotImplemented
......@@ -28,5 +28,6 @@ Here are guidelines for contributing to various aspect of the project:
committer_guide
document
code_guide
error_handling
pull_request
git_howto
......@@ -19,6 +19,7 @@ from . import target
from . import generic
from . import hybrid
from . import testing
from . import error
from . import ndarray as nd
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
......
......@@ -7,7 +7,7 @@ import ctypes
import traceback
from numbers import Number, Integral
from ..base import _LIB, check_call
from ..base import _LIB, get_last_ffi_error, py2cerror
from ..base import c_str, string_types
from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
......@@ -55,6 +55,7 @@ def convert_to_tvm_func(pyfunc):
rv = local_pyfunc(*pyargs)
except Exception:
msg = traceback.format_exc()
msg = py2cerror(msg)
_LIB.TVMAPISetLastError(c_str(msg))
return -1
......@@ -65,7 +66,8 @@ def convert_to_tvm_func(pyfunc):
values, tcodes, _ = _make_tvm_args((rv,), temp_args)
if not isinstance(ret, TVMRetValueHandle):
ret = TVMRetValueHandle(ret)
check_call(_LIB.TVMCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)))
if _LIB.TVMCFuncSetReturn(ret, values, tcodes, ctypes.c_int(1)) != 0:
raise get_last_ffi_error()
_ = temp_args
_ = rv
return 0
......@@ -76,8 +78,9 @@ def convert_to_tvm_func(pyfunc):
# TVM_FREE_PYOBJ will be called after it is no longer needed.
pyobj = ctypes.py_object(f)
ctypes.pythonapi.Py_IncRef(pyobj)
check_call(_LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)))
if _LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0:
raise get_last_ffi_error()
return _CLASS_FUNCTION(handle, False)
......@@ -168,7 +171,8 @@ class FunctionBase(object):
def __del__(self):
if not self.is_global and _LIB is not None:
check_call(_LIB.TVMFuncFree(self.handle))
if _LIB.TVMFuncFree(self.handle) != 0:
raise get_last_ffi_error()
def __call__(self, *args):
"""Call the function with positional arguments
......@@ -180,9 +184,10 @@ class FunctionBase(object):
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.TVMFuncCall(
self.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
if _LIB.TVMFuncCall(
self.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)) != 0:
raise get_last_ffi_error()
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
......@@ -194,9 +199,10 @@ def __init_handle_by_constructor__(fconstructor, args):
values, tcodes, num_args = _make_tvm_args(args, temp_args)
ret_val = TVMValue()
ret_tcode = ctypes.c_int()
check_call(_LIB.TVMFuncCall(
fconstructor.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)))
if _LIB.TVMFuncCall(
fconstructor.handle, values, tcodes, ctypes.c_int(num_args),
ctypes.byref(ret_val), ctypes.byref(ret_tcode)) != 0:
raise get_last_ffi_error()
_ = temp_args
_ = args
assert ret_tcode.value == TypeCode.NODE_HANDLE
......
from ..base import TVMError
from ..base import get_last_ffi_error
from libcpp.vector cimport vector
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
......@@ -148,7 +148,7 @@ cdef inline c_str(pystr):
cdef inline CALL(int ret):
if ret != 0:
raise TVMError(py_str(TVMGetLastError()))
raise get_last_ffi_error()
cdef inline object ctypes_handle(void* chandle):
......
......@@ -2,7 +2,7 @@ import ctypes
import traceback
from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types
from ..base import string_types, py2cerror
from ..node_generic import convert_to_node, NodeGeneric
from ..runtime_ctypes import TVMType, TVMContext, TVMByteArray
......@@ -38,6 +38,7 @@ cdef int tvm_callback(TVMValue* args,
rv = local_pyfunc(*pyargs)
except Exception:
msg = traceback.format_exc()
msg = py2cerror(msg)
TVMAPISetLastError(c_str(msg))
return -1
if rv is not None:
......
# coding: utf-8
# pylint: disable=invalid-name
""" ctypes library of nnvm and helper functions """
"""Base library for TVM FFI."""
from __future__ import absolute_import
import sys
......@@ -30,10 +30,6 @@ else:
py_str = lambda x: x
class TVMError(Exception):
"""Error thrown by TVM function"""
def _load_lib():
"""Load libary by searching possible path."""
lib_path = libinfo.find_lib_path()
......@@ -56,21 +52,6 @@ _FFI_MODE = os.environ.get("TVM_FFI", "auto")
#----------------------------
# helper function in ctypes.
#----------------------------
def check_call(ret):
"""Check the return value of C API call
This function will raise exception when error occurs.
Wrap every API call with this function
Parameters
----------
ret : int
return value from API calls
"""
if ret != 0:
raise TVMError(py_str(_LIB.TVMGetLastError()))
def c_str(string):
"""Create ctypes char * from a python string
Parameters
......@@ -118,3 +99,200 @@ def decorate(func, fwrapped):
"""
import decorator
return decorator.decorate(func, fwrapped)
#-----------------------------------------
# Base code for structured error handling.
#-----------------------------------------
# Maps error type to its constructor
ERROR_TYPE = {}
class TVMError(RuntimeError):
"""Default error thrown by TVM functions.
TVMError will be raised if you do not give any error type specification,
"""
def register_error(func_name=None, cls=None):
"""Register an error class so it can be recognized by the ffi error handler.
Parameters
----------
func_name : str or function or class
The name of the error function.
cls : function
The function to create the class
Returns
-------
fregister : function
Register function if f is not specified.
Examples
--------
.. code-block:: python
@tvm.error.register_error
class MyError(RuntimeError):
pass
err_inst = tvm.error.create_ffi_error("MyError: xyz")
assert isinstance(err_inst, MyError)
"""
if callable(func_name):
cls = func_name
func_name = cls.__name__
def register(mycls):
"""internal register function"""
err_name = func_name if isinstance(func_name, str) else mycls.__name__
ERROR_TYPE[err_name] = mycls
return mycls
if cls is None:
return register
return register(cls)
def _valid_error_name(name):
"""Check whether name is a valid error name."""
return all(x.isalnum() or x in "_." for x in name)
def _find_error_type(line):
"""Find the error name given the first line of the error message.
Parameters
----------
line : str
The first line of error message.
Returns
-------
name : str The error name
"""
end_pos = line.find(":")
if end_pos == -1:
return None
err_name = line[:end_pos]
if _valid_error_name(err_name):
return err_name
return None
def c2pyerror(err_msg):
"""Translate C API error message to python style.
Parameters
----------
err_msg : str
The error message.
Returns
-------
new_msg : str
Translated message.
err_type : str
Detected error type.
"""
arr = err_msg.split("\n")
if arr[-1] == "":
arr.pop()
err_type = _find_error_type(arr[0])
trace_mode = False
stack_trace = []
message = []
for line in arr:
if trace_mode:
if line.startswith(" "):
stack_trace.append(line)
else:
trace_mode = False
if not trace_mode:
if line.startswith("Stack trace"):
trace_mode = True
else:
message.append(line)
out_msg = ""
if stack_trace:
out_msg += "Traceback (most recent call last):\n"
out_msg += "\n".join(reversed(stack_trace)) + "\n"
out_msg += "\n".join(message)
return out_msg, err_type
def py2cerror(err_msg):
"""Translate python style error message to C style.
Parameters
----------
err_msg : str
The error message.
Returns
-------
new_msg : str
Translated message.
"""
arr = err_msg.split("\n")
if arr[-1] == "":
arr.pop()
trace_mode = False
stack_trace = []
message = []
for line in arr:
if trace_mode:
if line.startswith(" "):
stack_trace.append(line)
else:
trace_mode = False
if not trace_mode:
if line.find("Traceback") != -1:
trace_mode = True
else:
message.append(line)
# Remove the first error name if there are two of them.
# RuntimeError: MyErrorName: message => MyErrorName: message
head_arr = message[0].split(":", 3)
if len(head_arr) >= 3 and _valid_error_name(head_arr[1].strip()):
head_arr[1] = head_arr[1].strip()
message[0] = ":".join(head_arr[1:])
# reverse the stack trace.
out_msg = "\n".join(message)
if stack_trace:
out_msg += "\nStack trace:\n"
out_msg += "\n".join(reversed(stack_trace)) + "\n"
return out_msg
def get_last_ffi_error():
"""Create error object given result of TVMGetLastError.
Returns
-------
err : object
The error object based on the err_msg
"""
c_err_msg = py_str(_LIB.TVMGetLastError())
py_err_msg, err_type = c2pyerror(c_err_msg)
if err_type.startswith("tvm.error."):
err_type = err_type[10:]
return ERROR_TYPE.get(err_type, TVMError)(py_err_msg)
def check_call(ret):
"""Check the return value of C API call
This function will raise exception when error occurs.
Wrap every API call with this function
Parameters
----------
ret : int
return value from API calls
"""
if ret != 0:
raise get_last_ffi_error()
"""Structured error classes in TVM.
Each error class takes an error message as its input.
See the example sections for for suggested message conventions.
To make the code more readable, we recommended developers to
copy the examples and raise errors with the same message convention.
"""
from ._ffi.base import register_error, TVMError
@register_error
class InternalError(TVMError):
"""Internal error in the system.
Examples
--------
.. code :: c++
// Example code C++
LOG(FATAL) << "InternalError: internal error detail.";
.. code :: python
# Example code in python
raise InternalError("internal error detail")
"""
def __init__(self, msg):
# Patch up additional hint message.
if "TVM hint:" not in msg:
msg += ("\nTVM hint: You hit an internal error. " +
"Please open a thread on https://discuss.tvm.ai/ to report it.")
super(InternalError, self).__init__(msg)
register_error("ValueError", ValueError)
register_error("TypeError", TypeError)
@register_error
class OpError(TVMError):
"""Base class of all operator errors in frontends."""
@register_error
class OpNotImplemented(OpError, NotImplementedError):
"""Operator is not implemented.
Example
-------
.. code:: python
raise OpNotImplemented(
"Operator {} is not supported in {} frontend".format(
missing_op, frontend_name))
"""
@register_error
class OpAttributeRequired(OpError, AttributeError):
"""Required attribute is not found.
Example
-------
.. code:: python
raise OpAttributeRequired(
"Required attribute {} not found in operator {}".format(
attr_name, op_name))
"""
@register_error
class OpAttributeInvalid(OpError, AttributeError):
"""Attribute value is invalid when taking in a frontend operator.
Example
-------
.. code:: python
raise OpAttributeInvalid(
"Value {} in attribute {} of operator {} is not valid".format(
value, attr_name, op_name))
"""
@register_error
class OpAttributeUnimplemented(OpError, NotImplementedError):
"""Attribute is not supported in a certain frontend.
Example
-------
.. code:: python
raise OpAttributeUnimplemented(
"Attribute {} is not supported in operator {}".format(
attr_name, op_name))
"""
......@@ -39,6 +39,30 @@ TVM_REGISTER_API("_nop")
.set_body([](TVMArgs args, TVMRetValue *ret) {
});
TVM_REGISTER_API("_test_wrap_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc pf = args[0];
*ret = runtime::TypedPackedFunc<void()>([pf](){
pf();
});
});
TVM_REGISTER_API("_test_raise_error_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
std::string msg = args[0];
*ret = runtime::TypedPackedFunc<void()>([msg](){
LOG(FATAL) << msg;
});
});
TVM_REGISTER_API("_test_check_eq_callback")
.set_body([](TVMArgs args, TVMRetValue *ret) {
std::string msg = args[0];
*ret = runtime::TypedPackedFunc<void(int x, int y)>([msg](int x, int y){
CHECK_EQ(x, y) << msg;
});
});
TVM_REGISTER_API("_context_test")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLContext ctx = args[0];
......@@ -49,6 +73,20 @@ TVM_REGISTER_API("_context_test")
*ret = ctx;
});
// in src/api_test.cc
void ErrorTest(int x, int y) {
// raise ValueError
CHECK_EQ(x, y) << "ValueError: expect x and y to be equal.";
if (x == 1) {
// raise InternalError.
LOG(FATAL) << "InternalError: cannot reach here";
}
}
TVM_REGISTER_API("_ErrorTest")
.set_body_typed<void(int, int)>(ErrorTest);
// internal function used for debug and testing purposes
TVM_REGISTER_API("_ndarray_use_count")
.set_body([](TVMArgs args, TVMRetValue *ret) {
......
......@@ -13,10 +13,14 @@
#ifdef _LIBCPP_SGX_CONFIG
#include "sgx/trusted/runtime.h"
#endif
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
#include <sstream>
#endif
#include <array>
#include <algorithm>
#include <string>
#include <cstdlib>
#include <cctype>
#include "runtime_base.h"
namespace tvm {
......@@ -104,6 +108,169 @@ void DeviceAPI::SyncStreamFromTo(TVMContext ctx,
TVMStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api.";
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
//--------------------------------------------------------
// Error handling mechanism
// -------------------------------------------------------
// Standard error message format, {} means optional
//--------------------------------------------------------
// {error_type:} {message0}
// {message1}
// {message2}
// {Stack trace:} // stack traces follow by this line
// {trace 0} // two spaces in the begining.
// {trace 1}
// {trace 2}
//--------------------------------------------------------
/*!
* \brief Normalize error message
*
* Parse them header generated by by LOG(FATAL) and CHECK
* and reformat the message into the standard format.
*
* This function will also merge all the stack traces into
* one trace and trim them.
*
* \param err_msg The error message.
* \return normalized message.
*/
std::string NormalizeError(std::string err_msg) {
// ------------------------------------------------------------------------
// log with header, {} indicates optional
//-------------------------------------------------------------------------
// [timestamp] file_name:line_number: {check_msg:} {error_type:} {message0}
// {message1}
// Stack trace:
// {stack trace 0}
// {stack trace 1}
//-------------------------------------------------------------------------
// Normalzied version
//-------------------------------------------------------------------------
// error_type: check_msg message0
// {message1}
// Stack trace:
// File file_name, line lineno
// {stack trace 0}
// {stack trace 1}
//-------------------------------------------------------------------------
int line_number = 0;
std::istringstream is(err_msg);
std::string line, file_name, error_type, check_msg;
// Parse log header and set the fields,
// Return true if it the log is in correct format,
// return false if something is wrong.
auto parse_log_header = [&]() {
// skip timestamp
if (is.peek() != '[') {
getline(is, line);
return true;
}
if (!(is >> line)) return false;
// get filename
while (is.peek() == ' ') is.get();
if (!getline(is, file_name, ':')) return false;
// get line number
if (!(is >> line_number)) return false;
// get rest of the message.
while (is.peek() == ' ' || is.peek() == ':') is.get();
if (!getline(is, line)) return false;
// detect check message, rewrite to remote extra :
if (line.compare(0, 13, "Check failed:") == 0) {
size_t end_pos = line.find(':', 13);
if (end_pos == std::string::npos) return false;
check_msg = line.substr(0, end_pos + 1) + ' ';
line = line.substr(end_pos + 1);
}
return true;
};
// if not in correct format, do not do any rewrite.
if (!parse_log_header()) return err_msg;
// Parse error type.
{
size_t start_pos = 0, end_pos;
for (; start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {}
for (end_pos = start_pos; end_pos < line.length(); ++end_pos) {
char ch = line[end_pos];
if (ch == ':') {
error_type = line.substr(start_pos, end_pos - start_pos);
break;
}
// [A-Z0-9a-z_.]
if (!std::isalpha(ch) && !std::isdigit(ch) && ch != '_' && ch != '.') break;
}
if (error_type.length() != 0) {
// if we successfully detected error_type: trim the following space.
for (start_pos = end_pos + 1;
start_pos < line.length() && line[start_pos] == ' '; ++start_pos) {}
line = line.substr(start_pos);
} else {
// did not detect error_type, use default value.
line = line.substr(start_pos);
error_type = "TVMError";
}
}
// Seperate out stack trace.
std::ostringstream os;
os << error_type << ": " << check_msg << line << '\n';
bool trace_mode = true;
std::vector<std::string> stack_trace;
while (getline(is, line)) {
if (trace_mode) {
if (line.compare(0, 2, " ") == 0) {
stack_trace.push_back(line);
} else {
trace_mode = false;
// remove EOL trailing stacktrace.
if (line.length() == 0) continue;
}
}
if (!trace_mode) {
if (line.compare(0, 11, "Stack trace") == 0) {
trace_mode = true;
} else {
os << line << '\n';
}
}
}
if (stack_trace.size() != 0 || file_name.length() != 0) {
os << "Stack trace:\n";
if (file_name.length() != 0) {
os << " File \"" << file_name << "\", line " << line_number << "\n";
}
// Print out stack traces, optionally trim the c++ traces
// about the frontends (as they will be provided by the frontends).
bool ffi_boundary = false;
for (const auto& line : stack_trace) {
// Heuristic to detect python ffi.
if (line.find("libffi.so") != std::string::npos ||
line.find("core.cpython") != std::string::npos) {
ffi_boundary = true;
}
// If the backtrace is not c++ backtrace with the prefix " [bt]",
// then we can stop trimming.
if (ffi_boundary && line.compare(0, 6, " [bt]") != 0) {
ffi_boundary = false;
}
if (!ffi_boundary) {
os << line << '\n';
}
// The line after TVMFuncCall cound be in FFI.
if (line.find("(TVMFuncCall") != std::string::npos) {
ffi_boundary = true;
}
}
}
return os.str();
}
#else
std::string NormalizeError(std::string err_msg) {
return err_msg;
}
#endif
} // namespace runtime
} // namespace tvm
......@@ -121,6 +288,11 @@ const char *TVMGetLastError() {
return TVMAPIRuntimeStore::Get()->last_error.c_str();
}
int TVMAPIHandleException(const std::runtime_error &e) {
TVMAPISetLastError(NormalizeError(e.what()).c_str());
return -1;
}
void TVMAPISetLastError(const char* msg) {
#ifndef _LIBCPP_SGX_CONFIG
TVMAPIRuntimeStore::Get()->last_error = msg;
......@@ -279,9 +451,7 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
if (ret != 0) {
std::string err = "TVMCall CFunc Error:\n";
err += TVMGetLastError();
throw dmlc::Error(err);
throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace());
}
});
} else {
......@@ -293,9 +463,7 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
int ret = func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get());
if (ret != 0) {
std::string err = "TVMCall CFunc Error:\n";
err += TVMGetLastError();
throw dmlc::Error(err);
throw dmlc::Error(TVMGetLastError() + ::dmlc::StackTrace());
}
});
}
......
......@@ -26,9 +26,6 @@
* \param e the exception
* \return the return value of API after exception is handled
*/
inline int TVMAPIHandleException(const std::runtime_error &e) {
TVMAPISetLastError(e.what());
return -1;
}
int TVMAPIHandleException(const std::runtime_error &e);
#endif // TVM_RUNTIME_RUNTIME_BASE_H_
"""Test runtime error handling"""
import tvm
def test_op_translation():
ferror = tvm._api_internal._test_raise_error_callback(
"OpNotImplemented: myop")
try:
ferror()
assert False
except tvm.error.OpNotImplemented as e:
msg = str(e)
assert isinstance(e, NotImplementedError)
assert msg.find("api_test.cc") != -1
fchk_eq = tvm._api_internal._test_check_eq_callback(
"InternalError: myop")
try:
fchk_eq(0, 1)
assert False
except tvm.error.InternalError as e:
msg = str(e)
assert msg.find("api_test.cc") != -1
try:
tvm._api_internal._ErrorTest(0, 1)
assert False
except ValueError as e:
msg = str(e)
assert msg.find("api_test.cc") != -1
def test_deep_callback():
def error_callback():
raise ValueError("callback error")
wrap1 = tvm._api_internal._test_wrap_callback(error_callback)
def flevel2():
wrap1()
wrap2 = tvm._api_internal._test_wrap_callback(flevel2)
def flevel3():
wrap2()
wrap3 = tvm._api_internal._test_wrap_callback(flevel3)
try:
wrap3()
assert False
except ValueError as e:
msg = str(e)
idx2 = msg.find("in flevel2")
idx3 = msg.find("in flevel3")
assert idx2 != -1 and idx3 != -1
assert idx2 > idx3
if __name__ == "__main__":
test_op_translation()
test_deep_callback()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment