Commit 8214d6ca by Tianqi Chen Committed by GitHub

[DLPack] Upgrade dlpack to 0.2 (#609)

parent a152a9cb
......@@ -28,10 +28,10 @@ void Verify(tvm::runtime::Module mod, std::string fname) {
DLTensor* x;
DLTensor* y;
int ndim = 1;
int dtype_code = kFloat;
int dtype_code = kDLFloat;
int dtype_bits = 32;
int dtype_lanes = 1;
int device_type = kCPU;
int device_type = kDLCPU;
int device_id = 0;
int64_t shape[1] = {10};
TVMArrayAlloc(shape, ndim, dtype_code, dtype_bits, dtype_lanes,
......
Subproject commit 9422e98f3f4dafc6bc3473cf8484543ad376aab6
Subproject commit 10892ac964f1af7c81aae145cd3fab78bbccd297
......@@ -105,10 +105,10 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
inline TVMArgValue::operator Halide::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kInt) {
if (type_code_ == kDLInt) {
return Expr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kFloat) {
if (type_code_ == kDLFloat) {
return Expr(static_cast<float>(value_.v_float64));
}
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
......
......@@ -217,25 +217,25 @@ class ExtTypeVTable {
class TVMPODValue_ {
public:
operator double() const {
TVM_CHECK_TYPE_CODE(type_code_, kFloat);
TVM_CHECK_TYPE_CODE(type_code_, kDLFloat);
return value_.v_float64;
}
operator int64_t() const {
TVM_CHECK_TYPE_CODE(type_code_, kInt);
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64;
}
operator uint64_t() const {
TVM_CHECK_TYPE_CODE(type_code_, kInt);
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64;
}
operator int() const {
TVM_CHECK_TYPE_CODE(type_code_, kInt);
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
CHECK_LE(value_.v_int64,
std::numeric_limits<int>::max());
return static_cast<int>(value_.v_int64);
}
operator bool() const {
TVM_CHECK_TYPE_CODE(type_code_, kInt);
TVM_CHECK_TYPE_CODE(type_code_, kDLInt);
return value_.v_int64 != 0;
}
operator void*() const {
......@@ -430,7 +430,7 @@ class TVMRetValue : public TVMPODValue_ {
return *this;
}
TVMRetValue& operator=(double value) {
this->SwitchToPOD(kFloat);
this->SwitchToPOD(kDLFloat);
value_.v_float64 = value;
return *this;
}
......@@ -445,12 +445,12 @@ class TVMRetValue : public TVMPODValue_ {
return *this;
}
TVMRetValue& operator=(int64_t value) {
this->SwitchToPOD(kInt);
this->SwitchToPOD(kDLInt);
value_.v_int64 = value;
return *this;
}
TVMRetValue& operator=(int value) {
this->SwitchToPOD(kInt);
this->SwitchToPOD(kDLInt);
value_.v_int64 = value;
return *this;
}
......@@ -460,7 +460,7 @@ class TVMRetValue : public TVMPODValue_ {
return *this;
}
TVMRetValue& operator=(bool value) {
this->SwitchToPOD(kInt);
this->SwitchToPOD(kDLInt);
value_.v_int64 = value;
return *this;
}
......@@ -609,9 +609,9 @@ class TVMRetValue : public TVMPODValue_ {
// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kInt: return "int";
case kUInt: return "uint";
case kFloat: return "float";
case kDLInt: return "int";
case kDLUInt: return "uint";
case kDLFloat: return "float";
case kStr: return "str";
case kBytes: return "bytes";
case kHandle: return "handle";
......@@ -648,11 +648,11 @@ inline TVMType String2TVMType(std::string s) {
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kInt; scan = s.c_str() + 3;
t.code = kDLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kUInt; scan = s.c_str() + 4;
t.code = kDLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kFloat; scan = s.c_str() + 5;
t.code = kDLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default.
......@@ -724,17 +724,17 @@ class TVMArgsSetter {
std::is_integral<T>::value>::type>
void operator()(size_t i, T value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
type_codes_[i] = kInt;
type_codes_[i] = kDLInt;
}
void operator()(size_t i, uint64_t value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
CHECK_LE(value,
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
type_codes_[i] = kInt;
type_codes_[i] = kDLInt;
}
void operator()(size_t i, double value) const {
values_[i].v_float64 = value;
type_codes_[i] = kFloat;
type_codes_[i] = kDLFloat;
}
void operator()(size_t i, std::nullptr_t value) const {
values_[i].v_handle = value;
......
......@@ -161,10 +161,10 @@ void fromJavaContext(JNIEnv *env, jobject jctx, TVMContext *ctx) {
jobject tvmRetValueToJava(JNIEnv *env, TVMValue value, int tcode) {
switch (tcode) {
case kUInt:
case kInt:
case kDLUInt:
case kDLInt:
return newTVMValueLong(env, static_cast<jlong>(value.v_int64));
case kFloat:
case kDLFloat:
return newTVMValueDouble(env, static_cast<jdouble>(value.v_float64));
case kModuleHandle:
return newModule(env, reinterpret_cast<jlong>(value.v_handle));
......
......@@ -62,7 +62,7 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgLong(
value.v_int64 = static_cast<int64_t>(arg);
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kInt);
e->tvmFuncArgTypes.push_back(kDLInt);
}
JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble(
......@@ -71,7 +71,7 @@ JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgDouble(
value.v_float64 = static_cast<double>(arg);
TVMFuncArgsThreadLocalEntry *e = TVMFuncArgsThreadLocalStore::Get();
e->tvmFuncArgValues.push_back(value);
e->tvmFuncArgTypes.push_back(kFloat);
e->tvmFuncArgTypes.push_back(kDLFloat);
}
JNIEXPORT void JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncPushArgString(
......
......@@ -27,9 +27,9 @@ TVM_REGISTER_API("_max_value")
TVM_REGISTER_API("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) {
if (args[0].type_code() == kInt) {
if (args[0].type_code() == kDLInt) {
*ret = make_const(args[1], args[0].operator int64_t());
} else if (args[0].type_code() == kFloat) {
} else if (args[0].type_code() == kDLFloat) {
*ret = make_const(args[1], args[0].operator double());
} else {
LOG(FATAL) << "only accept int or float";
......
......@@ -133,7 +133,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
inline int DetectROCMComputeVersion() {
TVMContext tvm_ctx;
tvm_ctx.device_type = kROCM;
tvm_ctx.device_type = kDLROCM;
tvm_ctx.device_id = 0;
TVMRetValue val;
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
......
......@@ -242,7 +242,7 @@ llvm::Type* CodeGenLLVM::LLVMType(const Type& t) const {
CHECK_EQ(t.lanes(), 1);
return t_void_p_;
}
llvm::Type* etype;
llvm::Type* etype = nullptr;
if (t.is_int() || t.is_uint()) {
etype = llvm::Type::getIntNTy(*ctx_, t.bits());
} else if (t.is_float()) {
......
......@@ -132,7 +132,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
inline int DetectCUDAComputeVersion() {
TVMContext tvm_ctx;
tvm_ctx.device_type = kGPU;
tvm_ctx.device_type = kDLGPU;
tvm_ctx.device_id = 0;
TVMRetValue val;
tvm::runtime::DeviceAPI::Get(tvm_ctx)->GetAttr(
......
......@@ -340,16 +340,16 @@ class StackVM {
static OpCode GetLoad(TVMType t) {
CHECK_EQ(t.lanes, 1U);
if (t.code == kHandle) return ARRAY_LOAD_HANDLE;
if (t.code == kInt) {
if (t.code == kDLInt) {
switch (t.bits) {
case 32 : return ARRAY_LOAD_INT32;
case 64 : return ARRAY_LOAD_INT64;
}
} else if (t.code == kUInt) {
} else if (t.code == kDLUInt) {
switch (t.bits) {
case 32 : return ARRAY_LOAD_UINT32;
}
} else if (t.code == kFloat) {
} else if (t.code == kDLFloat) {
switch (t.bits) {
case 64 : return ARRAY_LOAD_FP64;
}
......@@ -365,16 +365,16 @@ class StackVM {
static OpCode GetStore(TVMType t) {
CHECK_EQ(t.lanes, 1U);
if (t.code == kHandle) return ARRAY_STORE_HANDLE;
if (t.code == kInt) {
if (t.code == kDLInt) {
switch (t.bits) {
case 32 : return ARRAY_STORE_INT32;
case 64 : return ARRAY_STORE_INT64;
}
} else if (t.code == kUInt) {
} else if (t.code == kDLUInt) {
switch (t.bits) {
case 32 : return ARRAY_STORE_UINT32;
}
} else if (t.code == kFloat) {
} else if (t.code == kDLFloat) {
switch (t.bits) {
case 64 : return ARRAY_STORE_FP64;
}
......
......@@ -91,10 +91,10 @@ class VPIDeviceAPI final : public runtime::DeviceAPI {
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) final {
if (static_cast<int>(ctx_from.device_type) == kVPI) {
if (static_cast<int>(ctx_from.device_type) == kDLVPI) {
from = RealAddr(static_cast<const char*>(from) + from_offset, size);
}
if (static_cast<int>(ctx_to.device_type) == kVPI) {
if (static_cast<int>(ctx_to.device_type) == kDLVPI) {
to = RealAddr(static_cast<char*>(to) + to_offset, size);
}
memcpy(to, from, size);
......
......@@ -30,9 +30,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.cblas.matmul")
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kFloat, 32));
CHECK(TypeMatch(B->dtype, kFloat, 32));
CHECK(TypeMatch(C->dtype, kFloat, 32));
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
cblas_sgemm(CblasColMajor,
transb ? CblasTrans : CblasNoTrans,
transa ? CblasTrans : CblasNoTrans,
......
......@@ -13,17 +13,17 @@ namespace contrib {
// CuDNN Data Type
cudnnDataType_t CuDNNDataType::DLTypeToCuDNNType(const DLDataType &dtype) {
switch (dtype.code) {
case kInt:
case kDLInt:
if (dtype.bits == 8 && dtype.lanes == 1) return CUDNN_DATA_INT8;
else if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_INT32;
else if (dtype.bits == 8 && dtype.lanes == 4) return CUDNN_DATA_INT8x4;
else
LOG(FATAL) << "Unsupported type";
break;
case kUInt:
case kDLUInt:
LOG(FATAL) << "Unsupported type";
break;
case kFloat:
case kDLFloat:
if (dtype.bits == 32 && dtype.lanes == 1) return CUDNN_DATA_FLOAT;
else if (dtype.bits == 64 && dtype.lanes == 1) return CUDNN_DATA_DOUBLE;
else if (dtype.bits == 16 && dtype.lanes == 1) return CUDNN_DATA_HALF;
......
......@@ -44,10 +44,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_inference")
CHECK(kernel->strides == nullptr);
CHECK(bias->strides == nullptr);
CHECK(TypeMatch(input->dtype, kFloat, 32));
CHECK(TypeMatch(kernel->dtype, kFloat, 32));
CHECK(TypeMatch(bias->dtype, kFloat, 32));
CHECK(TypeMatch(output->dtype, kFloat, 32));
CHECK(TypeMatch(input->dtype, kDLFloat, 32));
CHECK(TypeMatch(kernel->dtype, kDLFloat, 32));
CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
CHECK(TypeMatch(output->dtype, kDLFloat, 32));
nnp_convolution_inference(nnp_convolution_algorithm_auto,
nnp_convolution_transform_strategy_block_based,
......@@ -102,10 +102,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.convolution_output")
CHECK(kernel->strides == nullptr);
CHECK(bias->strides == nullptr);
CHECK(TypeMatch(input->dtype, kFloat, 32));
CHECK(TypeMatch(kernel->dtype, kFloat, 32));
CHECK(TypeMatch(bias->dtype, kFloat, 32));
CHECK(TypeMatch(output->dtype, kFloat, 32));
CHECK(TypeMatch(input->dtype, kDLFloat, 32));
CHECK(TypeMatch(kernel->dtype, kDLFloat, 32));
CHECK(TypeMatch(bias->dtype, kDLFloat, 32));
CHECK(TypeMatch(output->dtype, kDLFloat, 32));
nnp_convolution_output(nnp_convolution_algorithm_auto,
batch_size,
......
......@@ -29,9 +29,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_inference")
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kFloat, 32));
CHECK(TypeMatch(B->dtype, kFloat, 32));
CHECK(TypeMatch(C->dtype, kFloat, 32));
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
nnp_fully_connected_inference(B->shape[1],
B->shape[0],
......@@ -58,9 +58,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.nnpack.fully_connected_output")
CHECK(C->strides == nullptr);
CHECK(B->strides == nullptr);
CHECK(A->strides == nullptr);
CHECK(TypeMatch(A->dtype, kFloat, 32));
CHECK(TypeMatch(B->dtype, kFloat, 32));
CHECK(TypeMatch(C->dtype, kFloat, 32));
CHECK(TypeMatch(A->dtype, kDLFloat, 32));
CHECK(TypeMatch(B->dtype, kDLFloat, 32));
CHECK(TypeMatch(C->dtype, kDLFloat, 32));
nnp_fully_connected_output(A->shape[0],
B->shape[1],
......
......@@ -72,7 +72,7 @@ class BuiltinLower : public IRMutator {
int64_t nbytes = GetVectorBytes(op->type);
if (device_type_.defined()) {
if (arith::GetConst(device_type_, &dev_type)) {
if (dev_type == kCPU) {
if (dev_type == kDLCPU) {
int32_t constant_size = op->constant_allocation_size();
if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
return stmt;
......
......@@ -107,12 +107,13 @@ LoweredFunc MakeAPI(Stmt body,
} else if (t.is_int() || t.is_uint()) {
std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be int";
seq_check.emplace_back(AssertStmt::make(tcode == kInt, msg.str(), nop));
seq_check.emplace_back(AssertStmt::make(tcode == kDLInt, msg.str(), nop));
} else {
CHECK(t.is_float());
std::ostringstream msg;
msg << name << ": Expect arg[" << i << "] to be float";
seq_check.emplace_back(AssertStmt::make(tcode == kFloat, msg.str(), nop));
seq_check.emplace_back(
AssertStmt::make(tcode == kDLFloat, msg.str(), nop));
}
} else {
args.push_back(v_arg);
......@@ -148,7 +149,7 @@ LoweredFunc MakeAPI(Stmt body,
seq_check.push_back(AttrStmt::make(
node, attr::device_context_type, device_type, nop));
Stmt set_device = IfThenElse::make(
device_type != kCPU, Evaluate::make(Call::make(
device_type != kDLCPU, Evaluate::make(Call::make(
Int(32), intrinsic::tvm_call_packed,
{StringImm::make(runtime::symbol::tvm_set_device),
device_type, device_id}, Call::Intrinsic)));
......
......@@ -25,12 +25,12 @@ namespace runtime {
*/
inline std::string DeviceName(int type) {
switch (type) {
case kCPU: return "cpu";
case kGPU: return "gpu";
case kOpenCL: return "opencl";
case kMetal: return "metal";
case kVPI: return "vpi";
case kROCM: return "rocm";
case kDLCPU: return "cpu";
case kDLGPU: return "gpu";
case kDLOpenCL: return "opencl";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
case kExtDev: return "ext_dev";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
......@@ -126,7 +126,7 @@ inline void TVMArrayFree_(TVMArray* arr) {
inline void VerifyType(int dtype_code, int dtype_bits, int dtype_lanes) {
CHECK_GE(dtype_lanes, 1);
if (dtype_code == kFloat) {
if (dtype_code == kDLFloat) {
CHECK_EQ(dtype_bits % 32, 0);
} else {
CHECK_EQ(dtype_bits % 8, 0);
......@@ -382,10 +382,10 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
CHECK_EQ(from_size, to_size)
<< "TVMArrayCopyFromTo: The size must exactly match";
TVMContext ctx = from->ctx;
if (ctx.device_type == kCPU) {
if (ctx.device_type == kDLCPU) {
ctx = to->ctx;
} else {
CHECK(to->ctx.device_type == kCPU ||
CHECK(to->ctx.device_type == kDLCPU ||
to->ctx.device_type == from->ctx.device_type)
<< "Can not copy across different ctx types directly";
}
......@@ -401,7 +401,7 @@ int TVMArrayCopyFromBytes(TVMArrayHandle handle,
size_t nbytes) {
API_BEGIN();
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(handle);
CHECK_EQ(arr_size, nbytes)
......@@ -418,7 +418,7 @@ int TVMArrayCopyToBytes(TVMArrayHandle handle,
size_t nbytes) {
API_BEGIN();
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
size_t arr_size = GetDataSize(handle);
CHECK_EQ(arr_size, nbytes)
......
......@@ -68,7 +68,7 @@ class CPUDeviceAPI final : public DeviceAPI {
struct CPUWorkspacePool : public WorkspacePool {
CPUWorkspacePool() :
WorkspacePool(kCPU, CPUDeviceAPI::Global()) {}
WorkspacePool(kDLCPU, CPUDeviceAPI::Global()) {}
};
void* CPUDeviceAPI::AllocWorkspace(TVMContext ctx, size_t size) {
......
......@@ -79,7 +79,7 @@ class CUDADeviceAPI final : public DeviceAPI {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset;
if (ctx_from.device_type == kGPU && ctx_to.device_type == kGPU) {
if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLGPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
......@@ -88,10 +88,10 @@ class CUDADeviceAPI final : public DeviceAPI {
from, ctx_from.device_id,
size, cu_stream);
}
} else if (ctx_from.device_type == kGPU && ctx_to.device_type == kCPU) {
} else if (ctx_from.device_type == kDLGPU && ctx_to.device_type == kDLCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kGPU) {
} else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLGPU) {
CUDA_CALL(cudaSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
} else {
......@@ -140,7 +140,7 @@ class CUDADeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<CUDAThreadEntry> CUDAThreadStore;
CUDAThreadEntry::CUDAThreadEntry()
: pool(kGPU, CUDADeviceAPI::Global()) {
: pool(kDLGPU, CUDADeviceAPI::Global()) {
}
CUDAThreadEntry* CUDAThreadEntry::ThreadLocal() {
......
......@@ -462,7 +462,7 @@ void GraphRuntime::SetupStorage() {
int64_t shape[] = {static_cast<int64_t>(pool_entry_bytes[i] + 3) / 4};
DLTensor* tensor;
TVM_CCALL(TVMArrayAlloc(
shape, 1, kFloat, 32, 1, ctx_.device_type, ctx_.device_id, &tensor));
shape, 1, kDLFloat, 32, 1, ctx_.device_type, ctx_.device_id, &tensor));
storage_pool_.push_back(tensor);
}
// Assign the pooled entries.
......
......@@ -45,14 +45,14 @@ class MetalWorkspace final : public DeviceAPI {
~MetalWorkspace();
// Get command queue for given context.
id<MTLCommandQueue> GetCommandQueue(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kMetal);
CHECK_EQ(ctx.device_type, kDLMetal);
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
<< "Invalid Metal device_id=" << ctx.device_id;
return queues[ctx.device_id];
}
// Get device for given context
id<MTLDevice> GetDevice(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kMetal);
CHECK_EQ(ctx.device_type, kDLMetal);
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < devices.size())
<< "Invalid Metal device_id=" << ctx.device_id;
return devices[ctx.device_id];
......@@ -91,9 +91,9 @@ class MetalThreadEntry {
WorkspacePool pool;
// constructor
MetalThreadEntry()
: pool(static_cast<DLDeviceType>(kMetal), MetalWorkspace::Global()) {
: pool(static_cast<DLDeviceType>(kDLMetal), MetalWorkspace::Global()) {
context.device_id = 0;
context.device_type = static_cast<DLDeviceType>(kMetal);
context.device_type = static_cast<DLDeviceType>(kDLMetal);
}
~MetalThreadEntry();
// Get temp buffer with at least size under ctx.
......
......@@ -150,13 +150,13 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
this->Init();
CHECK(stream == nullptr);
TVMContext ctx = ctx_from;
if (ctx_from.device_type == kCPU) ctx = ctx_to;
if (ctx_from.device_type == kDLCPU) ctx = ctx_to;
id<MTLCommandQueue> queue = GetCommandQueue(ctx);
id<MTLCommandBuffer> cb = [queue commandBuffer];
int from_dev_type = static_cast<int>(ctx_from.device_type);
int to_dev_type = static_cast<int>(ctx_to.device_type);
if (from_dev_type == kMetal && to_dev_type == kMetal) {
if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
CHECK_EQ(ctx_from.device_id, ctx_to.device_id)
<< "Metal disallow cross device copy.";
id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
......@@ -167,7 +167,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
size:size];
[encoder endEncoding];
[cb commit];
} else if (from_dev_type == kMetal && to_dev_type == kCPU) {
} else if (from_dev_type == kDLMetal && to_dev_type == kDLCPU) {
// copy to a local buffer before get into global buffer.
id<MTLBuffer> from_buf = (__bridge id<MTLBuffer>)(from);
if (from_buf.storageMode != MTLStorageModeShared) {
......@@ -190,7 +190,7 @@ void MetalWorkspace::CopyDataFromTo(const void* from,
static_cast<char*>([from_buf contents]) + from_offset,
size);
}
} else if (from_dev_type == kCPU && to_dev_type == kMetal) {
} else if (from_dev_type == kDLCPU && to_dev_type == kDLMetal) {
id<MTLBuffer> to_buf = (__bridge id<MTLBuffer>)(to);
if (to_buf.storageMode != MTLStorageModeShared) {
id<MTLBuffer> temp = MetalThreadEntry::ThreadLocal()
......
......@@ -133,7 +133,7 @@ class OpenCLWorkspace final : public DeviceAPI {
void Init();
// get the queue of the context
cl_command_queue GetQueue(TVMContext ctx) {
CHECK_EQ(ctx.device_type, kOpenCL);
CHECK_EQ(ctx.device_type, kDLOpenCL);
this->Init();
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
<< "Invalid OpenCL device_id=" << ctx.device_id;
......@@ -178,9 +178,9 @@ class OpenCLThreadEntry {
WorkspacePool pool;
// constructor
OpenCLThreadEntry()
: pool(kOpenCL, OpenCLWorkspace::Global()) {
: pool(kDLOpenCL, OpenCLWorkspace::Global()) {
context.device_id = 0;
context.device_type = kOpenCL;
context.device_type = kDLOpenCL;
}
// get the global workspace
static OpenCLThreadEntry* ThreadLocal();
......
......@@ -76,13 +76,13 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from,
TVMStreamHandle stream) {
this->Init();
CHECK(stream == nullptr);
if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kOpenCL) {
if (ctx_from.device_type == kDLOpenCL && ctx_to.device_type == kDLOpenCL) {
OPENCL_CALL(clEnqueueCopyBuffer(
this->GetQueue(ctx_to),
static_cast<cl_mem>((void*)from), // NOLINT(*)
static_cast<cl_mem>(to),
from_offset, to_offset, size, 0, nullptr, nullptr));
} else if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kCPU) {
} else if (ctx_from.device_type == kDLOpenCL && ctx_to.device_type == kDLCPU) {
OPENCL_CALL(clEnqueueReadBuffer(
this->GetQueue(ctx_from),
static_cast<cl_mem>((void*)from), // NOLINT(*)
......@@ -90,7 +90,7 @@ void OpenCLWorkspace::CopyDataFromTo(const void* from,
static_cast<char*>(to) + to_offset,
0, nullptr, nullptr));
OPENCL_CALL(clFinish(this->GetQueue(ctx_from)));
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kOpenCL) {
} else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLOpenCL) {
OPENCL_CALL(clEnqueueWriteBuffer(
this->GetQueue(ctx_to),
static_cast<cl_mem>(to),
......
......@@ -104,12 +104,12 @@ enum ArgConvertCode {
inline ArgConvertCode GetArgConvertCode(TVMType t) {
CHECK_EQ(t.lanes, 1U)
<< "Cannot pass vector type argument to devic function for now";
if (t.code == kInt) {
if (t.code == kDLInt) {
if (t.bits == 64U) return INT64_TO_INT64;
if (t.bits == 32U) return INT64_TO_INT32;
} else if (t.code == kUInt) {
} else if (t.code == kDLUInt) {
if (t.bits == 32U) return INT64_TO_UINT32;
} else if (t.code == kFloat) {
} else if (t.code == kDLFloat) {
if (t.bits == 64U) return FLOAT64_TO_FLOAT64;
if (t.bits == 32U) return FLOAT64_TO_FLOAT32;
} else if (t.code == kHandle) {
......
......@@ -77,7 +77,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
hipStream_t hip_stream = static_cast<hipStream_t>(stream);
from = static_cast<const char*>(from) + from_offset;
to = static_cast<char*>(to) + to_offset;
if (ctx_from.device_type == kROCM && ctx_to.device_type == kROCM) {
if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLROCM) {
ROCM_CALL(hipSetDevice(ctx_from.device_id));
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream);
......@@ -86,10 +86,10 @@ class ROCMDeviceAPI final : public DeviceAPI {
from, ctx_from.device_id,
size, hip_stream);
}
} else if (ctx_from.device_type == kROCM && ctx_to.device_type == kCPU) {
} else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) {
ROCM_CALL(hipSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream);
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kROCM) {
} else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) {
ROCM_CALL(hipSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream);
} else {
......@@ -138,7 +138,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
typedef dmlc::ThreadLocalStore<ROCMThreadEntry> ROCMThreadStore;
ROCMThreadEntry::ROCMThreadEntry()
: pool(kROCM, ROCMDeviceAPI::Global()) {
: pool(kDLROCM, ROCMDeviceAPI::Global()) {
}
ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() {
......
......@@ -55,12 +55,12 @@ class RPCDeviceAPI final : public DeviceAPI {
static_cast<const RemoteSpace*>(to)->data, to_offset,
size, ctx_from, ctx_to, stream);
} else if (from_dev_type > kRPCSessMask &&
to_dev_type == kCPU) {
to_dev_type == kDLCPU) {
GetSess(ctx_from)->CopyFromRemote(
static_cast<const RemoteSpace*>(from)->data, from_offset,
to, to_offset, size,
ctx_from);
} else if (from_dev_type == kCPU &&
} else if (from_dev_type == kDLCPU &&
to_dev_type > kRPCSessMask) {
GetSess(ctx_to)->CopyToRemote(
(void*)from, from_offset, // NOLINT(*)
......
......@@ -162,9 +162,9 @@ class RPCSession::EventHandler {
int tcode = type_codes[i];
TVMValue value = arg_values[i];
switch (tcode) {
case kInt:
case kUInt:
case kFloat:
case kDLInt:
case kDLUInt:
case kDLFloat:
case kTVMType: {
writer_->Write(&value, sizeof(TVMValue));
break;
......@@ -315,9 +315,9 @@ class RPCSession::EventHandler {
int tcode = arg_buf_->tcode[arg_index_];
static_assert(sizeof(TVMValue) == sizeof(uint64_t), "invariant");
switch (tcode) {
case kInt:
case kUInt:
case kFloat:
case kDLInt:
case kDLUInt:
case kDLFloat:
case kTVMType:
case kHandle:
case kStr:
......@@ -352,9 +352,9 @@ class RPCSession::EventHandler {
TVMValue& value = arg_buf_->value[arg_index_];
if (arg_recv_stage_ == 0) {
switch (tcode) {
case kInt:
case kUInt:
case kFloat:
case kDLInt:
case kDLUInt:
case kDLFloat:
case kTVMType:
case kTVMContext: {
this->Read(&value, sizeof(TVMValue));
......@@ -484,7 +484,7 @@ class RPCSession::EventHandler {
this->Read(&offset, sizeof(offset));
this->Read(&size, sizeof(size));
this->Read(&ctx, sizeof(ctx));
if (ctx.device_type == kCPU) {
if (ctx.device_type == kDLCPU) {
RPCCode code = RPCCode::kCopyAck;
writer_->Write(&code, sizeof(code));
writer_->Write(reinterpret_cast<char*>(handle) + offset, size);
......@@ -492,7 +492,7 @@ class RPCSession::EventHandler {
temp_data_.resize(size + 1);
try {
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
DeviceAPI::Get(ctx)->CopyDataFromTo(
reinterpret_cast<void*>(handle), offset,
......@@ -531,7 +531,7 @@ class RPCSession::EventHandler {
int ret_tcode = kNull;
RPCCode code = RPCCode::kReturn;
std::string errmsg;
if (copy_ctx_.device_type == kCPU) {
if (copy_ctx_.device_type == kDLCPU) {
this->Read(
reinterpret_cast<char*>(copy_handle_) + copy_offset_, copy_size_);
} else {
......@@ -539,7 +539,7 @@ class RPCSession::EventHandler {
this->Read(&temp_data_[0], copy_size_);
try {
TVMContext cpu_ctx;
cpu_ctx.device_type = kCPU;
cpu_ctx.device_type = kDLCPU;
cpu_ctx.device_id = 0;
DeviceAPI::Get(copy_ctx_)->CopyDataFromTo(
temp_data_.data(), 0,
......@@ -915,10 +915,10 @@ void RPCCopyAmongRemote(TVMArgs args, TVMRetValue *rv) {
TVMContext ctx_to = args[6];
TVMStreamHandle stream = args[7];
TVMContext ctx = ctx_from;
if (ctx.device_type == kCPU) {
if (ctx.device_type == kDLCPU) {
ctx = ctx_to;
} else {
CHECK(ctx_to.device_type == kCPU ||
CHECK(ctx_to.device_type == kDLCPU ||
ctx_to.device_type == ctx_from.device_type)
<< "Can not copy across different ctx types directly";
}
......
......@@ -14,7 +14,7 @@ TEST(PackedFunc, Basic) {
Var v = PackedFunc([&](TVMArgs args, TVMRetValue* rv) {
CHECK(args.num_args == 3);
CHECK(args.values[0].v_float64 == 1.0);
CHECK(args.type_codes[0] == kFloat);
CHECK(args.type_codes[0] == kDLFloat);
CHECK(args.values[1].v_handle == &a);
CHECK(args.type_codes[1] == kArrayHandle);
CHECK(args.values[2].v_handle == &x);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment