Commit 8214d6ca by Tianqi Chen Committed by GitHub

[DLPack] Upgrade dlpack to 0.2 (#609)

parent a152a9cb
...@@ -28,10 +28,10 @@ void Verify(tvm::runtime::Module mod, std::string fname) { ...@@ -28,10 +28,10 @@ void Verify(tvm::runtime::Module mod, std::string fname) {
DLTensor* x; DLTensor* x;
DLTensor* y; DLTensor* y;
int ndim = 1; int ndim = 1;
int dtype_code = kFloat; int dtype_code = kDLFloat;
int dtype_bits = 32; int dtype_bits = 32;
int dtype_lanes = 1; int dtype_lanes = 1;
int device_type = kCPU; int device_type = kDLCPU;
int device_id = 0; int device_id = 0;
int64_t shape[1] = {10}; int64_t shape[1] = {10};
TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes, TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes,
......
Subproject commit 9422e98f3f4dafc6bc3473cf8484543ad376aab6 Subproject commit 10892ac964f1af7c81aae145cd3fab78bbccd297
...@@ -105,10 +105,10 @@ inline TNodeRef TVMArgValue::AsNodeRef() const { ...@@ -105,10 +105,10 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
inline TVMArgValue::operator Halide::Expr() const { inline TVMArgValue::operator Halide::Expr() const {
if (type_code_ == kNull) return Expr(); if (type_code_ == kNull) return Expr();
if (type_code_ == kInt) { if (type_code_ == kDLInt) {
return Expr(static_cast<int>(value_.v_int64)); return Expr(static_cast<int>(value_.v_int64));
} }
if (type_code_ == kFloat) { if (type_code_ == kDLFloat) {
return Expr(static_cast<float>(value_.v_float64)); return Expr(static_cast<float>(value_.v_float64));
} }
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
......
...@@ -217,25 +217,25 @@ class ExtTypeVTable { ...@@ -217,25 +217,25 @@ class ExtTypeVTable {
class TVMPODValue_ { class TVMPODValue_ {
public: public:
operator double() const { operator double() const {
TVM_CHECK_TYPE_CODE(type_code_, kFloat); TVM_CHECK_TYPE_CODE(type_code_, kDLFloat);
return value_.v_float64; return value_.v_float64;
} }
operator int64_t() const { operator int64_t() const {
TVM_CHECK_TYPE_CODE(type_code_, kInt); TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64; return value_.v_int64;
} }
operator uint64_t() const { operator uint64_t() const {
TVM_CHECK_TYPE_CODE(type_code_, kInt); TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64; return value_.v_int64;
} }
operator int() const { operator int() const {
TVM_CHECK_TYPE_CODE(type_code_, kInt); TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
CHECK_LE(value_.v_int64, CHECK_LE(value_.v_int64,
std::numeric_limits<int>::max()); std::numeric_limits<int>::max());
return static_cast<int>(value_.v_int64); return static_cast<int>(value_.v_int64);
} }
operator bool() const { operator bool() const {
TVM_CHECK_TYPE_CODE(type_code_, kInt); TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64 != 0; return value_.v_int64 != 0;
} }
operator void*() const { operator void*() const {
...@@ -430,7 +430,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -430,7 +430,7 @@ class TVMRetValue : public TVMPODValue_ {
return *this; return *this;
} }
TVMRetValue& operator=(double value) { TVMRetValue& operator=(double value) {
this->SwitchToPOD(kFloat); this->SwitchToPOD(kDLFloat);
value_.v_float64 = value; value_.v_float64 = value;
return *this; return *this;
} }
...@@ -445,12 +445,12 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -445,12 +445,12 @@ class TVMRetValue : public TVMPODValue_ {
return *this; return *this;
} }
TVMRetValue& operator=(int64_t value) { TVMRetValue& operator=(int64_t value) {
this->SwitchToPOD(kInt); this->SwitchToPOD(kDLInt);
value_.v_int64 = value; value_.v_int64 = value;
return *this; return *this;
} }
TVMRetValue& operator=(int value) { TVMRetValue& operator=(int value) {
this->SwitchToPOD(kInt); this->SwitchToPOD(kDLInt);
value_.v_int64 = value; value_.v_int64 = value;
return *this; return *this;
} }
...@@ -460,7 +460,7 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -460,7 +460,7 @@ class TVMRetValue : public TVMPODValue_ {
return *this; return *this;
} }
TVMRetValue& operator=(bool value) { TVMRetValue& operator=(bool value) {
this->SwitchToPOD(kInt); this->SwitchToPOD(kDLInt);
value_.v_int64 = value; value_.v_int64 = value;
return *this; return *this;
} }
...@@ -609,9 +609,9 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -609,9 +609,9 @@ class TVMRetValue : public TVMPODValue_ {
// implementation details // implementation details
inline const char* TypeCode2Str(int type_code) { inline const char* TypeCode2Str(int type_code) {
switch (type_code) { switch (type_code) {
case kInt: return "int"; case kDLInt: return "int";
case kUInt: return "uint"; case kDLUInt: return "uint";
case kFloat: return "float"; case kDLFloat: return "float";
case kStr: return "str"; case kStr: return "str";
case kBytes: return "bytes"; case kBytes: return "bytes";
case kHandle: return "handle"; case kHandle: return "handle";
...@@ -648,11 +648,11 @@ inline TVMType String2TVMType(std::string s) { ...@@ -648,11 +648,11 @@ inline TVMType String2TVMType(std::string s) {
t.bits = 32; t.lanes = 1; t.bits = 32; t.lanes = 1;
const char* scan; const char* scan;
if (s.substr(0, 3) == "int") { if (s.substr(0, 3) == "int") {
t.code = kInt; scan = s.c_str() + 3; t.code = kDLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") { } else if (s.substr(0, 4) == "uint") {
t.code = kUInt; scan = s.c_str() + 4; t.code = kDLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") { } else if (s.substr(0, 5) == "float") {
t.code = kFloat; scan = s.c_str() + 5; t.code = kDLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") { } else if (s.substr(0, 6) == "handle") {
t.code = kHandle; t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default. t.bits = 64; // handle uses 64 bit by default.
...@@ -724,17 +724,17 @@ class TVMArgsSetter { ...@@ -724,17 +724,17 @@ class TVMArgsSetter {
std::is_integral<T>::value>::type> std::is_integral<T>::value>::type>
void operator()(size_t i, T value) const { void operator()(size_t i, T value) const {
values_[i].v_int64 = static_cast<int64_t>(value); values_[i].v_int64 = static_cast<int64_t>(value);
type_codes_[i] = kInt; type_codes_[i] = kDLInt;
} }
void operator()(size_t i, uint64_t value) const { void operator()(size_t i, uint64_t value) const {
values_[i].v_int64 = static_cast<int64_t>(value); values_[i].v_int64 = static_cast<int64_t>(value);
CHECK_LE(value, CHECK_LE(value,
static_cast<uint64_t>(std::numeric_limits<int64_t>::max())); static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
type_codes_[i] = kInt; type_codes_[i] = kDLInt;
} }
void operator()(size_t i, double value) const { void operator()(size_t i, double value) const {
values_[i].v_float64 = value; values_[i].v_float64 = value;
type_codes_[i] = kFloat; type_codes_[i] = kDLFloat;
} }
void operator()(size_t i, std::nullptr_t value) const { void operator()(size_t i, std::nullptr_t value) const {
values_[i].v_handle = value; values_[i].v_handle = value;
......
...@@ -161,10 +161,10 @@ void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) { ...@@ -161,10 +161,10 @@ void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) {
jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) { jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
switch (tcode) { switch (tcode) {
case kUInt: case kDLUInt:
case kInt: case kDLInt:
return newTVMValueLong(env, static_cast<jlong>(value.v_int64)); return newTVMValueLong(env, static_cast<jlong>(value.v_int64));
case kFloat: case kDLFloat:
return newTVMValueDouble(env, static_cast<jdouble>(value.v_float64)); return newTVMValueDouble(env, static_cast<jdouble>(value.v_float64));
case kModuleHandle: case kModuleHandle:
return newModule(env, reinterpret_cast<jlong>(value.v_handle)); return newModule(env, reinterpret_cast<jlong>(value.v_handle));
......
...@@ -62,7 +62,7 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgLong( ...@@ -62,7 +62,7 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgLong(
value.v_int64 = static_cast<int64_t>(arg); value.v_int64 = static_cast<int64_t>(arg);
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value); e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kInt); e->tvmFuncArgTypes.push_back(kDLInt);
} }
JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble( JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble(
...@@ -71,7 +71,7 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble( ...@@ -71,7 +71,7 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble(
value.v_float64 = static_cast<double>(arg); value.v_float64 = static_cast<double>(arg);
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get(); TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value); e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kFloat); e->tvmFuncArgTypes.push_back(kDLFloat);
} }
JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgString( JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgString(
......
...@@ -27,9 +27,9 @@ TVM_REGISTER_API("_max_value") ...@@ -27,9 +27,9 @@ TVM_REGISTER_API("_max_value")
TVM_REGISTER_API("_const") TVM_REGISTER_API("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[0].type_code() == kInt) { if (args[0].type_code() == kDLInt) {
*ret = make_const(args[1], args[0].operator int64_t()); *ret = make_const(args[1], args[0].operator int64_t());
} else if (args[0].type_code() == kFloat) { } else if (args[0].type_code() == kDLFloat) {
*ret = make_const(args[1], args[0].operator double()); *ret = make_const(args[1], args[0].operator double());
} else { } else {
LOG(FATAL) << "only accept int or float"; LOG(FATAL) << "only accept int or float";
......
...@@ -133,7 +133,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { ...@@ -133,7 +133,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
inline int DetectROCMComputeVersion() { inline int DetectROCMComputeVersion() {
TVMContext tvm_ctx; TVMContext tvm_ctx;
tvm_ctx.device_type = kROCM; tvm_ctx.device_type = kDLROCM;
tvm_ctx.device_id = 0; tvm_ctx.device_id = 0;
TVMRetValue val; TVMRetValue val;
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
......
...@@ -242,7 +242,7 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const { ...@@ -242,7 +242,7 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
CHECK_EQ(t.lanes(), 1); CHECK_EQ(t.lanes(), 1);
return t_void_p_; return t_void_p_;
} }
llvm::Type* etype; llvm::Type* etype = nullptr;
if (t.is_int() || t.is_uint()) { if (t.is_int() || t.is_uint()) {
etype = llvm::Type::getIntNTy(*ctx_, t.bits()); etype = llvm::Type::getIntNTy(*ctx_, t.bits());
} else if (t.is_float()) { } else if (t.is_float()) {
......
...@@ -132,7 +132,7 @@ class CodeGenNVPTX : public CodeGenLLVM { ...@@ -132,7 +132,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
inline int DetectCUDAComputeVersion() { inline int DetectCUDAComputeVersion() {
TVMContext tvm_ctx; TVMContext tvm_ctx;
tvm_ctx.device_type = kGPU; tvm_ctx.device_type = kDLGPU;
tvm_ctx.device_id = 0; tvm_ctx.device_id = 0;
TVMRetValue val; TVMRetValue val;
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr( tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
......
...@@ -340,16 +340,16 @@ class StackVM { ...@@ -340,16 +340,16 @@ class StackVM {
static OpCode GetLoad(TVMType t) { static OpCode GetLoad(TVMType t) {
CHECK_EQ(t.lanes, 1U); CHECK_EQ(t.lanes, 1U);
if (t.code == kHandle) return ARRAY_LOAD_HANDLE; if (t.code == kHandle) return ARRAY_LOAD_HANDLE;
if (t.code == kInt) { if (t.code == kDLInt) {
switch (t.bits) { switch (t.bits) {
case 32 : return ARRAY_LOAD_INT32; case 32 : return ARRAY_LOAD_INT32;
case 64 : return ARRAY_LOAD_INT64; case 64 : return ARRAY_LOAD_INT64;
} }
} else if (t.code == kUInt) { } else if (t.code == kDLUInt) {
switch (t.bits) { switch (t.bits) {
case 32 : return ARRAY_LOAD_UINT32; case 32 : return ARRAY_LOAD_UINT32;
} }
} else if (t.code == kFloat) { } else if (t.code == kDLFloat) {
switch (t.bits) { switch (t.bits) {
case 64 : return ARRAY_LOAD_FP64; case 64 : return ARRAY_LOAD_FP64;
} }
...@@ -365,16 +365,16 @@ class StackVM { ...@@ -365,16 +365,16 @@ class StackVM {
static OpCode GetStore(TVMType t) { static OpCode GetStore(TVMType t) {
CHECK_EQ(t.lanes, 1U); CHECK_EQ(t.lanes, 1U);
if (t.code == kHandle) return ARRAY_STORE_HANDLE; if (t.code == kHandle) return ARRAY_STORE_HANDLE;
if (t.code == kInt) { if (t.code == kDLInt) {
switch (t.bits) { switch (t.bits) {
case 32 : return ARRAY_STORE_INT32; case 32 : return ARRAY_STORE_INT32;
case 64 : return ARRAY_STORE_INT64; case 64 : return ARRAY_STORE_INT64;
} }
} else if (t.code == kUInt) { } else if (t.code == kDLUInt) {
switch (t.bits) { switch (t.bits) {
case 32 : return ARRAY_STORE_UINT32; case 32 : return ARRAY_STORE_UINT32;
} }
} else if (t.code == kFloat) { } else if (t.code == kDLFloat) {
switch (t.bits) { switch (t.bits) {
case 64 : return ARRAY_STORE_FP64; case 64 : return ARRAY_STORE_FP64;
} }
......
...@@ -91,10 +91,10 @@ class VPIDeviceAPI final : public runtime::DeviceAPI { ...@@ -91,10 +91,10 @@ class VPIDeviceAPI final : public runtime::DeviceAPI {
TVMContext ctx_from, TVMContext ctx_from,
TVMContext ctx_to, TVMContext ctx_to,
TVMStreamHandle stream) final { TVMStreamHandle stream) final {
if (static_cast<int>(ctx_from.device_type) == kVPI) { if (static_cast<int>(ctx_from.device_type) == kDLVPI) {
from = RealAddr(static_cast<const char*>(from) + from_offset, size); from = RealAddr(static_cast<const char*>(from) + from_offset, size);
} }
if (static_cast<int>(ctx_to.device_type) == kVPI) { if (static_cast<int>(ctx_to.device_type) == kDLVPI) {
to = RealAddr(static_cast<char*>(to) + to_offset, size); to = RealAddr(static_cast<char*>(to) + to_offset, size);
} }
memcpy(to, from, size); memcpy(to, from, size);
......
...@@ -30,9 +30,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul") ...@@ -30,9 +30,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
CHECK(C->strides == nullptr); CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr); CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr); CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kFloat, 32)); CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kFloat, 32)); CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kFloat, 32)); CHECK(TypeMatch(C->dtype, kDLFloat, 32));
cblas_sgemm(CblasColMajor, cblas_sgemm(CblasColMajor,
transb ? CblasTrans : CblasNoTrans, transb ? CblasTrans : CblasNoTrans,
transa ? CblasTrans : CblasNoTrans, transa ? CblasTrans : CblasNoTrans,
......
...@@ -13,17 +13,17 @@ namespace contrib { ...@@ -13,17 +13,17 @@ namespace contrib {
// CuDNN Data Type // CuDNN Data Type
cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType &dtype) { cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType &dtype) {
switch (dtype.code) { switch (dtype.code) {
case kInt: case kDLInt:
if (dtype.bits == 8 && dtype.lanes == 1) return CUDNN_DATA_INT8; if (dtype.bits == 8 && dtype.lanes == 1) return CUDNN_DATA_INT8;
else if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_INT32; else if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_INT32;
else if (dtype.bits == 8 && dtype.lanes == 4) return CUDNN_DATA_INT8x4; else if (dtype.bits == 8 && dtype.lanes == 4) return CUDNN_DATA_INT8x4;
else else
LOG(FATAL) << "Unsupported type"; LOG(FATAL) << "Unsupported type";
break; break;
case kUInt: case kDLUInt:
LOG(FATAL) << "Unsupported type"; LOG(FATAL) << "Unsupported type";
break; break;
case kFloat: case kDLFloat:
if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_FLOAT; if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_FLOAT;
else if (dtype.bits == 64 && dtype.lanes == 1) return CUDNN_DATA_DOUBLE; else if (dtype.bits == 64 && dtype.lanes == 1) return CUDNN_DATA_DOUBLE;
else if (dtype.bits == 16 && dtype.lanes == 1) return CUDNN_DATA_HALF; else if (dtype.bits == 16 && dtype.lanes == 1) return CUDNN_DATA_HALF;
......
...@@ -44,10 +44,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference") ...@@ -44,10 +44,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
CHECK(kernel->strides == nullptr); CHECK(kernel->strides == nullptr);
CHECK(bias->strides == nullptr); CHECK(bias->strides == nullptr);
CHECK(TypeMatch(input->dtype, kFloat, 32)); CHECK(TypeMatch(input->dtype, kDLFloat, 32));
CHECK(TypeMatch(kernel->dtype, kFloat, 32)); CHECK(TypeMatch(kernel->dtype, kDLFloat, 32));
CHECK(TypeMatch(bias->dtype, kFloat, 32)); CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
CHECK(TypeMatch(output->dtype, kFloat, 32)); CHECK(TypeMatch(output->dtype, kDLFloat, 32));
nnp_convolution_inference(nnp_convolution_algorithm_auto, nnp_convolution_inference(nnp_convolution_algorithm_auto,
nnp_convolution_transform_strategy_block_based, nnp_convolution_transform_strategy_block_based,
...@@ -102,10 +102,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output") ...@@ -102,10 +102,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
CHECK(kernel->strides == nullptr); CHECK(kernel->strides == nullptr);
CHECK(bias->strides == nullptr); CHECK(bias->strides == nullptr);
CHECK(TypeMatch(input->dtype, kFloat, 32)); CHECK(TypeMatch(input->dtype, kDLFloat, 32));
CHECK(TypeMatch(kernel->dtype, kFloat, 32)); CHECK(TypeMatch(kernel->dtype, kDLFloat, 32));
CHECK(TypeMatch(bias->dtype, kFloat, 32)); CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
CHECK(TypeMatch(output->dtype, kFloat, 32)); CHECK(TypeMatch(output->dtype, kDLFloat, 32));
nnp_convolution_output(nnp_convolution_algorithm_auto, nnp_convolution_output(nnp_convolution_algorithm_auto,
batch_size, batch_size,
......
...@@ -29,9 +29,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference") ...@@ -29,9 +29,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
CHECK(C->strides == nullptr); CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr); CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr); CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kFloat, 32)); CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kFloat, 32)); CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kFloat, 32)); CHECK(TypeMatch(C->dtype, kDLFloat, 32));
nnp_fully_connected_inference(B->shape[1], nnp_fully_connected_inference(B->shape[1],
B->shape[0], B->shape[0],
...@@ -58,9 +58,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output") ...@@ -58,9 +58,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
CHECK(C->strides == nullptr); CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr); CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr); CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kFloat, 32)); CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kFloat, 32)); CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kFloat, 32)); CHECK(TypeMatch(C->dtype, kDLFloat, 32));
nnp_fully_connected_output(A->shape[0], nnp_fully_connected_output(A->shape[0],
B->shape[1], B->shape[1],
......
...@@ -72,7 +72,7 @@ class BuiltinLower : public IRMutator { ...@@ -72,7 +72,7 @@ class BuiltinLower : public IRMutator {
int64_t nbytes = GetVectorBytes(op->type); int64_t nbytes = GetVectorBytes(op->type);
if (device_type_.defined()) { if (device_type_.defined()) {
if (arith::GetConst(device_type_, &dev_type)) { if (arith::GetConst(device_type_, &dev_type)) {
if (dev_type == kCPU) { if (dev_type == kDLCPU) {
int32_t constant_size = op->constant_allocation_size(); int32_t constant_size = op->constant_allocation_size();
if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) { if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
return stmt; return stmt;
......
...@@ -107,12 +107,13 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -107,12 +107,13 @@ LoweredFunc MakeAPI(Stmt body,
} else if (t.is_int() || t.is_uint()) { } else if (t.is_int() || t.is_uint()) {
std::ostringstream msg; std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be int"; msg << name << ": Expect arg[" << i << "] to be int";
seq_check.emplace_back(AssertStmt::make(tcode == kInt, msg.str(), nop)); seq_check.emplace_back(AssertStmt::make(tcode == kDLInt, msg.str(), nop));
} else { } else {
CHECK(t.is_float()); CHECK(t.is_float());
std::ostringstream msg; std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be float"; msg << name << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(AssertStmt::make(tcode == kFloat, msg.str(), nop)); seq_check.emplace_back(
AssertStmt::make(tcode == kDLFloat, msg.str(), nop));
} }
} else { } else {
args.push_back(v_arg); args.push_back(v_arg);
...@@ -148,7 +149,7 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -148,7 +149,7 @@ LoweredFunc MakeAPI(Stmt body,
seq_check.push_back(AttrStmt::make( seq_check.push_back(AttrStmt::make(
node, attr::device_context_type, device_type, nop)); node, attr::device_context_type, device_type, nop));
Stmt set_device = IfThenElse::make( Stmt set_device = IfThenElse::make(
device_type != kCPU, Evaluate::make(Call::make( device_type != kDLCPU, Evaluate::make(Call::make(
Int(32), intrinsic::tvm_call_packed, Int(32), intrinsic::tvm_call_packed,
{StringImm::make(runtime::symbol::tvm_set_device), {StringImm::make(runtime::symbol::tvm_set_device),
device_type, device_id}, Call::Intrinsic))); device_type, device_id}, Call::Intrinsic)));
......
...@@ -25,12 +25,12 @@ namespace runtime { ...@@ -25,12 +25,12 @@ namespace runtime {
*/ */
inline std::string DeviceName(int type) { inline std::string DeviceName(int type) {
switch (type) { switch (type) {
case kCPU: return "cpu"; case kDLCPU: return "cpu";
case kGPU: return "gpu"; case kDLGPU: return "gpu";
case kOpenCL: return "opencl"; case kDLOpenCL: return "opencl";
case kMetal: return "metal"; case kDLMetal: return "metal";
case kVPI: return "vpi"; case kDLVPI: return "vpi";
case kROCM: return "rocm"; case kDLROCM: return "rocm";
case kExtDev: return "ext_dev"; case kExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown"; default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
} }
...@@ -126,7 +126,7 @@ inline void TVMArrayFree_(TVMArray* arr) { ...@@ -126,7 +126,7 @@ inline void TVMArrayFree_(TVMArray* arr) {
inline void VerifyType(int dtype_code, int dtype_bits, int dtype_lanes) { inline void VerifyType(int dtype_code, int dtype_bits, int dtype_lanes) {
CHECK_GE(dtype_lanes, 1); CHECK_GE(dtype_lanes, 1);
if (dtype_code == kFloat) { if (dtype_code == kDLFloat) {
CHECK_EQ(dtype_bits % 32, 0); CHECK_EQ(dtype_bits % 32, 0);
} else { } else {
CHECK_EQ(dtype_bits % 8, 0); CHECK_EQ(dtype_bits % 8, 0);
...@@ -382,10 +382,10 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -382,10 +382,10 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
CHECK_EQ(from_size, to_size) CHECK_EQ(from_size, to_size)
<< "TVMArrayCopyFromTo: The size must exactly match"; << "TVMArrayCopyFromTo: The size must exactly match";
TVMContext ctx = from->ctx; TVMContext ctx = from->ctx;
if (ctx.device_type == kCPU) { if (ctx.device_type == kDLCPU) {
ctx = to->ctx; ctx = to->ctx;
} else { } else {
CHECK(to->ctx.device_type == kCPU || CHECK(to->ctx.device_type == kDLCPU ||
to->ctx.device_type == from->ctx.device_type) to->ctx.device_type == from->ctx.device_type)
<< "Can not copy across different ctx types directly"; << "Can not copy across different ctx types directly";
} }
...@@ -401,7 +401,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle, ...@@ -401,7 +401,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle,
size_t nbytes) { size_t nbytes) {
API_BEGIN(); API_BEGIN();
TVMContext cpu_ctx; TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU; cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(handle); size_t arr_size = GetDataSize(handle);
CHECK_EQ(arr_size, nbytes) CHECK_EQ(arr_size, nbytes)
...@@ -418,7 +418,7 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle, ...@@ -418,7 +418,7 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle,
size_t nbytes) { size_t nbytes) {
API_BEGIN(); API_BEGIN();
TVMContext cpu_ctx; TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU; cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(handle); size_t arr_size = GetDataSize(handle);
CHECK_EQ(arr_size, nbytes) CHECK_EQ(arr_size, nbytes)
......
...@@ -68,7 +68,7 @@ class CPUDeviceAPI final : public DeviceAPI { ...@@ -68,7 +68,7 @@ class CPUDeviceAPI final : public DeviceAPI {
struct CPUWorkspacePool : public WorkspacePool { struct CPUWorkspacePool : public WorkspacePool {
CPUWorkspacePool() : CPUWorkspacePool() :
WorkspacePool(kCPU, CPUDeviceAPI::Global()) {} WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
}; };
void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) { void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) {
......
...@@ -79,7 +79,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -79,7 +79,7 @@ class CUDADeviceAPI final : public DeviceAPI {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream); cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset; from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset; to = static_cast<char*>(to) + to_offset;
if (ctx_from.device_type == kGPU && ctx_to.device_type == kGPU) { if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id)); CUDA_CALL(cudaSetDevice(ctx_from.device_id));
if (ctx_from.device_id == ctx_to.device_id) { if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream); GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
...@@ -88,10 +88,10 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -88,10 +88,10 @@ class CUDADeviceAPI final : public DeviceAPI {
from, ctx_from.device_id, from, ctx_from.device_id,
size, cu_stream); size, cu_stream);
} }
} else if (ctx_from.device_type == kGPU && ctx_to.device_type == kCPU) { } else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id)); CUDA_CALL(cudaSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream); GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kGPU) { } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLGPU) {
CUDA_CALL(cudaSetDevice(ctx_to.device_id)); CUDA_CALL(cudaSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream); GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
} else { } else {
...@@ -140,7 +140,7 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -140,7 +140,7 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore; typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
CUDAThreadEntry::CUDAThreadEntry() CUDAThreadEntry::CUDAThreadEntry()
: pool(kGPU, CUDADeviceAPI::Global()) { : pool(kDLGPU, CUDADeviceAPI::Global()) {
} }
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() { CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
......
...@@ -462,7 +462,7 @@ void GraphRuntime::SetupStorage() { ...@@ -462,7 +462,7 @@ void GraphRuntime::SetupStorage() {
int64_t shape[] = {static_cast<int64_t>(pool_entry_bytes[i] + 3) / 4}; int64_t shape[] = {static_cast<int64_t>(pool_entry_bytes[i] + 3) / 4};
DLTensor* tensor; DLTensor* tensor;
TVM_CCALL(TVMArrayAlloc( TVM_CCALL(TVMArrayAlloc(
shape, 1, kFloat, 32, 1, ctx_.device_type, ctx_.device_id, &tensor)); shape, 1, kDLFloat, 32, 1, ctx_.device_type, ctx_.device_id, &tensor));
storage_pool_.push_back(tensor); storage_pool_.push_back(tensor);
} }
// Assign the pooled entries. // Assign the pooled entries.
......
...@@ -45,14 +45,14 @@ class MetalWorkspace final : public DeviceAPI { ...@@ -45,14 +45,14 @@ class MetalWorkspace final : public DeviceAPI {
~MetalWorkspace(); ~MetalWorkspace();
// Get command queue for given context. // Get command queue for given context.
id<MTLCommandQueue> GetCommandQueue(TVMContext ctx) { id<MTLCommandQueue> GetCommandQueue(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kMetal); CHECK_EQ(ctx.device_type, kDLMetal);
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size()) CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
<< "Invalid Metal device_id=" << ctx.device_id; << "Invalid Metal device_id=" << ctx.device_id;
return queues[ctx.device_id]; return queues[ctx.device_id];
} }
// Get device for given context // Get device for given context
id<MTLDevice> GetDevice(TVMContext ctx) { id<MTLDevice> GetDevice(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kMetal); CHECK_EQ(ctx.device_type, kDLMetal);
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < devices.size()) CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < devices.size())
<< "Invalid Metal device_id=" << ctx.device_id; << "Invalid Metal device_id=" << ctx.device_id;
return devices[ctx.device_id]; return devices[ctx.device_id];
...@@ -91,9 +91,9 @@ class MetalThreadEntry { ...@@ -91,9 +91,9 @@ class MetalThreadEntry {
WorkspacePool pool; WorkspacePool pool;
// constructor // constructor
MetalThreadEntry() MetalThreadEntry()
: pool(static_cast<DLDeviceType>(kMetal), MetalWorkspace::Global()) { : pool(static_cast<DLDeviceType>(kDLMetal), MetalWorkspace::Global()) {
context.device_id = 0; context.device_id = 0;
context.device_type = static_cast<DLDeviceType>(kMetal); context.device_type = static_cast<DLDeviceType>(kDLMetal);
} }
~MetalThreadEntry(); ~MetalThreadEntry();
// Get temp buffer with at least size under ctx. // Get temp buffer with at least size under ctx.
......
...@@ -150,13 +150,13 @@ void MetalWorkspace::CopyDataFromTo(const void* from, ...@@ -150,13 +150,13 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
this->Init(); this->Init();
CHECK(stream == nullptr); CHECK(stream == nullptr);
TVMContext ctx = ctx_from; TVMContext ctx = ctx_from;
if (ctx_from.device_type == kCPU) ctx = ctx_to; if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
id<MTLCommandQueue> queue = GetCommandQueue(ctx); id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer]; id<MTLCommandBuffer> cb = [queue commandBuffer];
int from_dev_type = static_cast<int>(ctx_from.device_type); int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type); int to_dev_type = static_cast<int>(ctx_to.device_type);
if (from_dev_type == kMetal && to_dev_type == kMetal) { if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
CHECK_EQ(ctx_from.device_id, ctx_to.device_id) CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
<< "Metal disallow cross device copy."; << "Metal disallow cross device copy.";
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder]; id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
...@@ -167,7 +167,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from, ...@@ -167,7 +167,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
size:size]; size:size];
[encoder endEncoding]; [encoder endEncoding];
[cb commit]; [cb commit];
} else if (from_dev_type == kMetal && to_dev_type == kCPU) { } else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
// copy to a local buffer before get into global buffer. // copy to a local buffer before get into global buffer.
id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from); id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
if (from_buf.storageMode != MTLStorageModeShared) { if (from_buf.storageMode != MTLStorageModeShared) {
...@@ -190,7 +190,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from, ...@@ -190,7 +190,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
static_cast<char*>([from_buf contents]) + from_offset, static_cast<char*>([from_buf contents]) + from_offset,
size); size);
} }
} else if (from_dev_type == kCPU && to_dev_type == kMetal) { } else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to); id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
if (to_buf.storageMode != MTLStorageModeShared) { if (to_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal() id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
......
...@@ -133,7 +133,7 @@ class OpenCLWorkspace final : public DeviceAPI { ...@@ -133,7 +133,7 @@ class OpenCLWorkspace final : public DeviceAPI {
void Init(); void Init();
// get the queue of the context // get the queue of the context
cl_command_queue GetQueue(TVMContext ctx) { cl_command_queue GetQueue(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kOpenCL); CHECK_EQ(ctx.device_type, kDLOpenCL);
this->Init(); this->Init();
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size()) CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
<< "Invalid OpenCL device_id=" << ctx.device_id; << "Invalid OpenCL device_id=" << ctx.device_id;
...@@ -178,9 +178,9 @@ class OpenCLThreadEntry { ...@@ -178,9 +178,9 @@ class OpenCLThreadEntry {
WorkspacePool pool; WorkspacePool pool;
// constructor // constructor
OpenCLThreadEntry() OpenCLThreadEntry()
: pool(kOpenCL, OpenCLWorkspace::Global()) { : pool(kDLOpenCL, OpenCLWorkspace::Global()) {
context.device_id = 0; context.device_id = 0;
context.device_type = kOpenCL; context.device_type = kDLOpenCL;
} }
// get the global workspace // get the global workspace
static OpenCLThreadEntry* ThreadLocal(); static OpenCLThreadEntry* ThreadLocal();
......
...@@ -76,13 +76,13 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from, ...@@ -76,13 +76,13 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from,
TVMStreamHandle stream) { TVMStreamHandle stream) {
this->Init(); this->Init();
CHECK(stream == nullptr); CHECK(stream == nullptr);
if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kOpenCL) { if (ctx_from.device_type == kDLOpenCL && ctx_to.device_type == kDLOpenCL) {
OPENCL_CALL(clEnqueueCopyBuffer( OPENCL_CALL(clEnqueueCopyBuffer(
this->GetQueue(ctx_to), this->GetQueue(ctx_to),
static_cast<cl_mem>((void*)from), // NOLINT(*) static_cast<cl_mem>((void*)from), // NOLINT(*)
static_cast<cl_mem>(to), static_cast<cl_mem>(to),
from_offset, to_offset, size, 0, nullptr, nullptr)); from_offset, to_offset, size, 0, nullptr, nullptr));
} else if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kCPU) { } else if (ctx_from.device_type == kDLOpenCL && ctx_to.device_type == kDLCPU) {
OPENCL_CALL(clEnqueueReadBuffer( OPENCL_CALL(clEnqueueReadBuffer(
this->GetQueue(ctx_from), this->GetQueue(ctx_from),
static_cast<cl_mem>((void*)from), // NOLINT(*) static_cast<cl_mem>((void*)from), // NOLINT(*)
...@@ -90,7 +90,7 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from, ...@@ -90,7 +90,7 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from,
static_cast<char*>(to) + to_offset, static_cast<char*>(to) + to_offset,
0, nullptr, nullptr)); 0, nullptr, nullptr));
OPENCL_CALL(clFinish(this->GetQueue(ctx_from))); OPENCL_CALL(clFinish(this->GetQueue(ctx_from)));
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kOpenCL) { } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLOpenCL) {
OPENCL_CALL(clEnqueueWriteBuffer( OPENCL_CALL(clEnqueueWriteBuffer(
this->GetQueue(ctx_to), this->GetQueue(ctx_to),
static_cast<cl_mem>(to), static_cast<cl_mem>(to),
......
...@@ -104,12 +104,12 @@ enum ArgConvertCode { ...@@ -104,12 +104,12 @@ enum ArgConvertCode {
inline ArgConvertCode GetArgConvertCode(TVMType t) { inline ArgConvertCode GetArgConvertCode(TVMType t) {
CHECK_EQ(t.lanes, 1U) CHECK_EQ(t.lanes, 1U)
<< "Cannot pass vector type argument to devic function for now"; << "Cannot pass vector type argument to devic function for now";
if (t.code == kInt) { if (t.code == kDLInt) {
if (t.bits == 64U) return INT64_TO_INT64; if (t.bits == 64U) return INT64_TO_INT64;
if (t.bits == 32U) return INT64_TO_INT32; if (t.bits == 32U) return INT64_TO_INT32;
} else if (t.code == kUInt) { } else if (t.code == kDLUInt) {
if (t.bits == 32U) return INT64_TO_UINT32; if (t.bits == 32U) return INT64_TO_UINT32;
} else if (t.code == kFloat) { } else if (t.code == kDLFloat) {
if (t.bits == 64U) return FLOAT64_TO_FLOAT64; if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
if (t.bits == 32U) return FLOAT64_TO_FLOAT32; if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
} else if (t.code == kHandle) { } else if (t.code == kHandle) {
......
...@@ -77,7 +77,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -77,7 +77,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
hipStream_t hip_stream = static_cast<hipStream_t>(stream); hipStream_t hip_stream = static_cast<hipStream_t>(stream);
from = static_cast<const char*>(from) + from_offset; from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset; to = static_cast<char*>(to) + to_offset;
if (ctx_from.device_type == kROCM && ctx_to.device_type == kROCM) { if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLROCM) {
ROCM_CALL(hipSetDevice(ctx_from.device_id)); ROCM_CALL(hipSetDevice(ctx_from.device_id));
if (ctx_from.device_id == ctx_to.device_id) { if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream); GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream);
...@@ -86,10 +86,10 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -86,10 +86,10 @@ class ROCMDeviceAPI final : public DeviceAPI {
from, ctx_from.device_id, from, ctx_from.device_id,
size, hip_stream); size, hip_stream);
} }
} else if (ctx_from.device_type == kROCM && ctx_to.device_type == kCPU) { } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) {
ROCM_CALL(hipSetDevice(ctx_from.device_id)); ROCM_CALL(hipSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream); GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream);
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kROCM) { } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) {
ROCM_CALL(hipSetDevice(ctx_to.device_id)); ROCM_CALL(hipSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream); GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream);
} else { } else {
...@@ -138,7 +138,7 @@ class ROCMDeviceAPI final : public DeviceAPI { ...@@ -138,7 +138,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore; typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
ROCMThreadEntry::ROCMThreadEntry() ROCMThreadEntry::ROCMThreadEntry()
: pool(kROCM, ROCMDeviceAPI::Global()) { : pool(kDLROCM, ROCMDeviceAPI::Global()) {
} }
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
......
...@@ -55,12 +55,12 @@ class RPCDeviceAPI final : public DeviceAPI { ...@@ -55,12 +55,12 @@ class RPCDeviceAPI final : public DeviceAPI {
static_cast<const RemoteSpace*>(to)->data, to_offset, static_cast<const RemoteSpace*>(to)->data, to_offset,
size, ctx_from, ctx_to, stream); size, ctx_from, ctx_to, stream);
} else if (from_dev_type > kRPCSessMask && } else if (from_dev_type > kRPCSessMask &&
to_dev_type == kCPU) { to_dev_type == kDLCPU) {
GetSess(ctx_from)->CopyFromRemote( GetSess(ctx_from)->CopyFromRemote(
static_cast<const RemoteSpace*>(from)->data, from_offset, static_cast<const RemoteSpace*>(from)->data, from_offset,
to, to_offset, size, to, to_offset, size,
ctx_from); ctx_from);
} else if (from_dev_type == kCPU && } else if (from_dev_type == kDLCPU &&
to_dev_type > kRPCSessMask) { to_dev_type > kRPCSessMask) {
GetSess(ctx_to)->CopyToRemote( GetSess(ctx_to)->CopyToRemote(
(void*)from, from_offset, // NOLINT(*) (void*)from, from_offset, // NOLINT(*)
......
...@@ -162,9 +162,9 @@ class RPCSession::EventHandler { ...@@ -162,9 +162,9 @@ class RPCSession::EventHandler {
int tcode = type_codes[i]; int tcode = type_codes[i];
TVMValue value = arg_values[i]; TVMValue value = arg_values[i];
switch (tcode) { switch (tcode) {
case kInt: case kDLInt:
case kUInt: case kDLUInt:
case kFloat: case kDLFloat:
case kTVMType: { case kTVMType: {
writer_->Write(&value, sizeof(TVMValue)); writer_->Write(&value, sizeof(TVMValue));
break; break;
...@@ -315,9 +315,9 @@ class RPCSession::EventHandler { ...@@ -315,9 +315,9 @@ class RPCSession::EventHandler {
int tcode = arg_buf_->tcode[arg_index_]; int tcode = arg_buf_->tcode[arg_index_];
static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant"); static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant");
switch (tcode) { switch (tcode) {
case kInt: case kDLInt:
case kUInt: case kDLUInt:
case kFloat: case kDLFloat:
case kTVMType: case kTVMType:
case kHandle: case kHandle:
case kStr: case kStr:
...@@ -352,9 +352,9 @@ class RPCSession::EventHandler { ...@@ -352,9 +352,9 @@ class RPCSession::EventHandler {
TVMValue& value = arg_buf_->value[arg_index_]; TVMValue& value = arg_buf_->value[arg_index_];
if (arg_recv_stage_ == 0) { if (arg_recv_stage_ == 0) {
switch (tcode) { switch (tcode) {
case kInt: case kDLInt:
case kUInt: case kDLUInt:
case kFloat: case kDLFloat:
case kTVMType: case kTVMType:
case kTVMContext: { case kTVMContext: {
this->Read(&value, sizeof(TVMValue)); this->Read(&value, sizeof(TVMValue));
...@@ -484,7 +484,7 @@ class RPCSession::EventHandler { ...@@ -484,7 +484,7 @@ class RPCSession::EventHandler {
this->Read(&offset, sizeof(offset)); this->Read(&offset, sizeof(offset));
this->Read(&size, sizeof(size)); this->Read(&size, sizeof(size));
this->Read(&ctx, sizeof(ctx)); this->Read(&ctx, sizeof(ctx));
if (ctx.device_type == kCPU) { if (ctx.device_type == kDLCPU) {
RPCCode code = RPCCode::kCopyAck; RPCCode code = RPCCode::kCopyAck;
writer_->Write(&code, sizeof(code)); writer_->Write(&code, sizeof(code));
writer_->Write(reinterpret_cast<char*>(handle) + offset, size); writer_->Write(reinterpret_cast<char*>(handle) + offset, size);
...@@ -492,7 +492,7 @@ class RPCSession::EventHandler { ...@@ -492,7 +492,7 @@ class RPCSession::EventHandler {
temp_data_.resize(size + 1); temp_data_.resize(size + 1);
try { try {
TVMContext cpu_ctx; TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU; cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
DeviceAPI::Get(ctx)->CopyDataFromTo( DeviceAPI::Get(ctx)->CopyDataFromTo(
reinterpret_cast<void*>(handle), offset, reinterpret_cast<void*>(handle), offset,
...@@ -531,7 +531,7 @@ class RPCSession::EventHandler { ...@@ -531,7 +531,7 @@ class RPCSession::EventHandler {
int ret_tcode = kNull; int ret_tcode = kNull;
RPCCode code = RPCCode::kReturn; RPCCode code = RPCCode::kReturn;
std::string errmsg; std::string errmsg;
if (copy_ctx_.device_type == kCPU) { if (copy_ctx_.device_type == kDLCPU) {
this->Read( this->Read(
reinterpret_cast<char*>(copy_handle_) + copy_offset_, copy_size_); reinterpret_cast<char*>(copy_handle_) + copy_offset_, copy_size_);
} else { } else {
...@@ -539,7 +539,7 @@ class RPCSession::EventHandler { ...@@ -539,7 +539,7 @@ class RPCSession::EventHandler {
this->Read(&temp_data_[0], copy_size_); this->Read(&temp_data_[0], copy_size_);
try { try {
TVMContext cpu_ctx; TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU; cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0; cpu_ctx.device_id = 0;
DeviceAPI::Get(copy_ctx_)->CopyDataFromTo( DeviceAPI::Get(copy_ctx_)->CopyDataFromTo(
temp_data_.data(), 0, temp_data_.data(), 0,
...@@ -915,10 +915,10 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) { ...@@ -915,10 +915,10 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
TVMContext ctx_to = args[6]; TVMContext ctx_to = args[6];
TVMStreamHandle stream = args[7]; TVMStreamHandle stream = args[7];
TVMContext ctx = ctx_from; TVMContext ctx = ctx_from;
if (ctx.device_type == kCPU) { if (ctx.device_type == kDLCPU) {
ctx = ctx_to; ctx = ctx_to;
} else { } else {
CHECK(ctx_to.device_type == kCPU || CHECK(ctx_to.device_type == kDLCPU ||
ctx_to.device_type == ctx_from.device_type) ctx_to.device_type == ctx_from.device_type)
<< "Can not copy across different ctx types directly"; << "Can not copy across different ctx types directly";
} }
......
...@@ -14,7 +14,7 @@ TEST(PackedFunc, Basic) { ...@@ -14,7 +14,7 @@ TEST(PackedFunc, Basic) {
Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) { Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK(args.num_args == 3); CHECK(args.num_args == 3);
CHECK(args.values[0].v_float64 == 1.0); CHECK(args.values[0].v_float64 == 1.0);
CHECK(args.type_codes[0] == kFloat); CHECK(args.type_codes[0] == kDLFloat);
CHECK(args.values[1].v_handle == &a); CHECK(args.values[1].v_handle == &a);
CHECK(args.type_codes[1] == kArrayHandle); CHECK(args.type_codes[1] == kArrayHandle);
CHECK(args.values[2].v_handle == &x); CHECK(args.values[2].v_handle == &x);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment