Unverified Commit 5e50f476 by Tianqi Chen Committed by GitHub

[RUNTIME] Enable auto conversion from str to runtime::String in PackedFunc, move…

[RUNTIME] Enable auto conversion from str to runtime::String in PackedFunc, move dtype related handling to data_type.h (#5251)
parent f31df01e
......@@ -27,6 +27,7 @@
#include <tvm/runtime/c_runtime_api.h>
#include <dmlc/logging.h>
#include <type_traits>
#include <string>
namespace tvm {
namespace runtime {
......@@ -263,6 +264,141 @@ inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}
/*!
* \brief Runtime utility for getting custom type name from code
* \param type_code Custom type code
* \return Custom type name
*/
TVM_DLL std::string GetCustomTypeName(uint8_t type_code);
/*!
* \brief Runtime utility for checking whether custom type is registered
* \param type_code Custom type code
* \return Bool representing whether type is registered
*/
TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);
/*!
* \brief Runtime utility for parsing string of the form "custom[<typename>]"
* \param s String to parse
* \param scan pointer to parsing pointer, which is scanning across s
* \return type code of custom type parsed
*/
TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
/*!
* \brief Convert type code to its name
* \param type_code The type code .
* \return The name of type code.
*/
inline const char* TypeCode2Str(int type_code);
/*!
* \brief convert a string to TVM type.
* \param s The string to be converted.
* \return The corresponding tvm type.
*/
inline DLDataType String2DLDataType(std::string s);
/*!
* \brief convert a TVM type to string.
* \param t The type to be converted.
* \return The corresponding tvm type in string.
*/
inline std::string DLDataType2String(DLDataType t);
// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt: return "int";
case kDLUInt: return "uint";
case kDLFloat: return "float";
case kTVMStr: return "str";
case kTVMBytes: return "bytes";
case kTVMOpaqueHandle: return "handle";
case kTVMNullptr: return "NULL";
case kTVMDLTensorHandle: return "ArrayHandle";
case kTVMDataType: return "DLDataType";
case kTVMContext: return "TVMContext";
case kTVMPackedFuncHandle: return "FunctionHandle";
case kTVMModuleHandle: return "ModuleHandle";
case kTVMNDArrayHandle: return "NDArrayContainer";
case kTVMObjectHandle: return "Object";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}
inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
os << "bool"; return os;
}
if (t.code < kTVMCustomBegin) {
os << TypeCode2Str(t.code);
} else {
os << "custom[" << GetCustomTypeName(t.code) << "]";
}
if (t.code == kTVMOpaqueHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);
}
return os;
}
inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
return os << dtype.operator DLDataType();
}
inline std::string DLDataType2String(DLDataType t) {
if (t.bits == 0) return "";
std::ostringstream os;
os << t;
return os.str();
}
inline DLDataType String2DLDataType(std::string s) {
DLDataType t;
// handle None type
if (s.length() == 0) {
t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
return t;
}
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kDLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kTVMOpaqueHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6;
} else if (s == "bool") {
t.code = kDLUInt;
t.bits = 1;
t.lanes = 1;
return t;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
}
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
char* endpt = xdelim;
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
}
CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
return t;
}
} // namespace runtime
using DataType = runtime::DataType;
......
......@@ -30,6 +30,7 @@
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/container.h>
#include <functional>
#include <tuple>
#include <vector>
......@@ -52,28 +53,6 @@ class PrimExpr;
namespace runtime {
/*!
* \brief Runtime utility for getting custom type name from code
* \param type_code Custom type code
* \return Custom type name
*/
TVM_DLL std::string GetCustomTypeName(uint8_t type_code);
/*!
* \brief Runtime utility for checking whether custom type is registered
* \param type_code Custom type code
* \return Bool representing whether type is registered
*/
TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code);
/*!
* \brief Runtime utility for parsing string of the form "custom[<typename>]"
* \param s String to parse
* \param scan pointer to parsing pointer, which is scanning across s
* \return type code of custom type parsed
*/
TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
// forward declarations
class TVMArgs;
class TVMArgValue;
......@@ -359,27 +338,6 @@ class TVMArgs {
inline TVMArgValue operator[](int i) const;
};
/*!
* \brief Convert type code to its name
* \param type_code The type code .
* \return The name of type code.
*/
inline const char* TypeCode2Str(int type_code);
/*!
* \brief convert a string to TVM type.
* \param s The string to be converted.
* \return The corresponding tvm type.
*/
inline DLDataType String2DLDataType(std::string s);
/*!
* \brief convert a TVM type to string.
* \param t The type to be converted.
* \return The corresponding tvm type in string.
*/
inline std::string DLDataType2String(DLDataType t);
// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " \
......@@ -554,6 +512,10 @@ class TVMArgValue : public TVMPODValue_ {
return std::string(value_.v_str);
}
}
operator tvm::runtime::String() const {
// directly use the std::string constructor for now.
return tvm::runtime::String(operator std::string());
}
operator DLDataType() const {
if (type_code_ == kTVMStr) {
return String2DLDataType(operator std::string());
......@@ -642,6 +604,10 @@ class TVMRetValue : public TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kTVMStr);
return *ptr<std::string>();
}
operator tvm::runtime::String() const {
// directly use the std::string constructor for now.
return tvm::runtime::String(operator std::string());
}
operator DLDataType() const {
if (type_code_ == kTVMStr) {
return String2DLDataType(operator std::string());
......@@ -994,96 +960,6 @@ class TVMRetValue : public TVMPODValue_ {
} \
}
// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt: return "int";
case kDLUInt: return "uint";
case kDLFloat: return "float";
case kTVMStr: return "str";
case kTVMBytes: return "bytes";
case kTVMOpaqueHandle: return "handle";
case kTVMNullptr: return "NULL";
case kTVMDLTensorHandle: return "ArrayHandle";
case kTVMDataType: return "DLDataType";
case kTVMContext: return "TVMContext";
case kTVMPackedFuncHandle: return "FunctionHandle";
case kTVMModuleHandle: return "ModuleHandle";
case kTVMNDArrayHandle: return "NDArrayContainer";
case kTVMObjectHandle: return "Object";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
}
inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
os << "bool"; return os;
}
if (t.code < kTVMCustomBegin) {
os << TypeCode2Str(t.code);
} else {
os << "custom[" << GetCustomTypeName(t.code) << "]";
}
if (t.code == kTVMOpaqueHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes);
}
return os;
}
inline std::ostream& operator<<(std::ostream& os, const DataType& dtype) { // NOLINT(*)
return os << dtype.operator DLDataType();
}
inline std::string DLDataType2String(DLDataType t) {
if (t.bits == 0) return "";
std::ostringstream os;
os << t;
return os.str();
}
inline DLDataType String2DLDataType(std::string s) {
DLDataType t;
// handle None type
if (s.length() == 0) {
t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle;
return t;
}
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kDLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kTVMOpaqueHandle;
t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6;
} else if (s == "bool") {
t.code = kDLUInt;
t.bits = 1;
t.lanes = 1;
return t;
} else if (s.substr(0, 6) == "custom") {
t.code = ParseCustomDatatype(s, &scan);
} else {
scan = s.c_str();
LOG(FATAL) << "unknown type " << s;
}
char* xdelim; // emulate sscanf("%ux%u", bits, lanes)
uint8_t bits = static_cast<uint8_t>(strtoul(scan, &xdelim, 10));
if (bits != 0) t.bits = bits;
char* endpt = xdelim;
if (*xdelim == 'x') {
t.lanes = static_cast<uint16_t>(strtoul(xdelim + 1, &endpt, 10));
}
CHECK(endpt == s.c_str() + s.length()) << "unknown type " << s;
return t;
}
inline TVMArgValue TVMArgs::operator[](int i) const {
CHECK_LT(i, num_args)
......
......@@ -91,6 +91,8 @@ TEST(PackedFunc, str) {
CHECK(args.num_args == 1);
std::string x = args[0];
CHECK(x == "hello");
String y = args[0];
CHECK(y == "hello");
*rv = x;
})("hello");
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment