Unverified Commit 12e737f5 by Krzysztof Parzyszek Committed by GitHub

Make "none" DataType explicit (#5491)

* Make "none" DataType explicit

The None data type is created when converting an empty string to DataType.
Add functions to create it and recognize it. Convert it to the "void" LLVM
type in LLVM codegen.

* Rename "none" to "void"

* Map VoidType:Type -> Void:DataType in GetRuntimeDataType

* Map Void:DataType -> VoidType:Type in GetType
parent 3aa103e7
...@@ -107,7 +107,7 @@ class DataType { ...@@ -107,7 +107,7 @@ class DataType {
} }
/*! \return whether type is a handle type. */ /*! \return whether type is a handle type. */
bool is_handle() const { bool is_handle() const {
return code() == DataType::kHandle; return code() == DataType::kHandle && !is_void();
} }
/*! \return whether type is a vector type. */ /*! \return whether type is a vector type. */
bool is_vector() const { bool is_vector() const {
...@@ -117,6 +117,10 @@ class DataType { ...@@ -117,6 +117,10 @@ class DataType {
bool is_vector_bool() const { bool is_vector_bool() const {
return is_vector() && bits() == 1; return is_vector() && bits() == 1;
} }
/*! \return whether type is a Void type. */
bool is_void() const {
return code() == DataType::kHandle && bits() == 0 && lanes() == 0;
}
/*! /*!
* \brief Create a new data type by change lanes to a specified value. * \brief Create a new data type by change lanes to a specified value.
* \param lanes The target number of lanes. * \param lanes The target number of lanes.
...@@ -212,6 +216,13 @@ class DataType { ...@@ -212,6 +216,13 @@ class DataType {
return DataType(kHandle, bits, lanes); return DataType(kHandle, bits, lanes);
} }
/*! /*!
* \brief Construct a Void type.
* \return The constructed data type.
*/
static DataType Void() {
return DataType(kHandle, 0, 0);
}
/*!
* \brief Get the corresponding type of TVMShapeIndex. * \brief Get the corresponding type of TVMShapeIndex.
* \return The type of TVM shape index. * \return The type of TVM shape index.
*/ */
...@@ -335,6 +346,9 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) ...@@ -335,6 +346,9 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) {
os << "bool"; return os; os << "bool"; return os;
} }
if (DataType(t).is_void()) {
return os << "void";
}
if (t.code < kTVMCustomBegin) { if (t.code < kTVMCustomBegin) {
os << TypeCode2Str(t.code); os << TypeCode2Str(t.code);
} else { } else {
...@@ -361,9 +375,9 @@ inline std::string DLDataType2String(DLDataType t) { ...@@ -361,9 +375,9 @@ inline std::string DLDataType2String(DLDataType t) {
inline DLDataType String2DLDataType(std::string s) { inline DLDataType String2DLDataType(std::string s) {
DLDataType t; DLDataType t;
// handle None type // handle void type
if (s.length() == 0) { if (s.length() == 0) {
t.bits = 0; t.lanes = 0; t.code = kTVMOpaqueHandle; t = DataType::Void();
return t; return t;
} }
t.bits = 32; t.lanes = 1; t.bits = 32; t.lanes = 1;
......
...@@ -309,6 +309,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const { ...@@ -309,6 +309,9 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& dtype) const {
CHECK_EQ(dtype.lanes(), 1); CHECK_EQ(dtype.lanes(), 1);
return t_void_p_; return t_void_p_;
} }
if (dtype.is_void()) {
return t_void_;
}
llvm::Type* etype = nullptr; llvm::Type* etype = nullptr;
if (dtype.is_int() || dtype.is_uint()) { if (dtype.is_int() || dtype.is_uint()) {
etype = llvm::Type::getIntNTy(*ctx_, dtype.bits()); etype = llvm::Type::getIntNTy(*ctx_, dtype.bits());
......
...@@ -38,6 +38,8 @@ runtime::DataType GetRuntimeDataType(const Type& type) { ...@@ -38,6 +38,8 @@ runtime::DataType GetRuntimeDataType(const Type& type) {
return n->dtype; return n->dtype;
} else if (type.as<PointerTypeNode>()) { } else if (type.as<PointerTypeNode>()) {
return DataType::Handle(); return DataType::Handle();
} else if (IsVoidType(type)) {
return DataType::Void();
} else { } else {
LOG(FATAL) << "Type " << type LOG(FATAL) << "Type " << type
<< " does not have a corresponding runtime::DataType"; << " does not have a corresponding runtime::DataType";
...@@ -57,9 +59,8 @@ Type GetType(const PrimExpr& expr) { ...@@ -57,9 +59,8 @@ Type GetType(const PrimExpr& expr) {
} }
// Default: return the type indicated by the dtype. // Default: return the type indicated by the dtype.
runtime::DataType dtype = expr.dtype(); runtime::DataType dtype = expr.dtype();
// These types already implies the specific type. if (dtype.is_void()) {
if (dtype.is_int() || dtype.is_uint() || dtype.is_float()) { return VoidType();
return PrimType(dtype);
} }
return PrimType(dtype); return PrimType(dtype);
} }
......
...@@ -43,6 +43,18 @@ def test_llvm_intrin(): ...@@ -43,6 +43,18 @@ def test_llvm_intrin():
fcode = tvm.build(mod, None, "llvm") fcode = tvm.build(mod, None, "llvm")
def test_llvm_void_intrin():
ib = tvm.tir.ir_builder.create()
A = ib.pointer("uint8", name="A")
# Create an intrinsic that returns void.
x = tvm.tir.call_llvm_intrin('', 'llvm.va_start', tvm.tir.const(1, 'uint32'), A)
ib.emit(x)
body = ib.get()
mod = tvm.IRModule.from_expr(
tvm.tir.PrimFunc([A], body).with_attr("global_symbol", "main"))
fcode = tvm.build(mod, None, "llvm")
def test_llvm_overloaded_intrin(): def test_llvm_overloaded_intrin():
# Name lookup for overloaded intrinsics in LLVM 4- requires a name # Name lookup for overloaded intrinsics in LLVM 4- requires a name
# that includes the overloaded types. # that includes the overloaded types.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment