/* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file * distributed with this work for additional information * regarding copyright ownership. The ASF licenses this file * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * KIND, either express or implied. See the License for the * specific language governing permissions and limitations * under the License. */ /* * \file tvm/dtype.h * \brief Data type used in IR. */ // Acknowledgement: DataType structure design originates from Halide. #ifndef TVM_DTYPE_H_ #define TVM_DTYPE_H_ #include "runtime/packed_func.h" namespace tvm { class Expr; /*! * \brief Primitive data types in tvm. */ class DataType { public: /*! \brief default constructor */ DataType() {} /*! * \brief Constructor * \param dtype The DLDataType */ explicit DataType(DLDataType dtype) : data_(dtype) {} /*! * \brief Constructor * \param code The type code. * \param bits The number of bits in the type. * \param lanes The number of lanes. */ DataType(int code, int bits, int lanes) { data_.code = static_cast<uint8_t>(code); data_.bits = static_cast<uint8_t>(bits); data_.lanes = static_cast<uint16_t>(lanes); } /*! \return The type code. */ int code() const { return static_cast<int>(data_.code); } /*! \return number of bits in the data. */ int bits() const { return static_cast<int>(data_.bits); } /*! \return number of bytes to store each scalar. */ int bytes() const { return (bits() + 7) / 8; } /*! \return number of lanes in the data. */ int lanes() const { return static_cast<int>(data_.lanes); } /*! \return whether type is a scalar type. */ bool is_scalar() const { return lanes() == 1; } /*! \return whether type is a scalar type. */ bool is_bool() const { return code() == kDLUInt && bits() == 1; } /*! \return whether type is a float type. */ bool is_float() const { return code() == kDLFloat; } /*! \return whether type is an int type. */ bool is_int() const { return code() == kDLInt; } /*! \return whether type is an uint type. */ bool is_uint() const { return code() == kDLUInt; } /*! \return whether type is a handle type. */ bool is_handle() const { return code() == kHandle; } /*! \return whether type is a vector type. */ bool is_vector() const { return lanes() > 1; } /*! * \brief Create a new data type by change lanes to a specified value. * \param lanes The target number of lanes. * \return the result type. */ DataType with_lanes(int lanes) const { return DataType(data_.code, data_.bits, lanes); } /*! * \brief Create a new data type by change bits to a specified value. * \param bits The target number of bits. * \return the result type. */ DataType with_bits(int bits) const { return DataType(data_.code, bits, data_.lanes); } /*! * \brief Get the scalar version of the type. * \return the result type. */ DataType element_of() const { return with_lanes(1); } // operator overloadings bool operator==(const DataType& other) const { return data_.code == other.data_.code && data_.bits == other.data_.bits && data_.lanes == other.data_.lanes; } bool operator!=(const DataType& other) const { return !operator==(other); } operator DLDataType () const { return data_; } /*! \return the maximum possible value in this format. */ TVM_DLL Expr max() const; /*! \return the minimum possible value in this format. */ TVM_DLL Expr min() const; private: DLDataType data_; }; /*! * \brief Construct an int type. * \param bits The number of bits in the type. * \param lanes The number of lanes. * \return The constructed data type. */ inline DataType Int(int bits, int lanes = 1) { return DataType(kDLInt, bits, lanes); } /*! * \brief Construct an uint type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ inline DataType UInt(int bits, int lanes = 1) { return DataType(kDLUInt, bits, lanes); } /*! * \brief Construct a bool type. * \param lanes The number of lanes * \return The constructed data type. */ inline DataType Bool(int lanes = 1) { return UInt(1, lanes); } /*! * \brief Construct an uint type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ inline DataType Float(int bits, int lanes = 1) { return DataType(kDLFloat, bits, lanes); } /*! * \brief Construct a handle type. * \param bits The number of bits in the type. * \param lanes The number of lanes * \return The constructed data type. */ inline DataType Handle(int bits = 64, int lanes = 1) { return DataType(kHandle, bits, lanes); } /*! * \brief Get the corresponding type of TVMShapeIndex. * \return The type of TVM shape index. */ inline DataType TVMShapeIndexType() { if (std::is_signed<tvm_index_t>::value) { return Int(sizeof(tvm_index_t) * 8); } else { return UInt(sizeof(tvm_index_t) * 8); } } /*! * \brief Convert DLDataType to DataType. * \param t The original type. * \return The conversion result. */ inline DataType TVMType2Type(DLDataType t) { return DataType(t.code, t.bits, t.lanes); } /*! * \brief Convert DataType to DataType. * \param t The original type. * \return The conversion result. */ inline DLDataType Type2TVMType(DataType t) { return t.operator DLDataType(); } /*! * \brief Get the number of bytes needed in a vector. * \param dtype The data type. * \return Number of bytes needed. */ inline int GetVectorBytes(DataType dtype) { int data_bits = dtype.bits() * dtype.lanes(); // allow bool to exist if (dtype == Bool()) return 1; CHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; return data_bits / 8; } // Overload print function. inline std::ostream& operator<<(std::ostream& os, DataType dtype) { // NOLINT(*) using namespace tvm::runtime; return os << dtype.operator DLDataType(); } // Backward compatibility using Type = DataType; } // namespace tvm #endif // TVM_DTYPE_H_