Unverified Commit 0ec27f4b by Tianqi Chen Committed by GitHub

[REFACTOR][FFI] Make more clear naming for C API Type codes. (#4715)

This PR introduces more clear naming prefix for C API type codes
to avoid conflict with other packages.

We also removed TVMArray and TVMType to directly use DLTensor and DLDataType.
parent 49d31443
...@@ -156,7 +156,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_create") ...@@ -156,7 +156,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
int additional_info = args[0]; int additional_info = args[0];
*rv = NDSubClass(additional_info); *rv = NDSubClass(additional_info);
CHECK_EQ(rv->type_code(), kNDArrayContainer); CHECK_EQ(rv->type_code(), kTVMNDArrayHandle);
}); });
......
...@@ -47,7 +47,7 @@ func (parray Array) nativeCPtr() (retVal uintptr) { ...@@ -47,7 +47,7 @@ func (parray Array) nativeCPtr() (retVal uintptr) {
} }
func (parray Array) nativeCopyFrom(data unsafe.Pointer, datalen int) (err error) { func (parray Array) nativeCopyFrom(data unsafe.Pointer, datalen int) (err error) {
ret := C.TVMArrayCopyFromBytes((*C.TVMArray)(unsafe.Pointer(parray.nativeCPtr())), ret := C.TVMArrayCopyFromBytes((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr())),
data, data,
C.ulong(datalen)) C.ulong(datalen))
if ret != 0 { if ret != 0 {
...@@ -65,7 +65,7 @@ func (parray Array) nativeCopyFrom(data unsafe.Pointer, datalen int) (err error) ...@@ -65,7 +65,7 @@ func (parray Array) nativeCopyFrom(data unsafe.Pointer, datalen int) (err error)
func (parray Array) CopyFrom(val interface{}) (err error) { func (parray Array) CopyFrom(val interface{}) (err error) {
var data unsafe.Pointer var data unsafe.Pointer
var datalen int var datalen int
dtype := ((*C.TVMArray)(unsafe.Pointer(parray))).dtype dtype := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype
switch val.(type) { switch val.(type) {
case []int8: case []int8:
...@@ -126,7 +126,7 @@ func (parray Array) CopyFrom(val interface{}) (err error) { ...@@ -126,7 +126,7 @@ func (parray Array) CopyFrom(val interface{}) (err error) {
} }
func (parray Array) nativeCopyTo (data unsafe.Pointer, datalen int) (err error){ func (parray Array) nativeCopyTo (data unsafe.Pointer, datalen int) (err error){
ret := C.TVMArrayCopyToBytes((*C.TVMArray)(unsafe.Pointer(parray.nativeCPtr())), ret := C.TVMArrayCopyToBytes((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr())),
unsafe.Pointer(data), unsafe.Pointer(data),
C.ulong(datalen)) C.ulong(datalen))
...@@ -149,7 +149,7 @@ func (parray Array) AsSlice() (retVal interface{}, err error) { ...@@ -149,7 +149,7 @@ func (parray Array) AsSlice() (retVal interface{}, err error) {
for ii := range shape { for ii := range shape {
size *= shape[ii] size *= shape[ii]
} }
dtype := ((*C.TVMArray)(unsafe.Pointer(parray))).dtype dtype := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype
switch parray.GetDType() { switch parray.GetDType() {
case "int8": case "int8":
...@@ -221,13 +221,13 @@ func (parray Array) AsSlice() (retVal interface{}, err error) { ...@@ -221,13 +221,13 @@ func (parray Array) AsSlice() (retVal interface{}, err error) {
// GetNdim returns the number of dimentions in Array // GetNdim returns the number of dimentions in Array
func (parray Array) GetNdim() (retVal int32) { func (parray Array) GetNdim() (retVal int32) {
retVal = int32(((*C.TVMArray)(unsafe.Pointer(parray))).ndim) retVal = int32(((*C.DLTensor)(unsafe.Pointer(parray))).ndim)
return return
} }
// GetShape returns the number of dimentions in Array // GetShape returns the number of dimentions in Array
func (parray Array) GetShape() (retVal []int64) { func (parray Array) GetShape() (retVal []int64) {
shapePtr := (*C.int64_t)(((*C.TVMArray)(unsafe.Pointer(parray))).shape) shapePtr := (*C.int64_t)(((*C.DLTensor)(unsafe.Pointer(parray))).shape)
ndim := parray.GetNdim() ndim := parray.GetNdim()
shapeSlice := (*[1<<31] int64)(unsafe.Pointer(shapePtr))[:ndim:ndim] shapeSlice := (*[1<<31] int64)(unsafe.Pointer(shapePtr))[:ndim:ndim]
...@@ -238,14 +238,14 @@ func (parray Array) GetShape() (retVal []int64) { ...@@ -238,14 +238,14 @@ func (parray Array) GetShape() (retVal []int64) {
// GetDType returns the number of dimentions in Array // GetDType returns the number of dimentions in Array
func (parray Array) GetDType() (retVal string) { func (parray Array) GetDType() (retVal string) {
ret := ((*C.TVMArray)(unsafe.Pointer(parray))).dtype ret := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype
retVal, _ = dtypeFromTVMType(*(*pTVMType)(unsafe.Pointer(&ret))) retVal, _ = dtypeFromTVMType(*(*pTVMType)(unsafe.Pointer(&ret)))
return return
} }
// GetCtx returns the number of dimentions in Array // GetCtx returns the number of dimentions in Array
func (parray Array) GetCtx() (retVal Context) { func (parray Array) GetCtx() (retVal Context) {
ret := ((*C.TVMArray)(unsafe.Pointer(parray))).ctx ret := ((*C.DLTensor)(unsafe.Pointer(parray))).ctx
retVal = *(*Context)(unsafe.Pointer(&ret)) retVal = *(*Context)(unsafe.Pointer(&ret))
return return
} }
...@@ -342,6 +342,6 @@ func Empty(shape []int64, args ...interface{}) (parray *Array, err error) { ...@@ -342,6 +342,6 @@ func Empty(shape []int64, args ...interface{}) (parray *Array, err error) {
// //
// `ret` indicates the status of this api execution. // `ret` indicates the status of this api execution.
func nativeTVMArrayFree(parray Array) (retVal int32) { func nativeTVMArrayFree(parray Array) (retVal int32) {
retVal = (int32)(C.TVMArrayFree((*C.TVMArray)(unsafe.Pointer(parray.nativeCPtr())))) retVal = (int32)(C.TVMArrayFree((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr()))))
return return
} }
...@@ -33,38 +33,38 @@ import ( ...@@ -33,38 +33,38 @@ import (
"unsafe" "unsafe"
) )
// KHandle is golang type code for TVM enum kHandle. // KHandle is golang type code for TVM enum kTVMOpaqueHandle.
var KHandle = int32(C.kHandle) var KHandle = int32(C.kTVMOpaqueHandle)
// KNull is golang type code for TVM kNull. // KNull is golang type code for TVM kTVMNullptr.
var KNull = int32(C.kNull) var KNull = int32(C.kTVMNullptr)
// KTVMType is golang type code for TVM kTVMType. // KTVMType is golang type code for TVM kTVMDataType.
var KTVMType = int32(C.kTVMType) var KTVMType = int32(C.kTVMDataType)
// KTVMContext is golang type code for TVM kTVMContext. // KTVMContext is golang type code for TVM kTVMContext.
var KTVMContext = int32(C.kTVMContext) var KTVMContext = int32(C.kTVMContext)
// KArrayHandle is golang type code for TVM kArrayHandle. // KArrayHandle is golang type code for TVM kTVMDLTensorHandle.
var KArrayHandle = int32(C.kArrayHandle) var KArrayHandle = int32(C.kTVMDLTensorHandle)
// KObjectHandle is golang type code for TVM kObjectHandle. // KObjectHandle is golang type code for TVM kTVMObjectHandle.
var KObjectHandle = int32(C.kObjectHandle) var KObjectHandle = int32(C.kTVMObjectHandle)
// KModuleHandle is gonag type code for TVM kModuleHandle. // KModuleHandle is gonag type code for TVM kTVMModuleHandle.
var KModuleHandle = int32(C.kModuleHandle) var KModuleHandle = int32(C.kTVMModuleHandle)
// KFuncHandle is gonalg type code for TVM kFuncHandle. // KFuncHandle is gonalg type code for TVM kTVMPackedFuncHandle.
var KFuncHandle = int32(C.kFuncHandle) var KFuncHandle = int32(C.kTVMPackedFuncHandle)
// KStr is golang type code for TVM kStr. // KStr is golang type code for TVM kTVMStr.
var KStr = int32(C.kStr) var KStr = int32(C.kTVMStr)
// KBytes is golang type code for TVM kBytes. // KBytes is golang type code for TVM kTVMBytes.
var KBytes = int32(C.kBytes) var KBytes = int32(C.kTVMBytes)
// KNDArrayContainer is golang typecode for kNDArrayContainer. // KNDArrayContainer is golang typecode for kTVMNDArrayHandle.
var KNDArrayContainer = int32(C.kNDArrayContainer) var KNDArrayContainer = int32(C.kTVMNDArrayHandle)
// KExtBegin is golang enum corresponding to TVM kExtBegin. // KExtBegin is golang enum corresponding to TVM kTVMExtBegin.
var KExtBegin = int32(C.kExtBegin) var KExtBegin = int32(C.kTVMExtBegin)
// KNNVMFirst is golang enum corresponding to TVM kNNVMFirst. // KNNVMFirst is golang enum corresponding to TVM kNNVMFirst.
var KNNVMFirst = int32(C.kNNVMFirst) var KNNVMFirst = int32(C.kTVMNNVMFirst)
// KNNVMLast is golang enum corresponding to TVM kNNVMLast. // KNNVMLast is golang enum corresponding to TVM kNNVMLast.
var KNNVMLast = int32(C.kNNVMLast) var KNNVMLast = int32(C.kTVMNNVMLast)
// KExtReserveEnd is golang enum corresponding to TVM kExtReserveEnd. // KExtReserveEnd is golang enum corresponding to TVM kExtReserveEnd.
var KExtReserveEnd = int32(C.kExtReserveEnd) var KExtReserveEnd = int32(C.kTVMExtReserveEnd)
// KExtEnd is golang enum corresponding to TVM kExtEnd. // KExtEnd is golang enum corresponding to TVM kExtEnd.
var KExtEnd = int32(C.kExtEnd) var KExtEnd = int32(C.kTVMExtEnd)
// KDLInt is golang type code for TVM kDLInt. // KDLInt is golang type code for TVM kDLInt.
var KDLInt = int32(C.kDLInt) var KDLInt = int32(C.kDLInt)
// KDLUInt is golang type code for TVM kDLUInt. // KDLUInt is golang type code for TVM kDLUInt.
......
...@@ -682,7 +682,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) { ...@@ -682,7 +682,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
// datatypes lowering pass, we will lower the value to its true representation in the format // datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype. // specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough? // TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) { if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kTVMCustomBegin)) {
return FloatImm(t, static_cast<double>(value)); return FloatImm(t, static_cast<double>(value));
} }
LOG(FATAL) << "cannot make const for type " << t; LOG(FATAL) << "cannot make const for type " << t;
......
...@@ -88,7 +88,7 @@ inline TObjectRef NullValue() { ...@@ -88,7 +88,7 @@ inline TObjectRef NullValue() {
template<> template<>
inline DataType NullValue<DataType>() { inline DataType NullValue<DataType>() {
return DataType(kHandle, 0, 0); return DataType(DataType::kHandle, 0, 0);
} }
/*! \brief Error thrown during attribute checking. */ /*! \brief Error thrown during attribute checking. */
...@@ -492,7 +492,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) { ...@@ -492,7 +492,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {
template<> template<>
inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) { inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
if (val.type_code() == kStr) { if (val.type_code() == kTVMStr) {
*ptr = val.operator std::string(); *ptr = val.operator std::string();
} else { } else {
LOG(FATAL) << "Expect str"; LOG(FATAL) << "Expect str";
...@@ -762,7 +762,7 @@ class AttrsNode : public BaseAttrsNode { ...@@ -762,7 +762,7 @@ class AttrsNode : public BaseAttrsNode {
// linear search. // linear search.
auto ffind = [&args](const char* key, runtime::TVMArgValue* val) { auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
for (int i = 0; i < args.size(); i += 2) { for (int i = 0; i < args.size(); i += 2) {
CHECK_EQ(args.type_codes[i], kStr); CHECK_EQ(args.type_codes[i], kTVMStr);
if (!std::strcmp(key, args.values[i].v_str)) { if (!std::strcmp(key, args.values[i].v_str)) {
*val = args[i + 1]; *val = args[i + 1];
return true; return true;
...@@ -777,7 +777,7 @@ class AttrsNode : public BaseAttrsNode { ...@@ -777,7 +777,7 @@ class AttrsNode : public BaseAttrsNode {
// construct a map then do lookup. // construct a map then do lookup.
std::unordered_map<std::string, runtime::TVMArgValue> kwargs; std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
for (int i = 0; i < args.size(); i += 2) { for (int i = 0; i < args.size(); i += 2) {
CHECK_EQ(args.type_codes[i], kStr); CHECK_EQ(args.type_codes[i], kTVMStr);
kwargs[args[i].operator std::string()] = args[i + 1]; kwargs[args[i].operator std::string()] = args[i + 1];
} }
auto ffind = [&kwargs](const char *key, runtime::TVMArgValue* val) { auto ffind = [&kwargs](const char *key, runtime::TVMArgValue* val) {
......
...@@ -100,7 +100,7 @@ struct ObjectTypeChecker<Map<K, V> > { ...@@ -100,7 +100,7 @@ struct ObjectTypeChecker<Map<K, V> > {
// extensions for tvm arg value // extensions for tvm arg value
inline TVMPODValue_::operator tvm::PrimExpr() const { inline TVMPODValue_::operator tvm::PrimExpr() const {
if (type_code_ == kNull) return PrimExpr(); if (type_code_ == kTVMNullptr) return PrimExpr();
if (type_code_ == kDLInt) { if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min()); CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
...@@ -110,7 +110,7 @@ inline TVMPODValue_::operator tvm::PrimExpr() const { ...@@ -110,7 +110,7 @@ inline TVMPODValue_::operator tvm::PrimExpr() const {
return PrimExpr(static_cast<float>(value_.v_float64)); return PrimExpr(static_cast<float>(value_.v_float64));
} }
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle); Object* ptr = static_cast<Object*>(value_.v_handle);
if (ptr->IsInstance<IterVarNode>()) { if (ptr->IsInstance<IterVarNode>()) {
...@@ -126,13 +126,13 @@ inline TVMPODValue_::operator tvm::PrimExpr() const { ...@@ -126,13 +126,13 @@ inline TVMPODValue_::operator tvm::PrimExpr() const {
} }
inline TVMPODValue_::operator tvm::Integer() const { inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer(); if (type_code_ == kTVMNullptr) return Integer();
if (type_code_ == kDLInt) { if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max()); CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min()); CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Integer(static_cast<int>(value_.v_int64)); return Integer(static_cast<int>(value_.v_int64));
} }
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle); TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle); Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr)) CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName() << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
......
...@@ -86,62 +86,44 @@ typedef enum { ...@@ -86,62 +86,44 @@ typedef enum {
} TVMDeviceExtType; } TVMDeviceExtType;
/*! /*!
* \brief The type code in TVMType * \brief The type code in used in the TVM FFI.
* \note TVMType is used in two places.
*/ */
typedef enum { typedef enum {
// The type code of other types are compatible with DLPack. // The type code of other types are compatible with DLPack.
// The next few fields are extension types // The next few fields are extension types
// that is used by TVM API calls. // that is used by TVM API calls.
kHandle = 3U, kTVMOpaqueHandle = 3U,
kNull = 4U, kTVMNullptr = 4U,
kTVMType = 5U, kTVMDataType = 5U,
kTVMContext = 6U, kTVMContext = 6U,
kArrayHandle = 7U, kTVMDLTensorHandle = 7U,
kObjectHandle = 8U, kTVMObjectHandle = 8U,
kModuleHandle = 9U, kTVMModuleHandle = 9U,
kFuncHandle = 10U, kTVMPackedFuncHandle = 10U,
kStr = 11U, kTVMStr = 11U,
kBytes = 12U, kTVMBytes = 12U,
kNDArrayContainer = 13U, kTVMNDArrayHandle = 13U,
// Extension codes for other frameworks to integrate TVM PackedFunc. // Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and // To make sure each framework's id do not conflict, use first and
// last sections to mark ranges. // last sections to mark ranges.
// Open an issue at the repo if you need a section of code. // Open an issue at the repo if you need a section of code.
kExtBegin = 15U, kTVMExtBegin = 15U,
kNNVMFirst = 16U, kTVMNNVMFirst = 16U,
kNNVMLast = 20U, kTVMNNVMLast = 20U,
// The following section of code is used for non-reserved types. // The following section of code is used for non-reserved types.
kExtReserveEnd = 64U, kTVMExtReserveEnd = 64U,
kExtEnd = 128U, kTVMExtEnd = 128U,
// The rest of the space is used for custom, user-supplied datatypes // The rest of the space is used for custom, user-supplied datatypes
kCustomBegin = 129U, kTVMCustomBegin = 129U,
} TVMTypeCode; } TVMTypeCode;
/*! /*!
* \brief The data type used in TVM Runtime.
*
* Examples
* - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1
*
* \note Arguments TVM API function always takes bits=64 and lanes=1
*/
typedef DLDataType TVMType;
/*!
* \brief The Device information, abstract away common device types. * \brief The Device information, abstract away common device types.
*/ */
typedef DLContext TVMContext; typedef DLContext TVMContext;
/*!
* \brief The tensor array structure to TVM API.
*/
typedef DLTensor TVMArray;
/*! \brief the array handle */ /*! \brief the array handle */
typedef TVMArray* TVMArrayHandle; typedef DLTensor* TVMArrayHandle;
/*! /*!
* \brief Union type of values * \brief Union type of values
...@@ -152,13 +134,13 @@ typedef union { ...@@ -152,13 +134,13 @@ typedef union {
double v_float64; double v_float64;
void* v_handle; void* v_handle;
const char* v_str; const char* v_str;
TVMType v_type; DLDataType v_type;
TVMContext v_ctx; TVMContext v_ctx;
} TVMValue; } TVMValue;
/*! /*!
* \brief Byte array type used to pass in byte array * \brief Byte array type used to pass in byte array
* When kBytes is used as data type. * When kTVMBytes is used as data type.
*/ */
typedef struct { typedef struct {
const char* data; const char* data;
......
...@@ -44,7 +44,7 @@ class DataType { ...@@ -44,7 +44,7 @@ class DataType {
kInt = kDLInt, kInt = kDLInt,
kUInt = kDLUInt, kUInt = kDLUInt,
kFloat = kDLFloat, kFloat = kDLFloat,
kHandle = TVMTypeCode::kHandle, kHandle = TVMTypeCode::kTVMOpaqueHandle,
}; };
/*! \brief default constructor */ /*! \brief default constructor */
DataType() {} DataType() {}
......
...@@ -88,7 +88,7 @@ class TVM_DLL DeviceAPI { ...@@ -88,7 +88,7 @@ class TVM_DLL DeviceAPI {
virtual void* AllocDataSpace(TVMContext ctx, virtual void* AllocDataSpace(TVMContext ctx,
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
TVMType type_hint) = 0; DLDataType type_hint) = 0;
/*! /*!
* \brief Free a data space on device. * \brief Free a data space on device.
* \param ctx The device context to perform operation. * \param ctx The device context to perform operation.
...@@ -115,7 +115,7 @@ class TVM_DLL DeviceAPI { ...@@ -115,7 +115,7 @@ class TVM_DLL DeviceAPI {
size_t num_bytes, size_t num_bytes,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint, DLDataType type_hint,
TVMStreamHandle stream) = 0; TVMStreamHandle stream) = 0;
/*! /*!
* \brief Create a new stream of execution. * \brief Create a new stream of execution.
...@@ -177,7 +177,7 @@ class TVM_DLL DeviceAPI { ...@@ -177,7 +177,7 @@ class TVM_DLL DeviceAPI {
*/ */
virtual void* AllocWorkspace(TVMContext ctx, virtual void* AllocWorkspace(TVMContext ctx,
size_t nbytes, size_t nbytes,
TVMType type_hint = {}); DLDataType type_hint = {});
/*! /*!
* \brief Free temporal workspace in backend execution. * \brief Free temporal workspace in backend execution.
* *
......
...@@ -36,7 +36,7 @@ namespace runtime { ...@@ -36,7 +36,7 @@ namespace runtime {
* \param bits The number of bits to be matched. * \param bits The number of bits to be matched.
* \param lanes The number of lanes in the type. * \param lanes The number of lanes in the type.
*/ */
inline bool TypeMatch(TVMType t, int code, int bits, int lanes = 1) { inline bool TypeMatch(DLDataType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes; return t.code == code && t.bits == bits && t.lanes == lanes;
} }
/*! /*!
...@@ -44,7 +44,7 @@ inline bool TypeMatch(TVMType t, int code, int bits, int lanes = 1) { ...@@ -44,7 +44,7 @@ inline bool TypeMatch(TVMType t, int code, int bits, int lanes = 1) {
* \param lhs The left operand. * \param lhs The left operand.
* \param rhs The right operand. * \param rhs The right operand.
*/ */
inline bool TypeEqual(TVMType lhs, TVMType rhs) { inline bool TypeEqual(DLDataType lhs, DLDataType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes; return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
} }
} // namespace runtime } // namespace runtime
......
...@@ -167,8 +167,8 @@ jobject newObject(JNIEnv *env, const char *clsname) { ...@@ -167,8 +167,8 @@ jobject newObject(JNIEnv *env, const char *clsname) {
return object; return object;
} }
void fromJavaDType(JNIEnv *env, jobject jdtype, TVMType *dtype) { void fromJavaDType(JNIEnv *env, jobject jdtype, DLDataType *dtype) {
jclass tvmTypeClass = env->FindClass("org/apache/tvm/TVMType"); jclass tvmTypeClass = env->FindClass("org/apache/tvm/DLDataType");
dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I"))); dtype->code = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "typeCode", "I")));
dtype->bits = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "bits", "I"))); dtype->bits = (uint8_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "bits", "I")));
dtype->lanes = (uint16_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "lanes", "I"))); dtype->lanes = (uint16_t)(env->GetIntField(jdtype, env->GetFieldID(tvmTypeClass, "lanes", "I")));
...@@ -191,21 +191,21 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { ...@@ -191,21 +191,21 @@ jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
return newTVMValueLong(env, static_cast<jlong>(value.v_int64)); return newTVMValueLong(env, static_cast<jlong>(value.v_int64));
case kDLFloat: case kDLFloat:
return newTVMValueDouble(env, static_cast<jdouble>(value.v_float64)); return newTVMValueDouble(env, static_cast<jdouble>(value.v_float64));
case kHandle: case kTVMOpaqueHandle:
return newTVMValueHandle(env, reinterpret_cast<jlong>(value.v_handle)); return newTVMValueHandle(env, reinterpret_cast<jlong>(value.v_handle));
case kModuleHandle: case kTVMModuleHandle:
return newModule(env, reinterpret_cast<jlong>(value.v_handle)); return newModule(env, reinterpret_cast<jlong>(value.v_handle));
case kFuncHandle: case kTVMPackedFuncHandle:
return newFunction(env, reinterpret_cast<jlong>(value.v_handle)); return newFunction(env, reinterpret_cast<jlong>(value.v_handle));
case kArrayHandle: case kTVMDLTensorHandle:
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), true); return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), true);
case kNDArrayContainer: case kTVMNDArrayHandle:
return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), false); return newNDArray(env, reinterpret_cast<jlong>(value.v_handle), false);
case kStr: case kTVMStr:
return newTVMValueString(env, value.v_str); return newTVMValueString(env, value.v_str);
case kBytes: case kTVMBytes:
return newTVMValueBytes(env, reinterpret_cast<TVMByteArray *>(value.v_handle)); return newTVMValueBytes(env, reinterpret_cast<TVMByteArray *>(value.v_handle));
case kNull: case kTVMNullptr:
return newObject(env, "org/apache/tvm/TVMValueNull"); return newObject(env, "org/apache/tvm/TVMValueNull");
default: default:
LOG(FATAL) << "Do NOT know how to handle return type code " << tcode; LOG(FATAL) << "Do NOT know how to handle return type code " << tcode;
......
...@@ -98,7 +98,7 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgString( ...@@ -98,7 +98,7 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgString(
value.v_str = env->GetStringUTFChars(garg, 0); value.v_str = env->GetStringUTFChars(garg, 0);
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value); e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kStr); e->tvmFuncArgTypes.push_back(kTVMStr);
// release string args later // release string args later
e->tvmFuncArgPushedStrs.push_back(std::make_pair(garg, value.v_str)); e->tvmFuncArgPushedStrs.push_back(std::make_pair(garg, value.v_str));
} }
...@@ -126,7 +126,7 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes( ...@@ -126,7 +126,7 @@ JNIEXPORT void JNICALL Java_org_apache_tvm_LibInfo_tvmFuncPushArgBytes(
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value); e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kBytes); e->tvmFuncArgTypes.push_back(kTVMBytes);
e->tvmFuncArgPushedBytes.push_back(std::make_pair(garg, byteArray)); e->tvmFuncArgPushedBytes.push_back(std::make_pair(garg, byteArray));
// release (garg, data), byteArray later // release (garg, data), byteArray later
...@@ -242,7 +242,9 @@ extern "C" int funcInvokeCallback(TVMValue *args, ...@@ -242,7 +242,9 @@ extern "C" int funcInvokeCallback(TVMValue *args,
for (int i = 0; i < numArgs; ++i) { for (int i = 0; i < numArgs; ++i) {
TVMValue arg = args[i]; TVMValue arg = args[i];
int tcode = typeCodes[i]; int tcode = typeCodes[i];
if (tcode == kObjectHandle || tcode == kFuncHandle || tcode == kModuleHandle) { if (tcode == kTVMObjectHandle ||
tcode == kTVMPackedFuncHandle ||
tcode == kTVMModuleHandle) {
TVMCbArgToReturn(&arg, tcode); TVMCbArgToReturn(&arg, tcode);
} }
jobject jarg = tvmRetValueToJava(env, arg, tcode); jobject jarg = tvmRetValueToJava(env, arg, tcode);
...@@ -393,7 +395,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc( ...@@ -393,7 +395,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayAlloc(
JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape( JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayGetShape(
JNIEnv *env, jobject obj, jlong jhandle, jobject jshape) { JNIEnv *env, jobject obj, jlong jhandle, jobject jshape) {
TVMArray *array = reinterpret_cast<TVMArray *>(jhandle); DLTensor *array = reinterpret_cast<DLTensor *>(jhandle);
int64_t *shape = array->shape; int64_t *shape = array->shape;
int ndim = array->ndim; int ndim = array->ndim;
...@@ -424,7 +426,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray( ...@@ -424,7 +426,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray(
JNIEnv *env, jobject obj, jbyteArray jarr, jlong jfrom, jlong jto) { JNIEnv *env, jobject obj, jbyteArray jarr, jlong jfrom, jlong jto) {
jbyte *data = env->GetByteArrayElements(jarr, NULL); jbyte *data = env->GetByteArrayElements(jarr, NULL);
TVMArray *from = reinterpret_cast<TVMArray *>(jfrom); DLTensor *from = reinterpret_cast<DLTensor *>(jfrom);
from->data = static_cast<void *>(data); from->data = static_cast<void *>(data);
int ret = TVMArrayCopyFromTo(static_cast<TVMArrayHandle>(from), int ret = TVMArrayCopyFromTo(static_cast<TVMArrayHandle>(from),
...@@ -438,7 +440,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray( ...@@ -438,7 +440,7 @@ JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyFromJArray(
JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray( JNIEXPORT jint JNICALL Java_org_apache_tvm_LibInfo_tvmArrayCopyToJArray(
JNIEnv *env, jobject obj, jlong jfrom, jbyteArray jarr) { JNIEnv *env, jobject obj, jlong jfrom, jbyteArray jarr) {
TVMArray *from = reinterpret_cast<TVMArray *>(jfrom); DLTensor *from = reinterpret_cast<DLTensor *>(jfrom);
int size = static_cast<int>(env->GetArrayLength(jarr)); int size = static_cast<int>(env->GetArrayLength(jarr));
jbyte *pdata = env->GetByteArrayElements(jarr, NULL); jbyte *pdata = env->GetByteArrayElements(jarr, NULL);
int ret = 0; int ret = 0;
......
...@@ -115,8 +115,8 @@ def _make_tvm_args(args, temp_args): ...@@ -115,8 +115,8 @@ def _make_tvm_args(args, temp_args):
type_codes[i] = TypeCode.NULL type_codes[i] = TypeCode.NULL
elif isinstance(arg, NDArrayBase): elif isinstance(arg, NDArrayBase):
values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p) values[i].v_handle = ctypes.cast(arg.handle, ctypes.c_void_p)
type_codes[i] = (TypeCode.NDARRAY_CONTAINER type_codes[i] = (TypeCode.NDARRAY_HANDLE
if not arg.is_view else TypeCode.ARRAY_HANDLE) if not arg.is_view else TypeCode.DLTENSOR_HANDLE)
elif isinstance(arg, _nd._TVM_COMPATS): elif isinstance(arg, _nd._TVM_COMPATS):
values[i].v_handle = ctypes.c_void_p(arg._tvm_handle) values[i].v_handle = ctypes.c_void_p(arg._tvm_handle)
type_codes[i] = arg.__class__._tvm_tcode type_codes[i] = arg.__class__._tvm_tcode
...@@ -154,14 +154,14 @@ def _make_tvm_args(args, temp_args): ...@@ -154,14 +154,14 @@ def _make_tvm_args(args, temp_args):
type_codes[i] = TypeCode.MODULE_HANDLE type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, FunctionBase): elif isinstance(arg, FunctionBase):
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE type_codes[i] = TypeCode.PACKED_FUNC_HANDLE
elif isinstance(arg, ctypes.c_void_p): elif isinstance(arg, ctypes.c_void_p):
values[i].v_handle = arg values[i].v_handle = arg
type_codes[i] = TypeCode.HANDLE type_codes[i] = TypeCode.HANDLE
elif callable(arg): elif callable(arg):
arg = convert_to_tvm_func(arg) arg = convert_to_tvm_func(arg)
values[i].v_handle = arg.handle values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE type_codes[i] = TypeCode.PACKED_FUNC_HANDLE
temp_args.append(arg) temp_args.append(arg)
else: else:
raise TypeError("Don't know how to handle type %s" % type(arg)) raise TypeError("Don't know how to handle type %s" % type(arg))
...@@ -244,15 +244,15 @@ def _handle_return_func(x): ...@@ -244,15 +244,15 @@ def _handle_return_func(x):
# setup return handle for function type # setup return handle for function type
_object.__init_by_constructor__ = __init_handle_by_constructor__ _object.__init_by_constructor__ = __init_handle_by_constructor__
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func RETURN_SWITCH[TypeCode.PACKED_FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
RETURN_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True) RETURN_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
C_TO_PY_ARG_SWITCH[TypeCode.FUNC_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.PACKED_FUNC_HANDLE] = _wrap_arg_func(
_handle_return_func, TypeCode.FUNC_HANDLE) _handle_return_func, TypeCode.PACKED_FUNC_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func( C_TO_PY_ARG_SWITCH[TypeCode.MODULE_HANDLE] = _wrap_arg_func(
_return_module, TypeCode.MODULE_HANDLE) _return_module, TypeCode.MODULE_HANDLE)
C_TO_PY_ARG_SWITCH[TypeCode.ARRAY_HANDLE] = lambda x: _make_array(x.v_handle, True, False) C_TO_PY_ARG_SWITCH[TypeCode.DLTENSOR_HANDLE] = lambda x: _make_array(x.v_handle, True, False)
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_CONTAINER] = lambda x: _make_array(x.v_handle, False, True) C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)
_CLASS_MODULE = None _CLASS_MODULE = None
_CLASS_FUNCTION = None _CLASS_FUNCTION = None
......
...@@ -26,18 +26,18 @@ cdef enum TVMTypeCode: ...@@ -26,18 +26,18 @@ cdef enum TVMTypeCode:
kInt = 0 kInt = 0
kUInt = 1 kUInt = 1
kFloat = 2 kFloat = 2
kHandle = 3 kTVMOpaqueHandle = 3
kNull = 4 kTVMNullptr = 4
kTVMType = 5 kTVMDataType = 5
kTVMContext = 6 kTVMContext = 6
kArrayHandle = 7 kTVMDLTensorHandle = 7
kObjectHandle = 8 kTVMObjectHandle = 8
kModuleHandle = 9 kTVMModuleHandle = 9
kFuncHandle = 10 kTVMPackedFuncHandle = 10
kStr = 11 kTVMStr = 11
kBytes = 12 kTVMBytes = 12
kNDArrayContainer = 13 kTVMNDArrayHandle = 13
kExtBegin = 15 kTVMExtBegin = 15
cdef extern from "tvm/runtime/c_runtime_api.h": cdef extern from "tvm/runtime/c_runtime_api.h":
ctypedef struct DLDataType: ctypedef struct DLDataType:
......
...@@ -41,13 +41,13 @@ cdef int tvm_callback(TVMValue* args, ...@@ -41,13 +41,13 @@ cdef int tvm_callback(TVMValue* args,
for i in range(num_args): for i in range(num_args):
value = args[i] value = args[i]
tcode = type_codes[i] tcode = type_codes[i]
if (tcode == kObjectHandle or if (tcode == kTVMObjectHandle or
tcode == kFuncHandle or tcode == kTVMPackedFuncHandle or
tcode == kModuleHandle or tcode == kTVMModuleHandle or
tcode > kExtBegin): tcode > kTVMExtBegin):
CALL(TVMCbArgToReturn(&value, tcode)) CALL(TVMCbArgToReturn(&value, tcode))
if tcode != kArrayHandle: if tcode != kTVMDLTensorHandle:
pyargs.append(make_ret(value, tcode)) pyargs.append(make_ret(value, tcode))
else: else:
pyargs.append(c_make_array(value.v_handle, True, False)) pyargs.append(c_make_array(value.v_handle, True, False))
...@@ -99,11 +99,11 @@ cdef inline int make_arg(object arg, ...@@ -99,11 +99,11 @@ cdef inline int make_arg(object arg,
cdef unsigned long long ptr cdef unsigned long long ptr
if isinstance(arg, ObjectBase): if isinstance(arg, ObjectBase):
value[0].v_handle = (<ObjectBase>arg).chandle value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kObjectHandle tcode[0] = kTVMObjectHandle
elif isinstance(arg, NDArrayBase): elif isinstance(arg, NDArrayBase):
value[0].v_handle = (<NDArrayBase>arg).chandle value[0].v_handle = (<NDArrayBase>arg).chandle
tcode[0] = (kNDArrayContainer if tcode[0] = (kTVMNDArrayHandle if
not (<NDArrayBase>arg).c_is_view else kArrayHandle) not (<NDArrayBase>arg).c_is_view else kTVMDLTensorHandle)
elif isinstance(arg, _TVM_COMPATS): elif isinstance(arg, _TVM_COMPATS):
ptr = arg._tvm_handle ptr = arg._tvm_handle
value[0].v_handle = (<void*>ptr) value[0].v_handle = (<void*>ptr)
...@@ -117,18 +117,18 @@ cdef inline int make_arg(object arg, ...@@ -117,18 +117,18 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, str): elif isinstance(arg, str):
tstr = c_str(arg) tstr = c_str(arg)
value[0].v_str = tstr value[0].v_str = tstr
tcode[0] = kStr tcode[0] = kTVMStr
temp_args.append(tstr) temp_args.append(tstr)
elif arg is None: elif arg is None:
value[0].v_handle = NULL value[0].v_handle = NULL
tcode[0] = kNull tcode[0] = kTVMNullptr
elif isinstance(arg, Number): elif isinstance(arg, Number):
value[0].v_float64 = arg value[0].v_float64 = arg
tcode[0] = kFloat tcode[0] = kFloat
elif isinstance(arg, TVMType): elif isinstance(arg, TVMType):
tstr = c_str(str(arg)) tstr = c_str(str(arg))
value[0].v_str = tstr value[0].v_str = tstr
tcode[0] = kStr tcode[0] = kTVMStr
temp_args.append(tstr) temp_args.append(tstr)
elif isinstance(arg, TVMContext): elif isinstance(arg, TVMContext):
value[0].v_ctx = (<DLContext*>( value[0].v_ctx = (<DLContext*>(
...@@ -142,31 +142,31 @@ cdef inline int make_arg(object arg, ...@@ -142,31 +142,31 @@ cdef inline int make_arg(object arg,
arr.size = len(arg) arr.size = len(arg)
value[0].v_handle = <void*>( value[0].v_handle = <void*>(
<unsigned long long>ctypes.addressof(arr)) <unsigned long long>ctypes.addressof(arr))
tcode[0] = kBytes tcode[0] = kTVMBytes
temp_args.append(arr) temp_args.append(arr)
elif isinstance(arg, string_types): elif isinstance(arg, string_types):
tstr = c_str(arg) tstr = c_str(arg)
value[0].v_str = tstr value[0].v_str = tstr
tcode[0] = kStr tcode[0] = kTVMStr
temp_args.append(tstr) temp_args.append(tstr)
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)): elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_object(arg) arg = convert_to_object(arg)
value[0].v_handle = (<ObjectBase>arg).chandle value[0].v_handle = (<ObjectBase>arg).chandle
tcode[0] = kObjectHandle tcode[0] = kTVMObjectHandle
temp_args.append(arg) temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE): elif isinstance(arg, _CLASS_MODULE):
value[0].v_handle = c_handle(arg.handle) value[0].v_handle = c_handle(arg.handle)
tcode[0] = kModuleHandle tcode[0] = kTVMModuleHandle
elif isinstance(arg, FunctionBase): elif isinstance(arg, FunctionBase):
value[0].v_handle = (<FunctionBase>arg).chandle value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle tcode[0] = kTVMPackedFuncHandle
elif isinstance(arg, ctypes.c_void_p): elif isinstance(arg, ctypes.c_void_p):
value[0].v_handle = c_handle(arg) value[0].v_handle = c_handle(arg)
tcode[0] = kHandle tcode[0] = kTVMOpaqueHandle
elif callable(arg): elif callable(arg):
arg = convert_to_tvm_func(arg) arg = convert_to_tvm_func(arg)
value[0].v_handle = (<FunctionBase>arg).chandle value[0].v_handle = (<FunctionBase>arg).chandle
tcode[0] = kFuncHandle tcode[0] = kTVMPackedFuncHandle
temp_args.append(arg) temp_args.append(arg)
else: else:
raise TypeError("Don't know how to handle type %s" % type(arg)) raise TypeError("Don't know how to handle type %s" % type(arg))
...@@ -184,27 +184,27 @@ cdef inline bytearray make_ret_bytes(void* chandle): ...@@ -184,27 +184,27 @@ cdef inline bytearray make_ret_bytes(void* chandle):
cdef inline object make_ret(TVMValue value, int tcode): cdef inline object make_ret(TVMValue value, int tcode):
"""convert result to return value.""" """convert result to return value."""
if tcode == kObjectHandle: if tcode == kTVMObjectHandle:
return make_ret_object(value.v_handle) return make_ret_object(value.v_handle)
elif tcode == kNull: elif tcode == kTVMNullptr:
return None return None
elif tcode == kInt: elif tcode == kInt:
return value.v_int64 return value.v_int64
elif tcode == kFloat: elif tcode == kFloat:
return value.v_float64 return value.v_float64
elif tcode == kNDArrayContainer: elif tcode == kTVMNDArrayHandle:
return c_make_array(value.v_handle, False, True) return c_make_array(value.v_handle, False, True)
elif tcode == kStr: elif tcode == kTVMStr:
return py_str(value.v_str) return py_str(value.v_str)
elif tcode == kBytes: elif tcode == kTVMBytes:
return make_ret_bytes(value.v_handle) return make_ret_bytes(value.v_handle)
elif tcode == kHandle: elif tcode == kTVMOpaqueHandle:
return ctypes_handle(value.v_handle) return ctypes_handle(value.v_handle)
elif tcode == kTVMContext: elif tcode == kTVMContext:
return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id) return TVMContext(value.v_ctx.device_type, value.v_ctx.device_id)
elif tcode == kModuleHandle: elif tcode == kTVMModuleHandle:
return _CLASS_MODULE(ctypes_handle(value.v_handle)) return _CLASS_MODULE(ctypes_handle(value.v_handle))
elif tcode == kFuncHandle: elif tcode == kTVMPackedFuncHandle:
fobj = _CLASS_FUNCTION(None, False) fobj = _CLASS_FUNCTION(None, False)
(<FunctionBase>fobj).chandle = value.v_handle (<FunctionBase>fobj).chandle = value.v_handle
return fobj return fobj
......
...@@ -97,5 +97,5 @@ cdef class ObjectBase: ...@@ -97,5 +97,5 @@ cdef class ObjectBase:
cdef void* chandle cdef void* chandle
ConstructorCall( ConstructorCall(
(<FunctionBase>fconstructor).chandle, (<FunctionBase>fconstructor).chandle,
kObjectHandle, args, &chandle) kTVMObjectHandle, args, &chandle)
self.chandle = chandle self.chandle = chandle
...@@ -35,13 +35,13 @@ class TypeCode(object): ...@@ -35,13 +35,13 @@ class TypeCode(object):
NULL = 4 NULL = 4
TVM_TYPE = 5 TVM_TYPE = 5
TVM_CONTEXT = 6 TVM_CONTEXT = 6
ARRAY_HANDLE = 7 DLTENSOR_HANDLE = 7
OBJECT_HANDLE = 8 OBJECT_HANDLE = 8
MODULE_HANDLE = 9 MODULE_HANDLE = 9
FUNC_HANDLE = 10 PACKED_FUNC_HANDLE = 10
STR = 11 STR = 11
BYTES = 12 BYTES = 12
NDARRAY_CONTAINER = 13 NDARRAY_HANDLE = 13
EXT_BEGIN = 15 EXT_BEGIN = 15
......
...@@ -352,7 +352,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { ...@@ -352,7 +352,7 @@ impl<'m, 't> GraphExecutor<'m, 't> {
} }
} }
// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h // Converts a string to TVM DLDataTypeCode. @see `String2DLDataType` in packed_func.h
named!( named!(
tvm_str_to_type<CompleteStr, DataType>, tvm_str_to_type<CompleteStr, DataType>,
do_parse!( do_parse!(
......
...@@ -32,7 +32,7 @@ ...@@ -32,7 +32,7 @@
namespace tvm { namespace tvm {
TVM_REGISTER_GLOBAL("_format_str") TVM_REGISTER_GLOBAL("_format_str")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kObjectHandle); CHECK(args[0].type_code() == kTVMObjectHandle);
std::ostringstream os; std::ostringstream os;
os << args[0].operator ObjectRef(); os << args[0].operator ObjectRef();
*ret = os.str(); *ret = os.str();
...@@ -40,7 +40,7 @@ TVM_REGISTER_GLOBAL("_format_str") ...@@ -40,7 +40,7 @@ TVM_REGISTER_GLOBAL("_format_str")
TVM_REGISTER_GLOBAL("_raw_ptr") TVM_REGISTER_GLOBAL("_raw_ptr")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kObjectHandle); CHECK(args[0].type_code() == kTVMObjectHandle);
*ret = reinterpret_cast<int64_t>(args[0].value().v_handle); *ret = reinterpret_cast<int64_t>(args[0].value().v_handle);
}); });
......
...@@ -64,7 +64,7 @@ TVM_REGISTER_GLOBAL("_Array") ...@@ -64,7 +64,7 @@ TVM_REGISTER_GLOBAL("_Array")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
std::vector<ObjectRef> data; std::vector<ObjectRef> data;
for (int i = 0; i < args.size(); ++i) { for (int i = 0; i < args.size(); ++i) {
if (args[i].type_code() != kNull) { if (args[i].type_code() != kTVMNullptr) {
data.push_back(args[i].operator ObjectRef()); data.push_back(args[i].operator ObjectRef());
} else { } else {
data.push_back(ObjectRef(nullptr)); data.push_back(ObjectRef(nullptr));
...@@ -78,7 +78,7 @@ TVM_REGISTER_GLOBAL("_Array") ...@@ -78,7 +78,7 @@ TVM_REGISTER_GLOBAL("_Array")
TVM_REGISTER_GLOBAL("_ArrayGetItem") TVM_REGISTER_GLOBAL("_ArrayGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
int64_t i = args[1]; int64_t i = args[1];
CHECK_EQ(args[0].type_code(), kObjectHandle); CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle); Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<ArrayNode>()); CHECK(ptr->IsInstance<ArrayNode>());
auto* n = static_cast<const ArrayNode*>(ptr); auto* n = static_cast<const ArrayNode*>(ptr);
...@@ -89,7 +89,7 @@ TVM_REGISTER_GLOBAL("_ArrayGetItem") ...@@ -89,7 +89,7 @@ TVM_REGISTER_GLOBAL("_ArrayGetItem")
TVM_REGISTER_GLOBAL("_ArraySize") TVM_REGISTER_GLOBAL("_ArraySize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kObjectHandle); CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle); Object* ptr = static_cast<Object*>(args[0].value().v_handle);
CHECK(ptr->IsInstance<ArrayNode>()); CHECK(ptr->IsInstance<ArrayNode>());
*ret = static_cast<int64_t>( *ret = static_cast<int64_t>(
...@@ -99,11 +99,11 @@ TVM_REGISTER_GLOBAL("_ArraySize") ...@@ -99,11 +99,11 @@ TVM_REGISTER_GLOBAL("_ArraySize")
TVM_REGISTER_GLOBAL("_Map") TVM_REGISTER_GLOBAL("_Map")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size() % 2, 0); CHECK_EQ(args.size() % 2, 0);
if (args.size() != 0 && args[0].type_code() == kStr) { if (args.size() != 0 && args[0].type_code() == kTVMStr) {
// StrMap // StrMap
StrMapNode::ContainerType data; StrMapNode::ContainerType data;
for (int i = 0; i < args.num_args; i += 2) { for (int i = 0; i < args.num_args; i += 2) {
CHECK(args[i].type_code() == kStr) CHECK(args[i].type_code() == kTVMStr)
<< "key of str map need to be str"; << "key of str map need to be str";
CHECK(args[i + 1].IsObjectRef<ObjectRef>()) CHECK(args[i + 1].IsObjectRef<ObjectRef>())
<< "value of the map to be NodeRef"; << "value of the map to be NodeRef";
...@@ -132,7 +132,7 @@ TVM_REGISTER_GLOBAL("_Map") ...@@ -132,7 +132,7 @@ TVM_REGISTER_GLOBAL("_Map")
TVM_REGISTER_GLOBAL("_MapSize") TVM_REGISTER_GLOBAL("_MapSize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kObjectHandle); CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle); Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) { if (ptr->IsInstance<MapNode>()) {
auto* n = static_cast<const MapNode*>(ptr); auto* n = static_cast<const MapNode*>(ptr);
...@@ -146,11 +146,11 @@ TVM_REGISTER_GLOBAL("_MapSize") ...@@ -146,11 +146,11 @@ TVM_REGISTER_GLOBAL("_MapSize")
TVM_REGISTER_GLOBAL("_MapGetItem") TVM_REGISTER_GLOBAL("_MapGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kObjectHandle); CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle); Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) { if (ptr->IsInstance<MapNode>()) {
CHECK(args[1].type_code() == kObjectHandle); CHECK(args[1].type_code() == kTVMObjectHandle);
auto* n = static_cast<const MapNode*>(ptr); auto* n = static_cast<const MapNode*>(ptr);
auto it = n->data.find(args[1].operator ObjectRef()); auto it = n->data.find(args[1].operator ObjectRef());
CHECK(it != n->data.end()) CHECK(it != n->data.end())
...@@ -168,12 +168,12 @@ TVM_REGISTER_GLOBAL("_MapGetItem") ...@@ -168,12 +168,12 @@ TVM_REGISTER_GLOBAL("_MapGetItem")
TVM_REGISTER_GLOBAL("_MapCount") TVM_REGISTER_GLOBAL("_MapCount")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kObjectHandle); CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle); Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) { if (ptr->IsInstance<MapNode>()) {
auto* n = static_cast<const MapNode*>(ptr); auto* n = static_cast<const MapNode*>(ptr);
CHECK_EQ(args[0].type_code(), kObjectHandle); CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
*ret = static_cast<int64_t>( *ret = static_cast<int64_t>(
n->data.count(args[1].operator ObjectRef())); n->data.count(args[1].operator ObjectRef()));
} else { } else {
...@@ -186,7 +186,7 @@ TVM_REGISTER_GLOBAL("_MapCount") ...@@ -186,7 +186,7 @@ TVM_REGISTER_GLOBAL("_MapCount")
TVM_REGISTER_GLOBAL("_MapItems") TVM_REGISTER_GLOBAL("_MapItems")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kObjectHandle); CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* ptr = static_cast<Object*>(args[0].value().v_handle); Object* ptr = static_cast<Object*>(args[0].value().v_handle);
if (ptr->IsInstance<MapNode>()) { if (ptr->IsInstance<MapNode>()) {
......
...@@ -216,7 +216,7 @@ std::string CodeGenC::GetStructRef( ...@@ -216,7 +216,7 @@ std::string CodeGenC::GetStructRef(
DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) { DataType t, const PrimExpr& buffer, const PrimExpr& index, int kind) {
if (kind < intrinsic::kArrKindBound_) { if (kind < intrinsic::kArrKindBound_) {
std::ostringstream os; std::ostringstream os;
os << "(((TVMArray*)"; os << "(((DLTensor*)";
this->PrintExpr(buffer, os); this->PrintExpr(buffer, os);
os << ")"; os << ")";
if (kind == intrinsic::kArrAddr) { if (kind == intrinsic::kArrAddr) {
......
...@@ -200,7 +200,7 @@ void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT( ...@@ -200,7 +200,7 @@ void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(
const std::string& type = op->args[0].as<StringImmNode>()->value; const std::string& type = op->args[0].as<StringImmNode>()->value;
const IntImmNode* num = op->args[1].as<IntImmNode>(); const IntImmNode* num = op->args[1].as<IntImmNode>();
CHECK(num != nullptr); CHECK(num != nullptr);
static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant"); static_assert(alignof(TVMValue) % alignof(DLTensor) == 0, "invariant");
size_t unit = sizeof(TVMValue); size_t unit = sizeof(TVMValue);
size_t size = 0; size_t size = 0;
if (type == "shape") { if (type == "shape") {
...@@ -210,7 +210,7 @@ void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT( ...@@ -210,7 +210,7 @@ void CodeGenCHost::VisitExpr_(const CallNode *op, std::ostream& os) { // NOLINT(
} else if (type == "arg_tcode") { } else if (type == "arg_tcode") {
size = (num->value * sizeof(int) + unit - 1) / unit; size = (num->value * sizeof(int) + unit - 1) / unit;
} else if (type == "array") { } else if (type == "array") {
size = (num->value * sizeof(TVMArray) + unit - 1) / unit; size = (num->value * sizeof(DLTensor) + unit - 1) / unit;
} else { } else {
LOG(FATAL) << "Unknown stack alloca type " << type; LOG(FATAL) << "Unknown stack alloca type " << type;
} }
......
...@@ -25,19 +25,23 @@ ...@@ -25,19 +25,23 @@
namespace tvm { namespace tvm {
namespace datatype { namespace datatype {
TVM_REGISTER_GLOBAL("_datatype_register").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL("_datatype_register")
.set_body([](TVMArgs args, TVMRetValue* ret) {
datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int())); datatype::Registry::Global()->Register(args[0], static_cast<uint8_t>(args[1].operator int()));
}); });
TVM_REGISTER_GLOBAL("_datatype_get_type_code").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL("_datatype_get_type_code")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = datatype::Registry::Global()->GetTypeCode(args[0]); *ret = datatype::Registry::Global()->GetTypeCode(args[0]);
}); });
TVM_REGISTER_GLOBAL("_datatype_get_type_name").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL("_datatype_get_type_name")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Registry::Global()->GetTypeName(args[0].operator int()); *ret = Registry::Global()->GetTypeName(args[0].operator int());
}); });
TVM_REGISTER_GLOBAL("_datatype_get_type_registered").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_GLOBAL("_datatype_get_type_registered")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); *ret = Registry::Global()->GetTypeRegistered(args[0].operator int());
}); });
...@@ -47,7 +51,8 @@ Registry* Registry::Global() { ...@@ -47,7 +51,8 @@ Registry* Registry::Global() {
} }
void Registry::Register(const std::string& type_name, uint8_t type_code) { void Registry::Register(const std::string& type_name, uint8_t type_code) {
CHECK(type_code >= kCustomBegin) << "Please choose a type code >= kCustomBegin for custom types"; CHECK(type_code >= kTVMCustomBegin)
<< "Please choose a type code >= kTVMCustomBegin for custom types";
code_to_name_[type_code] = type_name; code_to_name_[type_code] = type_name;
name_to_code_[type_name] = type_code; name_to_code_[type_name] = type_code;
} }
......
...@@ -60,7 +60,7 @@ class Registry { ...@@ -60,7 +60,7 @@ class Registry {
* same code. Generally, this should be straightforward, as the user will be manually registering * same code. Generally, this should be straightforward, as the user will be manually registering
* all of their custom types. * all of their custom types.
* \param type_name The name of the type, e.g. "bfloat" * \param type_name The name of the type, e.g. "bfloat"
* \param type_code The type code, which should be greater than TVMTypeCode::kExtEnd * \param type_code The type code, which should be greater than TVMTypeCode::kTVMExtEnd
*/ */
void Register(const std::string& type_name, uint8_t type_code); void Register(const std::string& type_name, uint8_t type_code);
......
...@@ -731,7 +731,7 @@ llvm::Value *CodeGenCPU::CreateCallTracePacked(const CallNode *op) { ...@@ -731,7 +731,7 @@ llvm::Value *CodeGenCPU::CreateCallTracePacked(const CallNode *op) {
llvm::Value *ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8); llvm::Value *ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8);
// Check the ret_type_code and create cmp instruction. // Check the ret_type_code and create cmp instruction.
llvm::Value *cmp = builder_->CreateICmpNE( llvm::Value *cmp = builder_->CreateICmpNE(
ret_tcode_value, llvm::ConstantInt::get(t_int_, kNull)); ret_tcode_value, llvm::ConstantInt::get(t_int_, kTVMNullptr));
builder_->CreateCondBr(cmp, update_block, continue_block); builder_->CreateCondBr(cmp, update_block, continue_block);
builder_->SetInsertPoint(update_block); builder_->SetInsertPoint(update_block);
builder_->CreateBr(continue_block); builder_->CreateBr(continue_block);
......
...@@ -199,7 +199,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { ...@@ -199,7 +199,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) {
const std::string& type = op->args[0].as<StringImmNode>()->value; const std::string& type = op->args[0].as<StringImmNode>()->value;
const IntImmNode* num = op->args[1].as<IntImmNode>(); const IntImmNode* num = op->args[1].as<IntImmNode>();
CHECK(num != nullptr); CHECK(num != nullptr);
static_assert(alignof(TVMValue) % alignof(TVMArray) == 0, "invariant"); static_assert(alignof(TVMValue) % alignof(DLTensor) == 0, "invariant");
// static_assert(alignof(TVMValue) % alignof(tvm_index_t) == 0, "invariant"); // static_assert(alignof(TVMValue) % alignof(tvm_index_t) == 0, "invariant");
size_t unit = sizeof(TVMValue); size_t unit = sizeof(TVMValue);
size_t size = 0; size_t size = 0;
...@@ -210,7 +210,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) { ...@@ -210,7 +210,7 @@ void CodeGenStackVM::VisitExpr_(const CallNode* op) {
} else if (type == "arg_tcode") { } else if (type == "arg_tcode") {
size = (num->value * sizeof(int) + unit - 1) / unit; size = (num->value * sizeof(int) + unit - 1) / unit;
} else if (type == "array") { } else if (type == "array") {
size = (num->value * sizeof(TVMArray) + unit - 1) / unit; size = (num->value * sizeof(DLTensor) + unit - 1) / unit;
} else { } else {
LOG(FATAL) << "Unknown stack alloca type " << type; LOG(FATAL) << "Unknown stack alloca type " << type;
} }
......
...@@ -43,7 +43,7 @@ void DictAttrsNode::InitByPackedArgs( ...@@ -43,7 +43,7 @@ void DictAttrsNode::InitByPackedArgs(
runtime::TVMArgValue val = args[i + 1]; runtime::TVMArgValue val = args[i + 1];
if (val.IsObjectRef<ObjectRef>()) { if (val.IsObjectRef<ObjectRef>()) {
dict.Set(key, val.operator ObjectRef()); dict.Set(key, val.operator ObjectRef());
} else if (val.type_code() == kStr) { } else if (val.type_code() == kTVMStr) {
dict.Set(key, PrimExpr(val.operator std::string())); dict.Set(key, PrimExpr(val.operator std::string()));
} else { } else {
dict.Set(key, val.operator PrimExpr()); dict.Set(key, val.operator PrimExpr());
......
...@@ -129,10 +129,10 @@ void OpRegistry::UpdateAttr(const std::string& key, ...@@ -129,10 +129,10 @@ void OpRegistry::UpdateAttr(const std::string& key,
CHECK(p.second != plevel) CHECK(p.second != plevel)
<< "Attribute " << key << " of operator " << this->name << "Attribute " << key << " of operator " << this->name
<< " is already registered with same plevel=" << plevel; << " is already registered with same plevel=" << plevel;
CHECK(value.type_code() != kNull) CHECK(value.type_code() != kTVMNullptr)
<< "Registered packed_func is Null for " << key << "Registered packed_func is Null for " << key
<< " of operator " << this->name; << " of operator " << this->name;
if (p.second < plevel && value.type_code() != kNull) { if (p.second < plevel && value.type_code() != kTVMNullptr) {
op_map->data_[index] = std::make_pair(value, plevel); op_map->data_[index] = std::make_pair(value, plevel);
} }
} }
...@@ -195,7 +195,7 @@ TVM_REGISTER_GLOBAL("relay.op._Register") ...@@ -195,7 +195,7 @@ TVM_REGISTER_GLOBAL("relay.op._Register")
LOG(FATAL) << "attrs type key no longer supported"; LOG(FATAL) << "attrs type key no longer supported";
} else { } else {
// normal attr table override. // normal attr table override.
if (args[2].type_code() == kFuncHandle) { if (args[2].type_code() == kTVMPackedFuncHandle) {
// do an eager copy of the PackedFunc // do an eager copy of the PackedFunc
PackedFunc f = args[2]; PackedFunc f = args[2];
// If we get a function from frontend, avoid deleting it. // If we get a function from frontend, avoid deleting it.
......
...@@ -97,7 +97,7 @@ runtime::TVMRetValue ReflectionVTable::GetAttr( ...@@ -97,7 +97,7 @@ runtime::TVMRetValue ReflectionVTable::GetAttr(
success = true; success = true;
} else if (!self->IsInstance<DictAttrsNode>()) { } else if (!self->IsInstance<DictAttrsNode>()) {
VisitAttrs(self, &getter); VisitAttrs(self, &getter);
success = getter.found_ref_object || ret.type_code() != kNull; success = getter.found_ref_object || ret.type_code() != kTVMNullptr;
} else { } else {
// specially handle dict attr // specially handle dict attr
DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self); DictAttrsNode* dnode = static_cast<DictAttrsNode*>(self);
...@@ -258,13 +258,13 @@ void InitNodeByPackedArgs(Object* n, const TVMArgs& args) { ...@@ -258,13 +258,13 @@ void InitNodeByPackedArgs(Object* n, const TVMArgs& args) {
// Expose to FFI APIs. // Expose to FFI APIs.
void NodeGetAttr(TVMArgs args, TVMRetValue* ret) { void NodeGetAttr(TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kObjectHandle); CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* self = static_cast<Object*>(args[0].value().v_handle); Object* self = static_cast<Object*>(args[0].value().v_handle);
*ret = ReflectionVTable::Global()->GetAttr(self, args[1]); *ret = ReflectionVTable::Global()->GetAttr(self, args[1]);
} }
void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) { void NodeListAttrNames(TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args[0].type_code(), kObjectHandle); CHECK_EQ(args[0].type_code(), kTVMObjectHandle);
Object* self = static_cast<Object*>(args[0].value().v_handle); Object* self = static_cast<Object*>(args[0].value().v_handle);
auto names = std::make_shared<std::vector<std::string> >( auto names = std::make_shared<std::vector<std::string> >(
......
...@@ -39,11 +39,11 @@ ...@@ -39,11 +39,11 @@
namespace tvm { namespace tvm {
inline std::string Type2String(const DataType& t) { inline std::string Type2String(const DataType& t) {
return runtime::TVMType2String(t); return runtime::DLDataType2String(t);
} }
inline DataType String2Type(std::string s) { inline DataType String2Type(std::string s) {
return DataType(runtime::String2TVMType(s)); return DataType(runtime::String2DLDataType(s));
} }
// indexer to index all the nodes // indexer to index all the nodes
......
...@@ -260,9 +260,9 @@ class BuiltinLower : public StmtExprMutator { ...@@ -260,9 +260,9 @@ class BuiltinLower : public StmtExprMutator {
intrinsic::kTVMValueContent, arg)); intrinsic::kTVMValueContent, arg));
int arg_tcode = api_type.code(); int arg_tcode = api_type.code();
if (api_type.is_handle() && arg.as<StringImmNode>()) { if (api_type.is_handle() && arg.as<StringImmNode>()) {
arg_tcode = kStr; arg_tcode = kTVMStr;
} }
if (IsArrayHandle(arg)) arg_tcode = kArrayHandle; if (IsArrayHandle(arg)) arg_tcode = kTVMDLTensorHandle;
prep_seq_.emplace_back( prep_seq_.emplace_back(
StoreNode::make(stack_tcode_, StoreNode::make(stack_tcode_,
ConstInt32(arg_tcode), ConstInt32(arg_tcode),
......
...@@ -124,10 +124,10 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -124,10 +124,10 @@ LoweredFunc MakeAPI(Stmt body,
std::ostringstream msg; std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be pointer"; msg << name << ": Expect arg[" << i << "] to be pointer";
seq_check.emplace_back( seq_check.emplace_back(
AssertStmtNode::make(tcode == kHandle || AssertStmtNode::make(tcode == kTVMOpaqueHandle ||
tcode == kNDArrayContainer || tcode == kTVMNDArrayHandle ||
tcode == kArrayHandle || tcode == kTVMDLTensorHandle ||
tcode == kNull, msg.str(), nop)); tcode == kTVMNullptr, msg.str(), nop));
} else if (t.is_int() || t.is_uint()) { } else if (t.is_int() || t.is_uint()) {
std::ostringstream msg; std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be int"; msg << name << ": Expect arg[" << i << "] to be int";
......
...@@ -108,7 +108,7 @@ Doc PrintBool(bool value) { ...@@ -108,7 +108,7 @@ Doc PrintBool(bool value) {
} }
Doc PrintDType(DataType dtype) { Doc PrintDType(DataType dtype) {
return Doc(runtime::TVMType2String(dtype)); return Doc(runtime::DLDataType2String(dtype));
} }
Doc PrintString(const std::string& value) { Doc PrintString(const std::string& value) {
......
...@@ -932,7 +932,7 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { ...@@ -932,7 +932,7 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
LOG(FATAL) << "do not allow void as argument"; LOG(FATAL) << "do not allow void as argument";
} }
void Visit(const char* key, DataType* value) final { void Visit(const char* key, DataType* value) final {
PrintKV(key, PrintString(runtime::TVMType2String(*value))); PrintKV(key, PrintString(runtime::DLDataType2String(*value)));
} }
void Visit(const char* key, runtime::NDArray* value) final { void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument"; LOG(FATAL) << "do not allow NDarray as argument";
......
...@@ -1182,7 +1182,7 @@ double ToScalar(const runtime::NDArray& array) { ...@@ -1182,7 +1182,7 @@ double ToScalar(const runtime::NDArray& array) {
return reinterpret_cast<double*>(array->data)[0]; return reinterpret_cast<double*>(array->data)[0];
} }
} }
LOG(FATAL) << "Unknown data type: " << tvm::runtime::TVMType2String(array->dtype); LOG(FATAL) << "Unknown data type: " << tvm::runtime::DLDataType2String(array->dtype);
// make compiler happy // make compiler happy
return -std::numeric_limits<double>::infinity(); return -std::numeric_limits<double>::infinity();
} }
......
...@@ -151,7 +151,7 @@ DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) { ...@@ -151,7 +151,7 @@ DeviceAPI* DeviceAPI::Get(TVMContext ctx, bool allow_missing) {
void* DeviceAPI::AllocWorkspace(TVMContext ctx, void* DeviceAPI::AllocWorkspace(TVMContext ctx,
size_t size, size_t size,
TVMType type_hint) { DLDataType type_hint) {
return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint); return AllocDataSpace(ctx, size, kTempAllocaAlignment, type_hint);
} }
...@@ -431,7 +431,7 @@ void* TVMBackendAllocWorkspace(int device_type, ...@@ -431,7 +431,7 @@ void* TVMBackendAllocWorkspace(int device_type,
ctx.device_type = static_cast<DLDeviceType>(device_type); ctx.device_type = static_cast<DLDeviceType>(device_type);
ctx.device_id = device_id; ctx.device_id = device_id;
TVMType type_hint; DLDataType type_hint;
type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint); type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint);
type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint); type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint);
type_hint.lanes = 1; type_hint.lanes = 1;
...@@ -479,22 +479,22 @@ int TVMFuncCall(TVMFunctionHandle func, ...@@ -479,22 +479,22 @@ int TVMFuncCall(TVMFunctionHandle func,
(*static_cast<const PackedFunc*>(func)).CallPacked( (*static_cast<const PackedFunc*>(func)).CallPacked(
TVMArgs(args, arg_type_codes, num_args), &rv); TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string. // handle return string.
if (rv.type_code() == kStr || if (rv.type_code() == kTVMStr ||
rv.type_code() == kTVMType || rv.type_code() == kTVMDataType ||
rv.type_code() == kBytes) { rv.type_code() == kTVMBytes) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get(); TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
if (rv.type_code() != kTVMType) { if (rv.type_code() != kTVMDataType) {
e->ret_str = *rv.ptr<std::string>(); e->ret_str = *rv.ptr<std::string>();
} else { } else {
e->ret_str = rv.operator std::string(); e->ret_str = rv.operator std::string();
} }
if (rv.type_code() == kBytes) { if (rv.type_code() == kTVMBytes) {
e->ret_bytes.data = e->ret_str.c_str(); e->ret_bytes.data = e->ret_str.c_str();
e->ret_bytes.size = e->ret_str.length(); e->ret_bytes.size = e->ret_str.length();
*ret_type_code = kBytes; *ret_type_code = kTVMBytes;
ret_val->v_handle = &(e->ret_bytes); ret_val->v_handle = &(e->ret_bytes);
} else { } else {
*ret_type_code = kStr; *ret_type_code = kTVMStr;
ret_val->v_str = e->ret_str.c_str(); ret_val->v_str = e->ret_str.c_str();
} }
} else { } else {
......
...@@ -52,7 +52,7 @@ void ConvolutionForward( ...@@ -52,7 +52,7 @@ void ConvolutionForward(
// Set Ctx // Set Ctx
entry_ptr->conv_entry.ctx = x->ctx; entry_ptr->conv_entry.ctx = x->ctx;
// Set Data Type // Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype)); entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype); cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(x->dtype);
// Dims includes N and C // Dims includes N and C
int full_dims = dims + 2; int full_dims = dims + 2;
...@@ -194,8 +194,8 @@ void OutputShape( ...@@ -194,8 +194,8 @@ void OutputShape(
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
// Set Data Type // Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype)); entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(data_dtype)); cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(data_dtype));
// Set Format // Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format); entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// Dims includes N and C // Dims includes N and C
...@@ -276,8 +276,8 @@ void FindAlgo( ...@@ -276,8 +276,8 @@ void FindAlgo(
CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal(); CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
// Set Data Type // Set Data Type
entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(conv_dtype)); entry_ptr->conv_entry.data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(conv_dtype));
cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2TVMType(data_dtype)); cudnnDataType_t data_type = CuDNNDataType::DLTypeToCuDNNType(String2DLDataType(data_dtype));
// Set Format // Set Format
entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format); entry_ptr->conv_entry.tensor_format = static_cast<cudnnTensorFormat_t>(format);
// Dims includes N and C // Dims includes N and C
......
...@@ -142,10 +142,11 @@ class ExampleJsonModule : public ModuleNode { ...@@ -142,10 +142,11 @@ class ExampleJsonModule : public ModuleNode {
this->curr_subgraph_ = name; this->curr_subgraph_ = name;
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
for (auto i = 0; i < args.size(); ++i) { for (auto i = 0; i < args.size(); ++i) {
CHECK(args[i].type_code() == kNDArrayContainer || args[i].type_code() == kArrayHandle) CHECK(args[i].type_code() == kTVMNDArrayHandle ||
args[i].type_code() == kTVMDLTensorHandle)
<< "Expect NDArray or DLTensor as inputs" << "Expect NDArray or DLTensor as inputs"
<< "\n"; << "\n";
if (args[i].type_code() == kArrayHandle) { if (args[i].type_code() == kTVMDLTensorHandle) {
DLTensor* arg = args[i]; DLTensor* arg = args[i];
this->data_entry_[i].CopyFrom(arg); this->data_entry_[i].CopyFrom(arg);
} else { } else {
...@@ -158,7 +159,7 @@ class ExampleJsonModule : public ModuleNode { ...@@ -158,7 +159,7 @@ class ExampleJsonModule : public ModuleNode {
} }
CHECK_GT(graph_.count(this->curr_subgraph_), 0U); CHECK_GT(graph_.count(this->curr_subgraph_), 0U);
auto out_idx = graph_[this->curr_subgraph_].back().output; auto out_idx = graph_[this->curr_subgraph_].back().output;
if (args[args.size() - 1].type_code() == kArrayHandle) { if (args[args.size() - 1].type_code() == kTVMDLTensorHandle) {
DLTensor* arg = args[args.size() - 1]; DLTensor* arg = args[args.size() - 1];
this->data_entry_[out_idx].CopyTo(arg); this->data_entry_[out_idx].CopyTo(arg);
} else { } else {
...@@ -341,4 +342,3 @@ TVM_REGISTER_GLOBAL("module.loadbinary_examplejson") ...@@ -341,4 +342,3 @@ TVM_REGISTER_GLOBAL("module.loadbinary_examplejson")
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -40,7 +40,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") ...@@ -40,7 +40,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
DLTensor *input = args[0]; DLTensor *input = args[0];
DLTensor *kernel = args[1]; DLTensor *kernel = args[1];
DLTensor *bias = nullptr; DLTensor *bias = nullptr;
if (args[2].type_code() == kArrayHandle) { if (args[2].type_code() == kTVMDLTensorHandle) {
bias = args[2]; bias = args[2];
} }
DLTensor *output = args[3]; DLTensor *output = args[3];
...@@ -103,7 +103,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") ...@@ -103,7 +103,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float); const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float);
TVMContext ctx = input->ctx; TVMContext ctx = input->ctx;
TVMType type_hint = input->dtype; DLDataType type_hint = input->dtype;
DeviceAPI* cpu_api = DeviceAPI::Get(ctx); DeviceAPI* cpu_api = DeviceAPI::Get(ctx);
void* workspace_buffer = void* workspace_buffer =
...@@ -140,7 +140,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra ...@@ -140,7 +140,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra
DLTensor *input = args[0]; DLTensor *input = args[0];
DLTensor *transformed_kernel = args[1]; DLTensor *transformed_kernel = args[1];
DLTensor *bias = nullptr; DLTensor *bias = nullptr;
if (args[2].type_code() == kArrayHandle) { if (args[2].type_code() == kTVMDLTensorHandle) {
bias = args[2]; bias = args[2];
} }
DLTensor *output = args[3]; DLTensor *output = args[3];
...@@ -199,7 +199,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra ...@@ -199,7 +199,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference_without_weight_tra
const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float); const size_t workspace_elements = (workspace_size + sizeof(float) - 1) / sizeof(float);
TVMContext ctx = input->ctx; TVMContext ctx = input->ctx;
TVMType type_hint = input->dtype; DLDataType type_hint = input->dtype;
DeviceAPI* cpu_api = DeviceAPI::Get(ctx); DeviceAPI* cpu_api = DeviceAPI::Get(ctx);
void* workspace_buffer = void* workspace_buffer =
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -182,8 +182,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") ...@@ -182,8 +182,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
CHECK_LT(axis, input->ndim) << "Axis out of boundary for " CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
"input ndim " << input->ndim; "input ndim " << input->ndim;
auto data_dtype = TVMType2String(input->dtype); auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = TVMType2String(output->dtype); auto out_dtype = DLDataType2String(output->dtype);
if (data_dtype == "float32") { if (data_dtype == "float32") {
if (out_dtype == "int32") { if (out_dtype == "int32") {
...@@ -333,8 +333,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk") ...@@ -333,8 +333,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.topk")
} }
CHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim; CHECK(axis >= 0 && axis < input->ndim) << "Axis out of boundary for input ndim " << input->ndim;
auto data_dtype = TVMType2String(input->dtype); auto data_dtype = DLDataType2String(input->dtype);
auto out_dtype = (indices_out == nullptr) ? "int64" : TVMType2String(indices_out->dtype); auto out_dtype = (indices_out == nullptr) ? "int64" : DLDataType2String(indices_out->dtype);
if (data_dtype == "float32") { if (data_dtype == "float32") {
if (out_dtype == "int32") { if (out_dtype == "int32") {
......
...@@ -45,7 +45,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -45,7 +45,7 @@ class CPUDeviceAPI final : public DeviceAPI {
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(TVMContext ctx,
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
TVMType type_hint) final { DLDataType type_hint) final {
void* ptr; void* ptr;
#if _MSC_VER #if _MSC_VER
ptr = _aligned_malloc(nbytes, alignment); ptr = _aligned_malloc(nbytes, alignment);
...@@ -76,7 +76,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -76,7 +76,7 @@ class CPUDeviceAPI final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint, DLDataType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
memcpy(static_cast<char*>(to) + to_offset, memcpy(static_cast<char*>(to) + to_offset,
static_cast<const char*>(from) + from_offset, static_cast<const char*>(from) + from_offset,
...@@ -86,7 +86,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -86,7 +86,7 @@ class CPUDeviceAPI final : public DeviceAPI {
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
} }
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final; void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final; void FreeWorkspace(TVMContext ctx, void* data) final;
static const std::shared_ptr<CPUDeviceAPI>& Global() { static const std::shared_ptr<CPUDeviceAPI>& Global() {
...@@ -103,7 +103,7 @@ struct CPUWorkspacePool : public WorkspacePool { ...@@ -103,7 +103,7 @@ struct CPUWorkspacePool : public WorkspacePool {
void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx,
size_t size, size_t size,
TVMType type_hint) { DLDataType type_hint) {
return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get() return dmlc::ThreadLocalStore<CPUWorkspacePool>::Get()
->AllocWorkspace(ctx, size); ->AllocWorkspace(ctx, size);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -111,7 +111,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -111,7 +111,7 @@ class CUDADeviceAPI final : public DeviceAPI {
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(TVMContext ctx,
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
TVMType type_hint) final { DLDataType type_hint) final {
CUDA_CALL(cudaSetDevice(ctx.device_id)); CUDA_CALL(cudaSetDevice(ctx.device_id));
CHECK_EQ(256 % alignment, 0U) CHECK_EQ(256 % alignment, 0U)
<< "CUDA space is aligned at 256 bytes"; << "CUDA space is aligned at 256 bytes";
...@@ -132,7 +132,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -132,7 +132,7 @@ class CUDADeviceAPI final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint, DLDataType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream); cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset; from = static_cast<const char*>(from) + from_offset;
...@@ -191,7 +191,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -191,7 +191,7 @@ class CUDADeviceAPI final : public DeviceAPI {
->stream = static_cast<cudaStream_t>(stream); ->stream = static_cast<cudaStream_t>(stream);
} }
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final { void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final {
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -34,7 +34,7 @@ namespace runtime { ...@@ -34,7 +34,7 @@ namespace runtime {
void FunctionInfo::Save(dmlc::JSONWriter* writer) const { void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
std::vector<std::string> sarg_types(arg_types.size()); std::vector<std::string> sarg_types(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
sarg_types[i] = TVMType2String(arg_types[i]); sarg_types[i] = DLDataType2String(arg_types[i]);
} }
writer->BeginObject(); writer->BeginObject();
writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("name", name);
...@@ -52,7 +52,7 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { ...@@ -52,7 +52,7 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
helper.ReadAllFields(reader); helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size()); arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
arg_types[i] = String2TVMType(sarg_types[i]); arg_types[i] = String2DLDataType(sarg_types[i]);
} }
} }
......
...@@ -176,7 +176,7 @@ PackedFunc GraphRuntimeDebug::GetFunction( ...@@ -176,7 +176,7 @@ PackedFunc GraphRuntimeDebug::GetFunction(
}); });
} else if (name == "debug_get_output") { } else if (name == "debug_get_output") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) { if (args[0].type_code() == kTVMStr) {
this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]); this->DebugGetNodeOutput(this->GetNodeIndex(args[0]), args[1]);
} else { } else {
this->DebugGetNodeOutput(args[0], args[1]); this->DebugGetNodeOutput(args[0], args[1]);
......
...@@ -250,9 +250,9 @@ void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) { ...@@ -250,9 +250,9 @@ void GraphRuntime::ShareParams(const GraphRuntime& other, dmlc::Stream* strm) {
void GraphRuntime::SetupStorage() { void GraphRuntime::SetupStorage() {
// Grab saved optimization plan from graph. // Grab saved optimization plan from graph.
std::vector<TVMType> vtype; std::vector<DLDataType> vtype;
for (const std::string& s_type : attrs_.dltype) { for (const std::string& s_type : attrs_.dltype) {
vtype.push_back(tvm::runtime::String2TVMType(s_type)); vtype.push_back(tvm::runtime::String2DLDataType(s_type));
} }
// Size and device type of each storage pool entry. // Size and device type of each storage pool entry.
...@@ -371,7 +371,7 @@ std::pair<std::function<void()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRu ...@@ -371,7 +371,7 @@ std::pair<std::function<void()>, std::shared_ptr<GraphRuntime::OpArgs> > GraphRu
DLTensor* t = &arg_ptr->args[i]; DLTensor* t = &arg_ptr->args[i];
v.v_handle = t; v.v_handle = t;
arg_ptr->arg_values.push_back(v); arg_ptr->arg_values.push_back(v);
arg_ptr->arg_tcodes.push_back(kArrayHandle); arg_ptr->arg_tcodes.push_back(kTVMDLTensorHandle);
if (param.flatten_data) { if (param.flatten_data) {
arg_ptr->shape_data[i] = std::accumulate( arg_ptr->shape_data[i] = std::accumulate(
t->shape, t->shape + t->ndim, 1, std::multiplies<int64_t>()); t->shape, t->shape + t->ndim, 1, std::multiplies<int64_t>());
...@@ -414,7 +414,7 @@ PackedFunc GraphRuntime::GetFunction( ...@@ -414,7 +414,7 @@ PackedFunc GraphRuntime::GetFunction(
// Return member functions during query. // Return member functions during query.
if (name == "set_input") { if (name == "set_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) { if (args[0].type_code() == kTVMStr) {
int in_idx = this->GetInputIndex(args[0]); int in_idx = this->GetInputIndex(args[0]);
if (in_idx >= 0) this->SetInput(in_idx, args[1]); if (in_idx >= 0) this->SetInput(in_idx, args[1]);
} else { } else {
...@@ -423,7 +423,7 @@ PackedFunc GraphRuntime::GetFunction( ...@@ -423,7 +423,7 @@ PackedFunc GraphRuntime::GetFunction(
}); });
} else if (name == "set_input_zero_copy") { } else if (name == "set_input_zero_copy") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
if (args[0].type_code() == kStr) { if (args[0].type_code() == kTVMStr) {
int in_idx = this->GetInputIndex(args[0]); int in_idx = this->GetInputIndex(args[0]);
if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]); if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]);
} else { } else {
...@@ -441,7 +441,7 @@ PackedFunc GraphRuntime::GetFunction( ...@@ -441,7 +441,7 @@ PackedFunc GraphRuntime::GetFunction(
} else if (name == "get_input") { } else if (name == "get_input") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int in_idx = 0; int in_idx = 0;
if (args[0].type_code() == kStr) { if (args[0].type_code() == kTVMStr) {
in_idx = this->GetInputIndex(args[0]); in_idx = this->GetInputIndex(args[0]);
} else { } else {
in_idx = args[0]; in_idx = args[0];
......
...@@ -81,7 +81,7 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, ...@@ -81,7 +81,7 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr,
const ObjectPtr<Object>& sptr_to_self) { const ObjectPtr<Object>& sptr_to_self) {
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
TVMValue ret_value; TVMValue ret_value;
int ret_type_code = kNull; int ret_type_code = kTVMNullptr;
int ret = (*faddr)( int ret = (*faddr)(
const_cast<TVMValue*>(args.values), const_cast<TVMValue*>(args.values),
const_cast<int*>(args.type_codes), const_cast<int*>(args.type_codes),
...@@ -89,7 +89,7 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, ...@@ -89,7 +89,7 @@ PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr,
&ret_value, &ret_value,
&ret_type_code); &ret_type_code);
CHECK_EQ(ret, 0) << TVMGetLastError(); CHECK_EQ(ret, 0) << TVMGetLastError();
if (ret_type_code != kNull) { if (ret_type_code != kTVMNullptr) {
*rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code);
} }
}); });
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -37,7 +37,7 @@ namespace runtime { ...@@ -37,7 +37,7 @@ namespace runtime {
/*! \brief function information needed by device */ /*! \brief function information needed by device */
struct FunctionInfo { struct FunctionInfo {
std::string name; std::string name;
std::vector<TVMType> arg_types; std::vector<DLDataType> arg_types;
std::vector<std::string> thread_axis_tags; std::vector<std::string> thread_axis_tags;
void Save(dmlc::JSONWriter *writer) const; void Save(dmlc::JSONWriter *writer) const;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -84,7 +84,7 @@ class MetalWorkspace final : public DeviceAPI { ...@@ -84,7 +84,7 @@ class MetalWorkspace final : public DeviceAPI {
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(TVMContext ctx,
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
TVMType type_hint) final; DLDataType type_hint) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final; void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from, void CopyDataFromTo(const void* from,
size_t from_size, size_t from_size,
...@@ -93,10 +93,10 @@ class MetalWorkspace final : public DeviceAPI { ...@@ -93,10 +93,10 @@ class MetalWorkspace final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint, DLDataType type_hint,
TVMStreamHandle stream) final; TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final; void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final; void FreeWorkspace(TVMContext ctx, void* data) final;
// get the global workspace // get the global workspace
static const std::shared_ptr<MetalWorkspace>& Global(); static const std::shared_ptr<MetalWorkspace>& Global();
......
...@@ -62,7 +62,7 @@ void MetalWorkspace::GetAttr( ...@@ -62,7 +62,7 @@ void MetalWorkspace::GetAttr(
case kMultiProcessorCount: return; case kMultiProcessorCount: return;
case kMaxThreadDimensions: return; case kMaxThreadDimensions: return;
case kExist: break; case kExist: break;
case kGcnArch: return; case kGcnArch: return;
} }
} }
...@@ -145,7 +145,7 @@ void MetalWorkspace::SetDevice(TVMContext ctx) { ...@@ -145,7 +145,7 @@ void MetalWorkspace::SetDevice(TVMContext ctx) {
} }
void* MetalWorkspace::AllocDataSpace( void* MetalWorkspace::AllocDataSpace(
TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) { TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) {
this->Init(); this->Init();
id<MTLDevice> dev = GetDevice(ctx); id<MTLDevice> dev = GetDevice(ctx);
// GPU memory only // GPU memory only
...@@ -176,7 +176,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from, ...@@ -176,7 +176,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint, DLDataType type_hint,
TVMStreamHandle stream) { TVMStreamHandle stream) {
this->Init(); this->Init();
CHECK(stream == nullptr); CHECK(stream == nullptr);
...@@ -261,7 +261,7 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) { ...@@ -261,7 +261,7 @@ void MetalWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
void* MetalWorkspace::AllocWorkspace(TVMContext ctx, void* MetalWorkspace::AllocWorkspace(TVMContext ctx,
size_t size, size_t size,
TVMType type_hint) { DLDataType type_hint) {
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
} }
......
...@@ -48,7 +48,7 @@ class MicroDeviceAPI final : public DeviceAPI { ...@@ -48,7 +48,7 @@ class MicroDeviceAPI final : public DeviceAPI {
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(TVMContext ctx,
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
TVMType type_hint) final { DLDataType type_hint) final {
ObjectPtr<MicroSession>& session = MicroSession::Current(); ObjectPtr<MicroSession>& session = MicroSession::Current();
void* data = session->AllocateInSection(SectionKind::kHeap, nbytes).cast_to<void*>(); void* data = session->AllocateInSection(SectionKind::kHeap, nbytes).cast_to<void*>();
CHECK(data != nullptr) << "unable to allocate " << nbytes << " bytes on device heap"; CHECK(data != nullptr) << "unable to allocate " << nbytes << " bytes on device heap";
...@@ -72,7 +72,7 @@ class MicroDeviceAPI final : public DeviceAPI { ...@@ -72,7 +72,7 @@ class MicroDeviceAPI final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint, DLDataType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
std::tuple<int, int> type_from_to(ctx_from.device_type, ctx_to.device_type); std::tuple<int, int> type_from_to(ctx_from.device_type, ctx_to.device_type);
if (type_from_to == std::make_tuple(kDLMicroDev, kDLMicroDev)) { if (type_from_to == std::make_tuple(kDLMicroDev, kDLMicroDev)) {
...@@ -123,7 +123,7 @@ class MicroDeviceAPI final : public DeviceAPI { ...@@ -123,7 +123,7 @@ class MicroDeviceAPI final : public DeviceAPI {
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
} }
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final { void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final {
ObjectPtr<MicroSession>& session = MicroSession::Current(); ObjectPtr<MicroSession>& session = MicroSession::Current();
void* data = session->AllocateInSection(SectionKind::kWorkspace, size).cast_to<void*>(); void* data = session->AllocateInSection(SectionKind::kWorkspace, size).cast_to<void*>();
......
...@@ -333,9 +333,9 @@ std::tuple<DevPtr, DevPtr> MicroSession::EncoderAppend( ...@@ -333,9 +333,9 @@ std::tuple<DevPtr, DevPtr> MicroSession::EncoderAppend(
for (int i = 0; i < num_args; i++) { for (int i = 0; i < num_args; i++) {
switch (type_codes[i]) { switch (type_codes[i]) {
case kNDArrayContainer: case kTVMNDArrayHandle:
case kArrayHandle: { case kTVMDLTensorHandle: {
TVMArray* base_arr_handle = args[i]; DLTensor* base_arr_handle = args[i];
// All uTVM arrays store a `MicroDevSpace` struct in their `data` field, // All uTVM arrays store a `MicroDevSpace` struct in their `data` field,
// which wraps the actual data and stores a reference to the session, in // which wraps the actual data and stores a reference to the session, in
// order to prevent premature session destruction. // order to prevent premature session destruction.
...@@ -371,7 +371,7 @@ std::tuple<DevPtr, DevPtr> MicroSession::EncoderAppend( ...@@ -371,7 +371,7 @@ std::tuple<DevPtr, DevPtr> MicroSession::EncoderAppend(
} }
template <typename T> template <typename T>
DevPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const TVMArray& arr) { DevPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr) {
auto tvm_arr_slot = encoder->Alloc<T>(); auto tvm_arr_slot = encoder->Alloc<T>();
auto shape_slot = encoder->Alloc<int64_t>(arr.ndim); auto shape_slot = encoder->Alloc<int64_t>(arr.ndim);
...@@ -396,7 +396,7 @@ DevPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const TVMAr ...@@ -396,7 +396,7 @@ DevPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const TVMAr
strides_dev_addr.value(), strides_dev_addr.value(),
TargetVal { .val64 = arr.byte_offset }); TargetVal { .val64 = arr.byte_offset });
CHECK(dev_arr.ctx.device_type == static_cast<DLDeviceType>(kDLMicroDev)) CHECK(dev_arr.ctx.device_type == static_cast<DLDeviceType>(kDLMicroDev))
<< "attempt to write TVMArray with non-micro device type"; << "attempt to write DLTensor with non-micro device type";
// Update the device type to CPU, because from the microcontroller's // Update the device type to CPU, because from the microcontroller's
// perspective, it is. // perspective, it is.
dev_arr.ctx.device_type = DLDeviceType::kDLCPU; dev_arr.ctx.device_type = DLDeviceType::kDLCPU;
......
...@@ -231,13 +231,13 @@ class MicroSession : public ModuleNode { ...@@ -231,13 +231,13 @@ class MicroSession : public ModuleNode {
std::tuple<DevPtr, DevPtr> EncoderAppend(TargetDataLayoutEncoder* encoder, const TVMArgs& args); std::tuple<DevPtr, DevPtr> EncoderAppend(TargetDataLayoutEncoder* encoder, const TVMArgs& args);
/*! /*!
* \brief appends a `TVMArray` to the host-side buffer of `encoder` * \brief appends a `DLTensor` to the host-side buffer of `encoder`
* \param encoder encoder being used to append `arr` * \param encoder encoder being used to append `arr`
* \param arr TVMArray to be appended * \param arr DLTensor to be appended
* \return device address of the allocated `TVMArray` * \return device address of the allocated `DLTensor`
*/ */
template <typename T> template <typename T>
DevPtr EncoderAppend(TargetDataLayoutEncoder* encoder, const TVMArray& arr); DevPtr EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr);
/*! /*!
* \brief checks and logs if there was an error during the device's most recent execution * \brief checks and logs if there was an error during the device's most recent execution
......
...@@ -324,7 +324,7 @@ std::function<void()> CreateTVMOp(const DSOModule& module, const TVMOpParam& par ...@@ -324,7 +324,7 @@ std::function<void()> CreateTVMOp(const DSOModule& module, const TVMOpParam& par
void* v_handle; void* v_handle;
} TVMValue; } TVMValue;
/*typedef*/ enum { /*typedef*/ enum {
kArrayHandle = 7U, kTVMDLTensorHandle = 7U,
} /*TVMTypeCode*/; } /*TVMTypeCode*/;
struct OpArgs { struct OpArgs {
DynArray<DLTensor> args; DynArray<DLTensor> args;
...@@ -345,7 +345,7 @@ std::function<void()> CreateTVMOp(const DSOModule& module, const TVMOpParam& par ...@@ -345,7 +345,7 @@ std::function<void()> CreateTVMOp(const DSOModule& module, const TVMOpParam& par
DLTensor* t = &(arg_ptr->args[i]); DLTensor* t = &(arg_ptr->args[i]);
v.v_handle = t; v.v_handle = t;
arg_ptr->arg_values[i] = v; arg_ptr->arg_values[i] = v;
arg_ptr->arg_tcodes[i] = kArrayHandle; arg_ptr->arg_tcodes[i] = kTVMDLTensorHandle;
if (param.flatten_data) { if (param.flatten_data) {
arg_ptr->shape_data[i] = arg_ptr->shape_data[i] =
std::accumulate(t->shape, t->shape + t->ndim, 1, std::multiplies<int64_t>()); std::accumulate(t->shape, t->shape + t->ndim, 1, std::multiplies<int64_t>());
......
...@@ -193,7 +193,7 @@ class OpenCLWorkspace : public DeviceAPI { ...@@ -193,7 +193,7 @@ class OpenCLWorkspace : public DeviceAPI {
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(TVMContext ctx,
size_t size, size_t size,
size_t alignment, size_t alignment,
TVMType type_hint) final; DLDataType type_hint) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final; void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from, void CopyDataFromTo(const void* from,
size_t from_offset, size_t from_offset,
...@@ -202,10 +202,10 @@ class OpenCLWorkspace : public DeviceAPI { ...@@ -202,10 +202,10 @@ class OpenCLWorkspace : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint, DLDataType type_hint,
TVMStreamHandle stream) final; TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final; void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final; void FreeWorkspace(TVMContext ctx, void* data) final;
/*! /*!
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -118,7 +118,7 @@ void OpenCLWorkspace::GetAttr( ...@@ -118,7 +118,7 @@ void OpenCLWorkspace::GetAttr(
} }
void* OpenCLWorkspace::AllocDataSpace( void* OpenCLWorkspace::AllocDataSpace(
TVMContext ctx, size_t size, size_t alignment, TVMType type_hint) { TVMContext ctx, size_t size, size_t alignment, DLDataType type_hint) {
this->Init(); this->Init();
CHECK(context != nullptr) << "No OpenCL device"; CHECK(context != nullptr) << "No OpenCL device";
cl_int err_code; cl_int err_code;
...@@ -144,7 +144,7 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from, ...@@ -144,7 +144,7 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from,
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint, DLDataType type_hint,
TVMStreamHandle stream) { TVMStreamHandle stream) {
this->Init(); this->Init();
CHECK(stream == nullptr); CHECK(stream == nullptr);
...@@ -182,7 +182,7 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) { ...@@ -182,7 +182,7 @@ void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx, void* OpenCLWorkspace::AllocWorkspace(TVMContext ctx,
size_t size, size_t size,
TVMType type_hint) { DLDataType type_hint) {
return GetThreadEntry()->pool.AllocWorkspace(ctx, size); return GetThreadEntry()->pool.AllocWorkspace(ctx, size);
} }
......
...@@ -131,9 +131,9 @@ PackedFunc OpenCLModuleNode::GetFunction( ...@@ -131,9 +131,9 @@ PackedFunc OpenCLModuleNode::GetFunction(
OpenCLWrappedFunc f; OpenCLWrappedFunc f;
std::vector<size_t> arg_size(info.arg_types.size()); std::vector<size_t> arg_size(info.arg_types.size());
for (size_t i = 0; i < info.arg_types.size(); ++i) { for (size_t i = 0; i < info.arg_types.size(); ++i) {
TVMType t = info.arg_types[i]; DLDataType t = info.arg_types[i];
CHECK_EQ(t.lanes, 1U); CHECK_EQ(t.lanes, 1U);
if (t.code == kHandle) { if (t.code == kTVMOpaqueHandle) {
// specially store pointer type size in OpenCL driver // specially store pointer type size in OpenCL driver
arg_size[i] = sizeof(void*); arg_size[i] = sizeof(void*);
} else { } else {
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -184,7 +184,7 @@ class OpenGLWorkspace final : public DeviceAPI { ...@@ -184,7 +184,7 @@ class OpenGLWorkspace final : public DeviceAPI {
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(TVMContext ctx,
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
TVMType type_hint) final; DLDataType type_hint) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final; void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from, void CopyDataFromTo(const void* from,
size_t from_offset, size_t from_offset,
...@@ -193,7 +193,7 @@ class OpenGLWorkspace final : public DeviceAPI { ...@@ -193,7 +193,7 @@ class OpenGLWorkspace final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint, DLDataType type_hint,
TVMStreamHandle stream) final; TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final; void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
...@@ -216,7 +216,7 @@ class OpenGLWorkspace final : public DeviceAPI { ...@@ -216,7 +216,7 @@ class OpenGLWorkspace final : public DeviceAPI {
* \param nbytes Number of bytes in the array. * \param nbytes Number of bytes in the array.
* \return The OpenGL texture. * \return The OpenGL texture.
*/ */
Texture CreateTexture(TVMType type, size_t nbytes); Texture CreateTexture(DLDataType type, size_t nbytes);
/*! /*!
* \brief Upload user data into a sub-region of an OpenGL texture. * \brief Upload user data into a sub-region of an OpenGL texture.
...@@ -256,7 +256,7 @@ class OpenGLWorkspace final : public DeviceAPI { ...@@ -256,7 +256,7 @@ class OpenGLWorkspace final : public DeviceAPI {
*/ */
void SetUniform(const Program& program, void SetUniform(const Program& program,
const std::string& name, const std::string& name,
TVMType type, DLDataType type,
void* value); void* value);
/*! /*!
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -121,7 +121,7 @@ void OpenGLWorkspace::GetAttr( ...@@ -121,7 +121,7 @@ void OpenGLWorkspace::GetAttr(
} }
void* OpenGLWorkspace::AllocDataSpace( void* OpenGLWorkspace::AllocDataSpace(
TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) { TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) {
return reinterpret_cast<void*>(new Texture(CreateTexture(type_hint, nbytes))); return reinterpret_cast<void*>(new Texture(CreateTexture(type_hint, nbytes)));
} }
...@@ -136,7 +136,7 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from, ...@@ -136,7 +136,7 @@ void OpenGLWorkspace::CopyDataFromTo(const void* from,
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint, DLDataType type_hint,
TVMStreamHandle stream) { TVMStreamHandle stream) {
CHECK(stream == nullptr); CHECK(stream == nullptr);
...@@ -312,7 +312,7 @@ GLuint OpenGLWorkspace::CreateShader(GLenum shader_kind, ...@@ -312,7 +312,7 @@ GLuint OpenGLWorkspace::CreateShader(GLenum shader_kind,
return shader; return shader;
} }
static TextureFormat GetTextureFormat(TVMType type) { static TextureFormat GetTextureFormat(DLDataType type) {
CHECK_EQ(type.lanes, 1) << "Not supporting multi-lane types."; CHECK_EQ(type.lanes, 1) << "Not supporting multi-lane types.";
switch (type.code) { switch (type.code) {
...@@ -355,7 +355,7 @@ static TextureFormat GetTextureFormat(TVMType type) { ...@@ -355,7 +355,7 @@ static TextureFormat GetTextureFormat(TVMType type) {
return {GL_R32F, GL_RED, GL_FLOAT}; return {GL_R32F, GL_RED, GL_FLOAT};
} }
Texture OpenGLWorkspace::CreateTexture(TVMType type, size_t nbytes) { Texture OpenGLWorkspace::CreateTexture(DLDataType type, size_t nbytes) {
// Create a texture. // Create a texture.
GLuint texture; GLuint texture;
OPENGL_CALL(gl->GenTextures(1, &texture)); OPENGL_CALL(gl->GenTextures(1, &texture));
...@@ -555,7 +555,7 @@ void OpenGLWorkspace::SetCurrentProgram(const Program& program) { ...@@ -555,7 +555,7 @@ void OpenGLWorkspace::SetCurrentProgram(const Program& program) {
void OpenGLWorkspace::SetUniform(const Program& program, void OpenGLWorkspace::SetUniform(const Program& program,
const std::string& name, const std::string& name,
TVMType type, DLDataType type,
void* value) { void* value) {
GLint location = gl->GetUniformLocation(program.program(), name.c_str()); GLint location = gl->GetUniformLocation(program.program(), name.c_str());
switch (type.code) { switch (type.code) {
......
...@@ -120,7 +120,7 @@ PackedFunc OpenGLModuleNode::GetFunction( ...@@ -120,7 +120,7 @@ PackedFunc OpenGLModuleNode::GetFunction(
std::vector<size_t> arg_size(func_info.arg_types.size()); std::vector<size_t> arg_size(func_info.arg_types.size());
for (size_t i = 0; i < func_info.arg_types.size(); ++i) { for (size_t i = 0; i < func_info.arg_types.size(); ++i) {
TVMType t = func_info.arg_types[i]; DLDataType t = func_info.arg_types[i];
CHECK_EQ(t.lanes, 1U); CHECK_EQ(t.lanes, 1U);
uint32_t bits = t.bits; uint32_t bits = t.bits;
CHECK_EQ(bits % 8, 0U); CHECK_EQ(bits % 8, 0U);
...@@ -222,14 +222,14 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, ...@@ -222,14 +222,14 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
break; break;
} }
case OpenGLArgKind::kInputTexture: { case OpenGLArgKind::kInputTexture: {
CHECK_EQ(type.code, kHandle) << "Type is not handle?"; CHECK_EQ(type.code, kTVMOpaqueHandle) << "Type is not handle?";
auto texture = *static_cast<gl::Texture**>(void_args[i]); auto texture = *static_cast<gl::Texture**>(void_args[i]);
m_->workspace().SetInputTexture(program, name, texture_unit, texture); m_->workspace().SetInputTexture(program, name, texture_unit, texture);
++texture_unit; ++texture_unit;
break; break;
} }
case OpenGLArgKind::kOutputTexture: { case OpenGLArgKind::kOutputTexture: {
CHECK_EQ(type.code, kHandle) << "Type is not handle?"; CHECK_EQ(type.code, kTVMOpaqueHandle) << "Type is not handle?";
CHECK(output == nullptr) << "Can only have one output texture."; CHECK(output == nullptr) << "Can only have one output texture.";
output = *static_cast<gl::Texture**>(void_args[i]); output = *static_cast<gl::Texture**>(void_args[i]);
break; break;
...@@ -241,7 +241,7 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, ...@@ -241,7 +241,7 @@ void OpenGLWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv,
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
std::unique_ptr<GLint> thread_extent(new GLint(wl.block_dim(0))); std::unique_ptr<GLint> thread_extent(new GLint(wl.block_dim(0)));
m_->workspace().SetUniform(program, shader.thread_extent_var, m_->workspace().SetUniform(program, shader.thread_extent_var,
TVMType{kDLInt, 32, 1}, DLDataType{kDLInt, 32, 1},
static_cast<void*>(thread_extent.get())); static_cast<void*>(thread_extent.get()));
m_->workspace().Render(output); m_->workspace().Render(output);
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -56,7 +56,7 @@ union ArgUnion { ...@@ -56,7 +56,7 @@ union ArgUnion {
* \return The wrapped packed function. * \return The wrapped packed function.
*/ */
template<typename F> template<typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types); inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DLDataType>& arg_types);
/*! /*!
* \brief Create a packed function that from function only packs buffer arguments. * \brief Create a packed function that from function only packs buffer arguments.
* *
...@@ -67,7 +67,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types); ...@@ -67,7 +67,7 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types);
* \return The wrapped packed function. * \return The wrapped packed function.
*/ */
template<typename F> template<typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types); inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DLDataType>& arg_types);
/*! /*!
* \brief Create a packed function that from function that takes a packed arguments. * \brief Create a packed function that from function that takes a packed arguments.
* *
...@@ -78,13 +78,13 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_type ...@@ -78,13 +78,13 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_type
* \return The wrapped packed function. * \return The wrapped packed function.
*/ */
template<typename F> template<typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types); inline PackedFunc PackFuncPackedArg(F f, const std::vector<DLDataType>& arg_types);
/*! /*!
* \brief Extract number of buffer argument from the argument types. * \brief Extract number of buffer argument from the argument types.
* \param arg_types The argument types. * \param arg_types The argument types.
* \return number of buffer arguments * \return number of buffer arguments
*/ */
inline size_t NumBufferArgs(const std::vector<TVMType>& arg_types); inline size_t NumBufferArgs(const std::vector<DLDataType>& arg_types);
// implementations details // implementations details
namespace detail { namespace detail {
...@@ -119,7 +119,7 @@ enum ArgConvertCode { ...@@ -119,7 +119,7 @@ enum ArgConvertCode {
HANDLE_TO_HANDLE HANDLE_TO_HANDLE
}; };
inline ArgConvertCode GetArgConvertCode(TVMType t) { inline ArgConvertCode GetArgConvertCode(DLDataType t) {
CHECK_EQ(t.lanes, 1U) CHECK_EQ(t.lanes, 1U)
<< "Cannot pass vector type argument to devic function for now"; << "Cannot pass vector type argument to devic function for now";
if (t.code == kDLInt) { if (t.code == kDLInt) {
...@@ -130,7 +130,7 @@ inline ArgConvertCode GetArgConvertCode(TVMType t) { ...@@ -130,7 +130,7 @@ inline ArgConvertCode GetArgConvertCode(TVMType t) {
} else if (t.code == kDLFloat) { } else if (t.code == kDLFloat) {
if (t.bits == 64U) return FLOAT64_TO_FLOAT64; if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
if (t.bits == 32U) return FLOAT64_TO_FLOAT32; if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
} else if (t.code == kHandle) { } else if (t.code == kTVMOpaqueHandle) {
return HANDLE_TO_HANDLE; return HANDLE_TO_HANDLE;
} }
LOG(FATAL) << "Cannot handle " << t << " as device function argument"; LOG(FATAL) << "Cannot handle " << t << " as device function argument";
...@@ -262,7 +262,7 @@ inline PackedFunc PackFuncPackedArg_( ...@@ -262,7 +262,7 @@ inline PackedFunc PackFuncPackedArg_(
} // namespace detail } // namespace detail
template<typename F> template<typename F>
inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types) { inline PackedFunc PackFuncVoidAddr(F f, const std::vector<DLDataType>& arg_types) {
std::vector<detail::ArgConvertCode> codes(arg_types.size()); std::vector<detail::ArgConvertCode> codes(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
codes[i] = detail::GetArgConvertCode(arg_types[i]); codes[i] = detail::GetArgConvertCode(arg_types[i]);
...@@ -278,22 +278,22 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types) { ...@@ -278,22 +278,22 @@ inline PackedFunc PackFuncVoidAddr(F f, const std::vector<TVMType>& arg_types) {
} }
} }
inline size_t NumBufferArgs(const std::vector<TVMType>& arg_types) { inline size_t NumBufferArgs(const std::vector<DLDataType>& arg_types) {
size_t base = arg_types.size(); size_t base = arg_types.size();
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
if (arg_types[i].code != kHandle) { if (arg_types[i].code != kTVMOpaqueHandle) {
base = i; break; base = i; break;
} }
} }
for (size_t i = base; i < arg_types.size(); ++i) { for (size_t i = base; i < arg_types.size(); ++i) {
CHECK(arg_types[i].code != kHandle) CHECK(arg_types[i].code != kTVMOpaqueHandle)
<< "Device function need to be organized"; << "Device function need to be organized";
} }
return base; return base;
} }
template<typename F> template<typename F>
inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_types) { inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<DLDataType>& arg_types) {
size_t num_buffer = NumBufferArgs(arg_types); size_t num_buffer = NumBufferArgs(arg_types);
std::vector<detail::ArgConvertCode> codes; std::vector<detail::ArgConvertCode> codes;
for (size_t i = num_buffer; i < arg_types.size(); ++i) { for (size_t i = num_buffer; i < arg_types.size(); ++i) {
...@@ -310,7 +310,7 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_type ...@@ -310,7 +310,7 @@ inline PackedFunc PackFuncNonBufferArg(F f, const std::vector<TVMType>& arg_type
} }
template<typename F> template<typename F>
inline PackedFunc PackFuncPackedArg(F f, const std::vector<TVMType>& arg_types) { inline PackedFunc PackFuncPackedArg(F f, const std::vector<DLDataType>& arg_types) {
std::vector<detail::ArgConvertCode> codes; std::vector<detail::ArgConvertCode> codes;
for (size_t i = 0; i < arg_types.size(); ++i) { for (size_t i = 0; i < arg_types.size(); ++i) {
codes.push_back(detail::GetArgConvertCode(arg_types[i])); codes.push_back(detail::GetArgConvertCode(arg_types[i]));
......
...@@ -119,7 +119,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -119,7 +119,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
*rv = value; *rv = value;
} }
void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment,
TVMType type_hint) final { DLDataType type_hint) final {
ROCM_CALL(hipSetDevice(ctx.device_id)); ROCM_CALL(hipSetDevice(ctx.device_id));
CHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes"; CHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes";
void* ret; void* ret;
...@@ -134,7 +134,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -134,7 +134,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
void CopyDataFromTo(const void* from, size_t from_offset, void* to, void CopyDataFromTo(const void* from, size_t from_offset, void* to,
size_t to_offset, size_t size, TVMContext ctx_from, size_t to_offset, size_t size, TVMContext ctx_from,
TVMContext ctx_to, TVMType type_hint, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
hipStream_t hip_stream = static_cast<hipStream_t>(stream); hipStream_t hip_stream = static_cast<hipStream_t>(stream);
from = static_cast<const char*>(from) + from_offset; from = static_cast<const char*>(from) + from_offset;
...@@ -169,7 +169,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -169,7 +169,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream); ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream);
} }
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final { void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final {
return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -41,7 +41,7 @@ class RPCDeviceAPI final : public DeviceAPI { ...@@ -41,7 +41,7 @@ class RPCDeviceAPI final : public DeviceAPI {
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(TVMContext ctx,
size_t nbytes, size_t nbytes,
size_t alignment, size_t alignment,
TVMType type_hint) final { DLDataType type_hint) final {
auto sess = GetSess(ctx); auto sess = GetSess(ctx);
void *data = sess->CallRemote( void *data = sess->CallRemote(
RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint); RPCCode::kDevAllocData, ctx, nbytes, alignment, type_hint);
...@@ -67,7 +67,7 @@ class RPCDeviceAPI final : public DeviceAPI { ...@@ -67,7 +67,7 @@ class RPCDeviceAPI final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint, DLDataType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
int from_dev_type = ctx_from.device_type; int from_dev_type = ctx_from.device_type;
int to_dev_type = ctx_to.device_type; int to_dev_type = ctx_to.device_type;
......
...@@ -187,7 +187,7 @@ class RPCModuleNode final : public ModuleNode { ...@@ -187,7 +187,7 @@ class RPCModuleNode final : public ModuleNode {
void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index, void* RPCWrappedFunc::UnwrapRemote(int rpc_sess_table_index,
const TVMArgValue& arg) { const TVMArgValue& arg) {
if (arg.type_code() == kModuleHandle) { if (arg.type_code() == kTVMModuleHandle) {
Module mod = arg; Module mod = arg;
std::string tkey = mod->type_key(); std::string tkey = mod->type_key();
CHECK_EQ(tkey, "rpc") CHECK_EQ(tkey, "rpc")
...@@ -211,15 +211,15 @@ void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess, ...@@ -211,15 +211,15 @@ void RPCWrappedFunc::WrapRemote(std::shared_ptr<RPCSession> sess,
int tcode = args.type_codes[0]; int tcode = args.type_codes[0];
if (handle == nullptr) return; if (handle == nullptr) return;
if (tcode == kFuncHandle) { if (tcode == kTVMPackedFuncHandle) {
auto wf = std::make_shared<RPCWrappedFunc>(handle, sess); auto wf = std::make_shared<RPCWrappedFunc>(handle, sess);
*rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) { *rv = PackedFunc([wf](TVMArgs args, TVMRetValue* rv) {
return wf->operator()(args, rv); return wf->operator()(args, rv);
}); });
} else if (tcode == kModuleHandle) { } else if (tcode == kTVMModuleHandle) {
auto n = make_object<RPCModuleNode>(handle, sess); auto n = make_object<RPCModuleNode>(handle, sess);
*rv = Module(n); *rv = Module(n);
} else if (tcode == kArrayHandle || tcode == kNDArrayContainer) { } else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) {
CHECK_EQ(args.size(), 2); CHECK_EQ(args.size(), 2);
DLTensor* tensor = args[0]; DLTensor* tensor = args[0];
void* nd_handle = args[1]; void* nd_handle = args[1];
......
...@@ -178,7 +178,7 @@ class RPCSession { ...@@ -178,7 +178,7 @@ class RPCSession {
size_t to_offset, size_t to_offset,
size_t nbytes, size_t nbytes,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint); DLDataType type_hint);
/*! /*!
* \brief Copy bytes from remote array content. * \brief Copy bytes from remote array content.
* \param from The source host data. * \param from The source host data.
...@@ -195,7 +195,7 @@ class RPCSession { ...@@ -195,7 +195,7 @@ class RPCSession {
size_t to_offset, size_t to_offset,
size_t nbytes, size_t nbytes,
TVMContext ctx_from, TVMContext ctx_from,
TVMType type_hint); DLDataType type_hint);
/*! /*!
* \brief Get a remote timer function on ctx. * \brief Get a remote timer function on ctx.
* This function consumes fhandle, caller should not call Free on fhandle. * This function consumes fhandle, caller should not call Free on fhandle.
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -54,10 +54,10 @@ void tvm_ecall_packed_func(int func_id, ...@@ -54,10 +54,10 @@ void tvm_ecall_packed_func(int func_id,
f->CallPacked(TVMArgs(arg_values, type_codes, num_args), &rv); f->CallPacked(TVMArgs(arg_values, type_codes, num_args), &rv);
int ret_type_code = rv.type_code(); int ret_type_code = rv.type_code();
if (ret_type_code == kNull) return; if (ret_type_code == kTVMNullptr) return;
TVMValue ret_value; TVMValue ret_value;
if (ret_type_code == kBytes || ret_type_code == kStr) { if (ret_type_code == kTVMBytes || ret_type_code == kTVMStr) {
// allocate a buffer in untrusted, copy the values in // allocate a buffer in untrusted, copy the values in
std::string bytes = rv; std::string bytes = rv;
...@@ -73,7 +73,7 @@ void tvm_ecall_packed_func(int func_id, ...@@ -73,7 +73,7 @@ void tvm_ecall_packed_func(int func_id,
arr->size = bytes.size(); arr->size = bytes.size();
ret_value = TVMValue{.v_handle = arr}; ret_value = TVMValue{.v_handle = arr};
ret_type_code = kBytes; ret_type_code = kTVMBytes;
} else { } else {
rv.MoveToCHost(&ret_value, &ret_type_code); rv.MoveToCHost(&ret_value, &ret_type_code);
} }
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -187,8 +187,8 @@ TVM_REGISTER_GLOBAL("__sgx_println__") ...@@ -187,8 +187,8 @@ TVM_REGISTER_GLOBAL("__sgx_println__")
case kDLInt: msg << static_cast<int64_t>(args[i]); break; case kDLInt: msg << static_cast<int64_t>(args[i]); break;
case kDLUInt: msg << static_cast<uint64_t>(args[i]); break; case kDLUInt: msg << static_cast<uint64_t>(args[i]); break;
case kDLFloat: msg << static_cast<double>(args[i]); break; case kDLFloat: msg << static_cast<double>(args[i]); break;
case kStr: case kTVMStr:
case kBytes: { case kTVMBytes: {
std::string val = args[i]; std::string val = args[i];
msg << val; msg << val;
} }
......
...@@ -395,7 +395,7 @@ void StackVM::Run(State* s) const { ...@@ -395,7 +395,7 @@ void StackVM::Run(State* s) const {
using namespace ir; using namespace ir;
int index = code[pc + 1].v_int; int index = code[pc + 1].v_int;
int kind = code[pc + 2].v_int; int kind = code[pc + 2].v_int;
TVMArray* arr = static_cast<TVMArray*>(stack[sp].v_handle); DLTensor* arr = static_cast<DLTensor*>(stack[sp].v_handle);
switch (kind) { switch (kind) {
case intrinsic::kArrData: { case intrinsic::kArrData: {
stack[sp].v_handle = arr[index].data; break; stack[sp].v_handle = arr[index].data; break;
...@@ -447,7 +447,7 @@ void StackVM::Run(State* s) const { ...@@ -447,7 +447,7 @@ void StackVM::Run(State* s) const {
using namespace ir; using namespace ir;
int index = code[pc + 1].v_int; int index = code[pc + 1].v_int;
int kind = code[pc + 2].v_int; int kind = code[pc + 2].v_int;
TVMArray* arr = static_cast<TVMArray*>(stack[sp - 1].v_handle); DLTensor* arr = static_cast<DLTensor*>(stack[sp - 1].v_handle);
switch (kind) { switch (kind) {
case intrinsic::kArrData: { case intrinsic::kArrData: {
arr[index].data = stack[sp].v_handle; break; arr[index].data = stack[sp].v_handle; break;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -358,9 +358,9 @@ class StackVM { ...@@ -358,9 +358,9 @@ class StackVM {
* \param t the type code. * \param t the type code.
* \return The load opcode * \return The load opcode
*/ */
static OpCode GetLoad(TVMType t) { static OpCode GetLoad(DLDataType t) {
CHECK_EQ(t.lanes, 1U); CHECK_EQ(t.lanes, 1U);
if (t.code == kHandle) return ARRAY_LOAD_HANDLE; if (t.code == kTVMOpaqueHandle) return ARRAY_LOAD_HANDLE;
if (t.code == kDLInt) { if (t.code == kDLInt) {
switch (t.bits) { switch (t.bits) {
case 32 : return ARRAY_LOAD_INT32; case 32 : return ARRAY_LOAD_INT32;
...@@ -383,9 +383,9 @@ class StackVM { ...@@ -383,9 +383,9 @@ class StackVM {
* \param t the type code. * \param t the type code.
* \return The load opcode * \return The load opcode
*/ */
static OpCode GetStore(TVMType t) { static OpCode GetStore(DLDataType t) {
CHECK_EQ(t.lanes, 1U); CHECK_EQ(t.lanes, 1U);
if (t.code == kHandle) return ARRAY_STORE_HANDLE; if (t.code == kTVMOpaqueHandle) return ARRAY_STORE_HANDLE;
if (t.code == kDLInt) { if (t.code == kDLInt) {
switch (t.bits) { switch (t.bits) {
case 32 : return ARRAY_STORE_INT32; case 32 : return ARRAY_STORE_INT32;
......
...@@ -82,7 +82,7 @@ class Allocator { ...@@ -82,7 +82,7 @@ class Allocator {
* \param type_hint A type hint to the allocator. * \param type_hint A type hint to the allocator.
* \return A sized allocation in the form of a buffer. * \return A sized allocation in the form of a buffer.
*/ */
virtual Buffer Alloc(size_t nbytes, size_t alignment, TVMType type_hint) = 0; virtual Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) = 0;
/*! \brief Free a buffer allocated by the allocator. /*! \brief Free a buffer allocated by the allocator.
* \param buffer The buffer to free. * \param buffer The buffer to free.
*/ */
......
...@@ -36,7 +36,7 @@ class NaiveAllocator final : public Allocator { ...@@ -36,7 +36,7 @@ class NaiveAllocator final : public Allocator {
public: public:
explicit NaiveAllocator(TVMContext ctx) : Allocator(), used_memory_(0), ctx_(ctx) {} explicit NaiveAllocator(TVMContext ctx) : Allocator(), used_memory_(0), ctx_(ctx) {}
Buffer Alloc(size_t nbytes, size_t alignment, TVMType type_hint) override { Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override {
Buffer buf; Buffer buf;
buf.ctx = ctx_; buf.ctx = ctx_;
buf.size = nbytes; buf.size = nbytes;
......
...@@ -44,7 +44,7 @@ class PooledAllocator final : public Allocator { ...@@ -44,7 +44,7 @@ class PooledAllocator final : public Allocator {
~PooledAllocator() { ReleaseAll(); } ~PooledAllocator() { ReleaseAll(); }
Buffer Alloc(size_t nbytes, size_t alignment, TVMType type_hint) override { Buffer Alloc(size_t nbytes, size_t alignment, DLDataType type_hint) override {
std::lock_guard<std::mutex> lock(mu_); std::lock_guard<std::mutex> lock(mu_);
size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_; size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_;
auto&& it = memory_pool_.find(size); auto&& it = memory_pool_.find(size);
......
...@@ -46,7 +46,7 @@ namespace runtime { ...@@ -46,7 +46,7 @@ namespace runtime {
namespace vm { namespace vm {
inline Storage make_storage(size_t size, size_t alignment, TVMType dtype_hint, TVMContext ctx) { inline Storage make_storage(size_t size, size_t alignment, DLDataType dtype_hint, TVMContext ctx) {
// We could put cache in here, from ctx to storage allocator. // We could put cache in here, from ctx to storage allocator.
auto storage_obj = SimpleObjAllocator().make_object<StorageObj>(); auto storage_obj = SimpleObjAllocator().make_object<StorageObj>();
auto alloc = MemoryManager::Global()->GetAllocator(ctx); auto alloc = MemoryManager::Global()->GetAllocator(ctx);
...@@ -336,7 +336,7 @@ Instruction Instruction::AllocTensorReg( ...@@ -336,7 +336,7 @@ Instruction Instruction::AllocTensorReg(
Instruction Instruction::AllocStorage(RegName size, Instruction Instruction::AllocStorage(RegName size,
Index alignment, Index alignment,
TVMType dtype_hint, DLDataType dtype_hint,
Index dst) { Index dst) {
Instruction instr; Instruction instr;
instr.op = Opcode::AllocStorage; instr.op = Opcode::AllocStorage;
...@@ -587,7 +587,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { ...@@ -587,7 +587,7 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) {
instr.dst << " $" << instr.dst << " $" <<
instr.alloc_storage.allocation_size << " $" << instr.alloc_storage.allocation_size << " $" <<
instr.alloc_storage.alignment << " " << instr.alloc_storage.alignment << " " <<
TVMType2String(instr.alloc_storage.dtype_hint); DLDataType2String(instr.alloc_storage.dtype_hint);
break; break;
} }
default: default:
...@@ -1019,7 +1019,7 @@ void VirtualMachine::RunLoop() { ...@@ -1019,7 +1019,7 @@ void VirtualMachine::RunLoop() {
DLOG(INFO) << DLOG(INFO) <<
"AllocStorage: allocation_size=" << size << "AllocStorage: allocation_size=" << size <<
"alignment=" << alignment << "alignment=" << alignment <<
"dtype_hint=" << TVMType2String(instr.alloc_storage.dtype_hint); "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint);
auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs_[0]); auto storage = make_storage(size, alignment, instr.alloc_storage.dtype_hint, ctxs_[0]);
WriteRegister(instr.dst, storage); WriteRegister(instr.dst, storage);
......
...@@ -117,7 +117,10 @@ class VulkanDeviceAPI final : public DeviceAPI { ...@@ -117,7 +117,10 @@ class VulkanDeviceAPI final : public DeviceAPI {
} }
void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; } void SetDevice(TVMContext ctx) final { VulkanThreadEntry::ThreadLocal()->ctx = ctx; }
void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final; void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final;
void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) final { void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
DLDataType type_hint) final {
const auto& vctx = context(ctx.device_id); const auto& vctx = context(ctx.device_id);
VkBufferCreateInfo info; VkBufferCreateInfo info;
info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
...@@ -194,7 +197,7 @@ class VulkanDeviceAPI final : public DeviceAPI { ...@@ -194,7 +197,7 @@ class VulkanDeviceAPI final : public DeviceAPI {
} }
void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size,
TVMContext ctx_from, TVMContext ctx_to, TVMType type_hint, TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
CHECK(stream == nullptr); CHECK(stream == nullptr);
TVMContext ctx = ctx_from; TVMContext ctx = ctx_from;
...@@ -327,7 +330,7 @@ class VulkanDeviceAPI final : public DeviceAPI { ...@@ -327,7 +330,7 @@ class VulkanDeviceAPI final : public DeviceAPI {
return; return;
} }
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final { void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final {
return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); return VulkanThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
} }
...@@ -772,8 +775,8 @@ class VulkanModuleNode final : public runtime::ModuleNode { ...@@ -772,8 +775,8 @@ class VulkanModuleNode final : public runtime::ModuleNode {
{ {
auto fit = fmap_.find(func_name); auto fit = fmap_.find(func_name);
CHECK(fit != fmap_.end()); CHECK(fit != fmap_.end());
for (TVMType arg_type : fit->second.arg_types) { for (DLDataType arg_type : fit->second.arg_types) {
if (arg_type.code == kHandle) { if (arg_type.code == kTVMOpaqueHandle) {
{ {
VkDescriptorSetLayoutBinding bd; VkDescriptorSetLayoutBinding bd;
bd.binding = num_buffer; bd.binding = num_buffer;
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -47,7 +47,7 @@ class WorkspacePool::Pool { ...@@ -47,7 +47,7 @@ class WorkspacePool::Pool {
nbytes = (nbytes + (kWorkspacePageSize - 1)) / kWorkspacePageSize * kWorkspacePageSize; nbytes = (nbytes + (kWorkspacePageSize - 1)) / kWorkspacePageSize * kWorkspacePageSize;
if (nbytes == 0) nbytes = kWorkspacePageSize; if (nbytes == 0) nbytes = kWorkspacePageSize;
Entry e; Entry e;
TVMType type; DLDataType type;
type.code = kDLUInt; type.code = kDLUInt;
type.bits = 8; type.bits = 8;
type.lanes = 1; type.lanes = 1;
......
...@@ -30,16 +30,16 @@ TEST(PackedFunc, Basic) { ...@@ -30,16 +30,16 @@ TEST(PackedFunc, Basic) {
using namespace tvm::runtime; using namespace tvm::runtime;
int x = 0; int x = 0;
void* handle = &x; void* handle = &x;
TVMArray a; DLTensor a;
Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK(args.num_args == 3); CHECK(args.num_args == 3);
CHECK(args.values[0].v_float64 == 1.0); CHECK(args.values[0].v_float64 == 1.0);
CHECK(args.type_codes[0] == kDLFloat); CHECK(args.type_codes[0] == kDLFloat);
CHECK(args.values[1].v_handle == &a); CHECK(args.values[1].v_handle == &a);
CHECK(args.type_codes[1] == kArrayHandle); CHECK(args.type_codes[1] == kTVMDLTensorHandle);
CHECK(args.values[2].v_handle == &x); CHECK(args.values[2].v_handle == &x);
CHECK(args.type_codes[2] == kHandle); CHECK(args.type_codes[2] == kTVMOpaqueHandle);
*rv = Var("a"); *rv = Var("a");
})(1.0, &a, handle); })(1.0, &a, handle);
CHECK(v->name_hint == "a"); CHECK(v->name_hint == "a");
...@@ -51,7 +51,7 @@ TEST(PackedFunc, Node) { ...@@ -51,7 +51,7 @@ TEST(PackedFunc, Node) {
Var x; Var x;
Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { Var t = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK(args.num_args == 1); CHECK(args.num_args == 1);
CHECK(args.type_codes[0] == kObjectHandle); CHECK(args.type_codes[0] == kTVMObjectHandle);
Var b = args[0]; Var b = args[0];
CHECK(x.same_as(b)); CHECK(x.same_as(b));
*rv = b; *rv = b;
...@@ -63,7 +63,7 @@ TEST(PackedFunc, NDArray) { ...@@ -63,7 +63,7 @@ TEST(PackedFunc, NDArray) {
using namespace tvm; using namespace tvm;
using namespace tvm::runtime; using namespace tvm::runtime;
auto x = NDArray::Empty( auto x = NDArray::Empty(
{}, String2TVMType("float32"), {}, String2DLDataType("float32"),
TVMContext{kDLCPU, 0}); TVMContext{kDLCPU, 0});
reinterpret_cast<float*>(x->data)[0] = 10.0f; reinterpret_cast<float*>(x->data)[0] = 10.0f;
CHECK(x.use_count() == 1); CHECK(x.use_count() == 1);
...@@ -203,25 +203,25 @@ TEST(PackedFunc, ObjectConversion) { ...@@ -203,25 +203,25 @@ TEST(PackedFunc, ObjectConversion) {
using namespace tvm::runtime; using namespace tvm::runtime;
TVMRetValue rv; TVMRetValue rv;
auto x = NDArray::Empty( auto x = NDArray::Empty(
{}, String2TVMType("float32"), {}, String2DLDataType("float32"),
TVMContext{kDLCPU, 0}); TVMContext{kDLCPU, 0});
// assign null // assign null
rv = ObjectRef(); rv = ObjectRef();
CHECK_EQ(rv.type_code(), kNull); CHECK_EQ(rv.type_code(), kTVMNullptr);
// Can assign NDArray to ret type // Can assign NDArray to ret type
rv = x; rv = x;
CHECK_EQ(rv.type_code(), kNDArrayContainer); CHECK_EQ(rv.type_code(), kTVMNDArrayHandle);
// Even if we assign base type it still shows as NDArray // Even if we assign base type it still shows as NDArray
rv = ObjectRef(x); rv = ObjectRef(x);
CHECK_EQ(rv.type_code(), kNDArrayContainer); CHECK_EQ(rv.type_code(), kTVMNDArrayHandle);
// Check convert back // Check convert back
CHECK(rv.operator NDArray().same_as(x)); CHECK(rv.operator NDArray().same_as(x));
CHECK(rv.operator ObjectRef().same_as(x)); CHECK(rv.operator ObjectRef().same_as(x));
CHECK(!rv.IsObjectRef<PrimExpr>()); CHECK(!rv.IsObjectRef<PrimExpr>());
auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { auto pf1 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args[0].type_code(), kNDArrayContainer); CHECK_EQ(args[0].type_code(), kTVMNDArrayHandle);
CHECK(args[0].operator NDArray().same_as(x)); CHECK(args[0].operator NDArray().same_as(x));
CHECK(args[0].operator ObjectRef().same_as(x)); CHECK(args[0].operator ObjectRef().same_as(x));
CHECK(args[1].operator ObjectRef().get() == nullptr); CHECK(args[1].operator ObjectRef().get() == nullptr);
...@@ -238,17 +238,17 @@ TEST(PackedFunc, ObjectConversion) { ...@@ -238,17 +238,17 @@ TEST(PackedFunc, ObjectConversion) {
CHECK(pf != nullptr); CHECK(pf != nullptr);
Module m = (*pf)("", "xyz"); Module m = (*pf)("", "xyz");
rv = m; rv = m;
CHECK_EQ(rv.type_code(), kModuleHandle); CHECK_EQ(rv.type_code(), kTVMModuleHandle);
// Even if we assign base type it still shows as NDArray // Even if we assign base type it still shows as NDArray
rv = ObjectRef(m); rv = ObjectRef(m);
CHECK_EQ(rv.type_code(), kModuleHandle); CHECK_EQ(rv.type_code(), kTVMModuleHandle);
// Check convert back // Check convert back
CHECK(rv.operator Module().same_as(m)); CHECK(rv.operator Module().same_as(m));
CHECK(rv.operator ObjectRef().same_as(m)); CHECK(rv.operator ObjectRef().same_as(m));
CHECK(!rv.IsObjectRef<NDArray>()); CHECK(!rv.IsObjectRef<NDArray>());
auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { auto pf2 = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args[0].type_code(), kModuleHandle); CHECK_EQ(args[0].type_code(), kTVMModuleHandle);
CHECK(args[0].operator Module().same_as(m)); CHECK(args[0].operator Module().same_as(m));
CHECK(args[0].operator ObjectRef().same_as(m)); CHECK(args[0].operator ObjectRef().same_as(m));
CHECK(args[1].operator ObjectRef().get() == nullptr); CHECK(args[1].operator ObjectRef().get() == nullptr);
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
@tvm.register_extension @tvm.register_extension
class MyTensorView(object): class MyTensorView(object):
_tvm_tcode = tvm.TypeCode.ARRAY_HANDLE _tvm_tcode = tvm.TypeCode.DLTENSOR_HANDLE
def __init__(self, arr): def __init__(self, arr):
self.arr = arr self.arr = arr
......
...@@ -91,7 +91,7 @@ Array<Integer> ArrayOrInt(TVMArgValue arg) { ...@@ -91,7 +91,7 @@ Array<Integer> ArrayOrInt(TVMArgValue arg) {
} }
inline bool IsTensorType(TVMArgValue arg) { inline bool IsTensorType(TVMArgValue arg) {
return (arg.type_code() == kObjectHandle && return (arg.type_code() == kTVMObjectHandle &&
static_cast<Object*>( static_cast<Object*>(
arg.value().v_handle)->IsInstance<tvm::TensorNode>()); arg.value().v_handle)->IsInstance<tvm::TensorNode>());
} }
......
...@@ -45,7 +45,7 @@ class VTADeviceAPI final : public DeviceAPI { ...@@ -45,7 +45,7 @@ class VTADeviceAPI final : public DeviceAPI {
void* AllocDataSpace(TVMContext ctx, void* AllocDataSpace(TVMContext ctx,
size_t size, size_t size,
size_t alignment, size_t alignment,
TVMType type_hint) final { DLDataType type_hint) final {
return VTABufferAlloc(size); return VTABufferAlloc(size);
} }
...@@ -60,7 +60,7 @@ class VTADeviceAPI final : public DeviceAPI { ...@@ -60,7 +60,7 @@ class VTADeviceAPI final : public DeviceAPI {
size_t size, size_t size,
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMType type_hint, DLDataType type_hint,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
int kind_mask = 0; int kind_mask = 0;
if (ctx_from.device_type != kDLCPU) { if (ctx_from.device_type != kDLCPU) {
...@@ -77,7 +77,7 @@ class VTADeviceAPI final : public DeviceAPI { ...@@ -77,7 +77,7 @@ class VTADeviceAPI final : public DeviceAPI {
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
} }
void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final; void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final;
void FreeWorkspace(TVMContext ctx, void* data) final; void FreeWorkspace(TVMContext ctx, void* data) final;
...@@ -93,7 +93,7 @@ struct VTAWorkspacePool : public WorkspacePool { ...@@ -93,7 +93,7 @@ struct VTAWorkspacePool : public WorkspacePool {
WorkspacePool(kDLExtDev, VTADeviceAPI::Global()) {} WorkspacePool(kDLExtDev, VTADeviceAPI::Global()) {}
}; };
void* VTADeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) { void* VTADeviceAPI::AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) {
return dmlc::ThreadLocalStore<VTAWorkspacePool>::Get() return dmlc::ThreadLocalStore<VTAWorkspacePool>::Get()
->AllocWorkspace(ctx, size); ->AllocWorkspace(ctx, size);
} }
......
...@@ -95,16 +95,16 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -95,16 +95,16 @@ var tvm_runtime = tvm_runtime || {};
var kInt = 0; var kInt = 0;
var kUInt = 1; var kUInt = 1;
var kFloat = 2; var kFloat = 2;
var kHandle = 3; var kTVMOpaqueHandle = 3;
var kNull = 4; var kNull = 4;
var kTVMType = 5; var kTVMDataType = 5;
var kTVMContext = 6; var kTVMContext = 6;
var kArrayHandle = 7; var kTVMDLTensorHandle = 7;
var kObjectHandle = 8; var kTVMObjectHandle = 8;
var kModuleHandle = 9; var kTVMModuleHandle = 9;
var kFuncHandle = 10; var kTVMPackedFuncHandle = 10;
var kStr = 11; var kTVMStr = 11;
var kBytes = 12; var kTVMBytes = 12;
//----------------------------------------- //-----------------------------------------
// TVM CWrap library // TVM CWrap library
// ---------------------------------------- // ----------------------------------------
...@@ -427,7 +427,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -427,7 +427,7 @@ var tvm_runtime = tvm_runtime || {};
code = kUInt; code = kUInt;
} else if (pattern.substring(0, 6) == "handle") { } else if (pattern.substring(0, 6) == "handle") {
pattern = pattern.substring(5, pattern.length); pattern = pattern.substring(5, pattern.length);
code = kHandle; code = kTVMOpaqueHandle;
bits = 64; bits = 64;
} else { } else {
throw throwError("Unknown dtype " + dtype); throw throwError("Unknown dtype " + dtype);
...@@ -453,11 +453,11 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -453,11 +453,11 @@ var tvm_runtime = tvm_runtime || {};
case kInt: case kInt:
case kUInt: return Module.getValue(vptr, "i64"); case kUInt: return Module.getValue(vptr, "i64");
case kFloat: return Module.getValue(vptr, "double"); case kFloat: return Module.getValue(vptr, "double");
case kFuncHandle: return makeTVMFunction(Module.getValue(vptr, "*")); case kTVMPackedFuncHandle: return makeTVMFunction(Module.getValue(vptr, "*"));
case kModuleHandle: return new TVMModule(Module.getValue(vptr, "*")); case kTVMModuleHandle: return new TVMModule(Module.getValue(vptr, "*"));
case kNull: return null; case kNull: return null;
case kStr: return CStringToJS(Module.getValue(vptr, "*")); case kTVMStr: return CStringToJS(Module.getValue(vptr, "*"));
case kBytes: return CBytesToJS(Module.getValue(vptr, "*")); case kTVMBytes: return CBytesToJS(Module.getValue(vptr, "*"));
default: throwError("Unsupported return type code=" + tcode); default: throwError("Unsupported return type code=" + tcode);
} }
} }
...@@ -497,9 +497,9 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -497,9 +497,9 @@ var tvm_runtime = tvm_runtime || {};
for (var i = 0; i < nargs; ++i) { for (var i = 0; i < nargs; ++i) {
var vptr = arg_value + i * SIZEOF_TVMVALUE; var vptr = arg_value + i * SIZEOF_TVMVALUE;
var tcode = Module.getValue(arg_tcode + i * SIZEOF_INT, "i32"); var tcode = Module.getValue(arg_tcode + i * SIZEOF_INT, "i32");
if (tcode == kObjectHandle || if (tcode == kTVMObjectHandle ||
tcode == kFuncHandle || tcode == kTVMPackedFuncHandle ||
tcode == kModuleHandle) { tcode == kTVMModuleHandle) {
TVM_CALL(TVMCbArgToReturn(vptr, tcode)); TVM_CALL(TVMCbArgToReturn(vptr, tcode));
} }
args.push(TVMRetValueToJS(vptr, tcode)); args.push(TVMRetValueToJS(vptr, tcode));
...@@ -630,7 +630,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -630,7 +630,7 @@ var tvm_runtime = tvm_runtime || {};
var sdata = new CBuffer(value.length + 1); var sdata = new CBuffer(value.length + 1);
Module.HEAPU8.set(StringToUint8Array(value), sdata.data); Module.HEAPU8.set(StringToUint8Array(value), sdata.data);
this.temp.push(sdata); this.temp.push(sdata);
Module.setValue(this.tcode + index * SIZEOF_INT, kStr, "i32"); Module.setValue(this.tcode + index * SIZEOF_INT, kTVMStr, "i32");
Module.setValue(this.value + index * SIZEOF_TVMVALUE, sdata.data, "*"); Module.setValue(this.value + index * SIZEOF_TVMVALUE, sdata.data, "*");
}, },
setBytes : function(index, value) { setBytes : function(index, value) {
...@@ -642,7 +642,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -642,7 +642,7 @@ var tvm_runtime = tvm_runtime || {};
Module.setValue(sheader.data + SIZEOF_POINTER, value.length, "i32"); Module.setValue(sheader.data + SIZEOF_POINTER, value.length, "i32");
this.temp.push(sdata); this.temp.push(sdata);
this.temp.push(sheader); this.temp.push(sheader);
Module.setValue(this.tcode + index * SIZEOF_INT, kBytes, "i32"); Module.setValue(this.tcode + index * SIZEOF_INT, kTVMBytes, "i32");
Module.setValue(this.value + index * SIZEOF_TVMVALUE, sheader.data, "*"); Module.setValue(this.value + index * SIZEOF_TVMVALUE, sheader.data, "*");
}, },
setArguments : function(args) { setArguments : function(args) {
...@@ -650,7 +650,7 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -650,7 +650,7 @@ var tvm_runtime = tvm_runtime || {};
var v = args[i]; var v = args[i];
var tp = typeof v; var tp = typeof v;
if (v instanceof NDArray) { if (v instanceof NDArray) {
this.setHandle(i, v.handle, kArrayHandle); this.setHandle(i, v.handle, kTVMDLTensorHandle);
} else if (v instanceof TVMConstant) { } else if (v instanceof TVMConstant) {
var code = getTVMType(v.dtype).code; var code = getTVMType(v.dtype).code;
if (code == kInt || code == kUInt) { if (code == kInt || code == kUInt) {
...@@ -658,13 +658,13 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -658,13 +658,13 @@ var tvm_runtime = tvm_runtime || {};
} else if (code == kFloat) { } else if (code == kFloat) {
this.setDouble(i, v.value); this.setDouble(i, v.value);
} else { } else {
CHECK(code == kHandle); CHECK(code == kTVMOpaqueHandle);
this.setHandle(i, v.value, kHandle); this.setHandle(i, v.value, kTVMOpaqueHandle);
} }
} else if (tp == "number") { } else if (tp == "number") {
this.setDouble(i, v); this.setDouble(i, v);
} else if (tp == "function" && v.hasOwnProperty("_tvm_function")) { } else if (tp == "function" && v.hasOwnProperty("_tvm_function")) {
this.setString(i, v._tvm_function.handle, kFuncHandle); this.setString(i, v._tvm_function.handle, kTVMPackedFuncHandle);
} else if (v === null) { } else if (v === null) {
this.setHandle(i, 0, kNull); this.setHandle(i, 0, kNull);
} else if (tp == "string") { } else if (tp == "string") {
...@@ -674,9 +674,9 @@ var tvm_runtime = tvm_runtime || {}; ...@@ -674,9 +674,9 @@ var tvm_runtime = tvm_runtime || {};
} else if (v instanceof Function) { } else if (v instanceof Function) {
v = convertFunc(v); v = convertFunc(v);
this.temp.push(v); this.temp.push(v);
this.setHandle(i, v._tvm_function.handle, kFuncHandle); this.setHandle(i, v._tvm_function.handle, kTVMPackedFuncHandle);
} else if (v instanceof TVMModule) { } else if (v instanceof TVMModule) {
this.setHandle(i, v.handle, kModuleHandle); this.setHandle(i, v.handle, kTVMModuleHandle);
} else { } else {
throwError("Unsupported argument type " + tp); throwError("Unsupported argument type " + tp);
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment