Commit 3957926e by Tianqi Chen Committed by GitHub

[RUNTIME] Make rutnime DLPack compatible, allow new device plugin (#71)

* [RUNTIME] Refactor runtime to be DLPack compatible. Enable plugin of new runtime.

* fix mac compile

* ok
parent 9e660dbe
......@@ -4,3 +4,6 @@
[submodule "HalideIR"]
path = HalideIR
url = ssh://git@github.com/tqchen/HalideIR
[submodule "dlpack"]
path = dlpack
url = https://github.com/dmlc/dlpack
......@@ -14,6 +14,9 @@ tvm_option(USE_MSVC_MT "Build with MT" OFF)
include_directories("include")
include_directories("HalideIR/src")
include_directories("dlpack/include")
set(TVM_LINKER_LIBS "")
set(TVM_RUNTIME_LINKER_LIBS "")
......
......@@ -28,7 +28,7 @@ ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
-Iinclude -Idlpack/include -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
ifdef CUDA_PATH
NVCC=$(CUDA_PATH)/bin/nvcc
......
Subproject commit 9f433c5ecfdd47184339cdd2b99706d24fae3aa1
......@@ -70,7 +70,13 @@ class BufferNode : public Node {
Array<Expr> strides;
/*! \brief data type in the content of the tensor */
Type dtype;
// Maybe need more information(alignment) later
/*!
* \brief The offset in bytes to the beginning pointer to data
* Can be undefined, indicating this must be zero.
*/
Expr byte_offset;
/*! \brief Alignment bytes size of byte_offset */
int offset_alignment;
/*! \brief constructor */
BufferNode() {}
......@@ -80,13 +86,17 @@ class BufferNode : public Node {
v->Visit("shape", &shape);
v->Visit("strides", &strides);
v->Visit("dtype", &dtype);
v->Visit("byte_offset", &byte_offset);
v->Visit("offset_alignment", &offset_alignment);
}
static Buffer make(std::string name,
Var ptr,
Array<Expr> shape,
Array<Expr> strides,
Type dtype);
Type dtype,
Expr byte_offset,
int offset_alignment);
static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
......
......@@ -41,6 +41,14 @@ using Halide::Internal::const_true;
using Halide::Internal::const_false;
using Halide::Internal::is_no_op;
inline Type TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
return Int(sizeof(tvm_index_t) * 8);
} else {
return UInt(sizeof(tvm_index_t) * 8);
}
}
inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
}
......
......@@ -167,7 +167,8 @@ enum TVMArrayFieldKind {
kStrides = 3,
kTypeCode = 4,
kTypeBits = 5,
kTypeLanes = 6
kTypeLanes = 6,
kByteOffset = 7
};
} // namespace intrinsic
......
......@@ -31,22 +31,22 @@
#include <stdint.h>
#include <stddef.h>
// TVM Runtime is DLPack compatible.
#include <dlpack/dlpack.h>
TVM_EXTERN_C {
/*! \brief type of array index. */
typedef uint32_t tvm_index_t;
typedef int64_t tvm_index_t;
/*!
* \brief The type code in TVMType
* \note TVMType is used in two places.
*/
typedef enum {
kInt = 0U,
kUInt = 1U,
kFloat = 2U,
kHandle = 3U,
// The type code of other types are compatible with DLPack.
// The next few fields are extension types
// that is used by TVM API calls.
kHandle = 3U,
kNull = 4U,
kArrayHandle = 5U,
kTVMType = 6U,
......@@ -67,14 +67,17 @@ typedef enum {
*
* \note Arguments TVM API function always takes bits=64 and lanes=1
*/
typedef struct {
/*! \brief type code, in TVMTypeCode */
uint8_t code;
/*! \brief number of bits of the type */
uint8_t bits;
/*! \brief number of lanes, */
uint16_t lanes;
} TVMType;
typedef DLDataType TVMType;
/*!
* \brief The Device information, abstract away common device types.
*/
typedef DLContext TVMContext;
/*!
* \brief The tensor array stucture to TVM API.
*/
typedef DLTensor TVMArray;
/*!
* \brief Union type of values
......@@ -97,50 +100,6 @@ typedef struct {
size_t size;
} TVMByteArray;
/*!
* \brief The device type
*/
typedef enum {
/*! \brief CPU device */
kCPU = 1,
/*! \brief NVidia GPU device(CUDA) */
kGPU = 2,
/*! \brief opencl device */
kOpenCL = 4
} TVMDeviceMask;
/*!
* \brief The Device information, abstract away common device types.
*/
typedef struct {
/*! \brief The device type mask */
int dev_mask;
/*! \brief the device id */
int dev_id;
} TVMContext;
/*!
* \brief Data structure representing a n-dimensional array(tensor).
* This is used to pass data specification into TVM.
*/
typedef struct {
/*! \brief The data field pointer on specified device */
void* data;
/*! \brief The shape pointers of the array */
const tvm_index_t* shape;
/*!
* \brief The stride data about each dimension of the array, can be NULL
* When strides is NULL, it indicates that the array is empty.
*/
const tvm_index_t* strides;
/*! \brief number of dimensions of the array */
tvm_index_t ndim;
/*! \brief The data type flag */
TVMType dtype;
/*! \brief The device context this array sits on */
TVMContext ctx;
} TVMArray;
/*! \brief Handle to TVM runtime modules. */
typedef void* TVMModuleHandle;
/*! \brief Handle to packed function handle. */
......
......@@ -8,35 +8,37 @@ import numpy as np
from .._base import _LIB, check_call
from .._base import c_array
from ._types import TVMType, tvm_index_t
from ._types import TVMType, tvm_shape_index_t
class TVMContext(ctypes.Structure):
"""TVM context strucure."""
_fields_ = [("dev_mask", ctypes.c_int),
("dev_id", ctypes.c_int)]
_fields_ = [("device_id", ctypes.c_int),
("device_type", ctypes.c_int)]
MASK2STR = {
1 : 'cpu',
2 : 'gpu',
4 : 'opencl'
}
def __init__(self, dev_mask, dev_id):
def __init__(self, device_id, device_type):
super(TVMContext, self).__init__()
self.dev_mask = dev_mask
self.dev_id = dev_id
self.device_id = device_id
self.device_type = device_type
def __repr__(self):
return "%s(%d)" % (
TVMContext.MASK2STR[self.dev_mask], self.dev_id)
TVMContext.MASK2STR[self.device_type], self.device_id)
class TVMArray(ctypes.Structure):
"""TVMValue in C API"""
_fields_ = [("data", ctypes.c_void_p),
("shape", ctypes.POINTER(tvm_index_t)),
("strides", ctypes.POINTER(tvm_index_t)),
("ndim", tvm_index_t),
("ctx", TVMContext),
("ndim", ctypes.c_int),
("dtype", TVMType),
("ctx", TVMContext)]
("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_size_t)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
......@@ -50,7 +52,7 @@ def cpu(dev_id=0):
dev_id : int, optional
The integer device id
"""
return TVMContext(1, dev_id)
return TVMContext(dev_id, 1)
def gpu(dev_id=0):
......@@ -61,7 +63,7 @@ def gpu(dev_id=0):
dev_id : int, optional
The integer device id
"""
return TVMContext(2, dev_id)
return TVMContext(dev_id, 2)
def opencl(dev_id=0):
......@@ -72,7 +74,7 @@ def opencl(dev_id=0):
dev_id : int, optional
The integer device id
"""
return TVMContext(4, dev_id)
return TVMContext(dev_id, 4)
def numpyasarray(np_data):
......@@ -81,7 +83,7 @@ def numpyasarray(np_data):
data = np_data
assert data.flags['C_CONTIGUOUS']
arr = TVMArray()
shape = c_array(tvm_index_t, data.shape)
shape = c_array(tvm_shape_index_t, data.shape)
arr.data = data.ctypes.data_as(ctypes.c_void_p)
arr.shape = shape
arr.strides = None
......@@ -114,8 +116,8 @@ def empty(shape, dtype="float32", ctx=cpu(0)):
arr : tvm.nd.NDArray
The array tvm supported.
"""
shape = c_array(tvm_index_t, shape)
ndim = tvm_index_t(len(shape))
shape = c_array(tvm_shape_index_t, shape)
ndim = ctypes.c_int(len(shape))
handle = TVMArrayHandle()
dtype = TVMType(dtype)
check_call(_LIB.TVMArrayAlloc(
......
......@@ -6,7 +6,7 @@ import ctypes
import numpy as np
from .._base import py_str
tvm_index_t = ctypes.c_uint32
tvm_shape_index_t = ctypes.c_int64
class TypeCode(object):
"""Type code used in API calls"""
......
......@@ -261,8 +261,10 @@ def call_packed(*args):
def Buffer(shape, dtype=None,
name="buffer",
ptr=None,
strides=None):
data=None,
strides=None,
byte_offset=None,
offset_alignment=0):
"""Create a new symbolic buffer
Parameters
......@@ -276,12 +278,18 @@ def Buffer(shape, dtype=None,
name : str, optional
The name of the buffer.
ptr : Var, optional
data : Var, optional
The data pointer in the buffer.
strides: array of Expr
The stride of the buffer.
byte_offset: Expr, optional
The offset in bytes to data pointer.
offset_alignment: int, optional
The alignment of offset
Returns
-------
buffer : Buffer
......@@ -290,11 +298,11 @@ def Buffer(shape, dtype=None,
shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides
if ptr is None:
ptr = Var(name, "handle")
if data is None:
data = Var(name, "handle")
return _api_internal._Buffer(
name, ptr, shape, strides, dtype)
name, data, shape, strides, dtype, byte_offset, offset_alignment)
def _IterVar(dom, name, iter_type, thread_tag=''):
......
......@@ -138,7 +138,9 @@ TVM_REGISTER_API(_Buffer)
args[1],
args[2],
args[3],
args[4]);
args[4],
args[5],
args[6]);
});
TVM_REGISTER_API(_Tensor)
......
......@@ -33,17 +33,18 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_int32_ = llvm::Type::getInt32Ty(*ctx);
t_int64_ = llvm::Type::getInt64Ty(*ctx);
t_float64_ = llvm::Type::getDoubleTy(*ctx);
t_tvm_index_ = llvm::Type::getIntNTy(*ctx, sizeof(tvm_index_t) * 8);
t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, TVMShapeIndexType().bits());
t_tvm_context_ = llvm::StructType::create({t_int_, t_int_});
t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_});
t_tvm_func_handle_ = t_void_p_;
t_tvm_array_ = llvm::StructType::create(
{t_void_p_,
t_tvm_index_->getPointerTo(),
t_tvm_index_->getPointerTo(),
t_tvm_index_,
t_tvm_context_,
t_int_,
t_tvm_type_,
t_tvm_context_});
t_tvm_shape_index_->getPointerTo(),
t_tvm_shape_index_->getPointerTo(),
t_int64_});
t_tvm_value_ = llvm::StructType::create({t_float64_});
t_f_tvm_par_for_lambda_ = llvm::FunctionType::get(
t_int_, {t_int64_, t_int64_, t_void_p_}, false);
......@@ -663,25 +664,29 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(0)}); break;
}
case intrinsic::kShape: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(1)}); break;
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(4)}); break;
}
case intrinsic::kStrides: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(2)}); break;
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(5)}); break;
}
case intrinsic::kNDim: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(3)}); break;
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(2)}); break;
}
case intrinsic::kTypeCode: {
ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(4), ConstInt32(0)}); break;
arr, {zero, ConstInt32(3), ConstInt32(0)}); break;
}
case intrinsic::kTypeBits: {
ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(4), ConstInt32(1)}); break;
arr, {zero, ConstInt32(3), ConstInt32(1)}); break;
}
case intrinsic::kTypeLanes: {
ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(4), ConstInt32(2)}); break;
arr, {zero, ConstInt32(3), ConstInt32(2)}); break;
}
case intrinsic::kByteOffset: {
ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(6)}); break;
}
default: LOG(FATAL) << "unknown field code";
}
......
......@@ -160,7 +160,7 @@ class CodeGenLLVM :
llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr};
// TVM related data types
llvm::Type* t_tvm_index_{nullptr};
llvm::Type* t_tvm_shape_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr};
llvm::StructType* t_tvm_context_{nullptr};
llvm::StructType* t_tvm_type_{nullptr};
......
......@@ -166,6 +166,7 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
case intrinsic::kTypeCode: PushOp(StackVM::TVM_ARRAY_GET_TYPE_CODE); break;
case intrinsic::kTypeBits: PushOp(StackVM::TVM_ARRAY_GET_TYPE_BITS); break;
case intrinsic::kTypeLanes: PushOp(StackVM::TVM_ARRAY_GET_TYPE_LANES); break;
case intrinsic::kByteOffset: PushOp(StackVM::TVM_ARRAY_GET_BYTE_OFFSET); break;
default: LOG(FATAL) << "unknown field code";
}
} else if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
......@@ -227,15 +228,12 @@ void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
void CodeGenStackVM::PushCast(Type dst, Type src) {
if (dst.is_int()) {
if (src.is_int()) return;
if (src.is_uint() && src.bits() <= 32) return;
} else if (dst.is_uint() && dst.bits() <= 32) {
if (src.is_int()) return;
if (src.is_uint() && src.bits() <= 32) return;
if (src.is_int() || src.is_uint()) return;
} else if (dst.is_uint()) {
if (src.is_int() || src.is_uint()) return;
} else if (dst.is_float()) {
if (src.is_float()) return;
}
LOG(FATAL) << "Cannot handle cast " << src << " to " << dst;
}
void CodeGenStackVM::VisitExpr_(const StringImm *op) {
......
......@@ -139,6 +139,7 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const {
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_SHAPE);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_STRIDES);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_NDIM);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_BYTE_OFFSET);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_CODE);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_BITS);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_LANES);
......@@ -352,6 +353,9 @@ void StackVM::Run(State* s) const {
case TVM_ARRAY_GET_NDIM: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break;
}
case TVM_ARRAY_GET_BYTE_OFFSET: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, byte_offset); break;
}
case TVM_ARRAY_GET_TYPE_CODE: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.code); break;
}
......
......@@ -199,7 +199,8 @@ class StackVM {
TVM_ARRAY_GET_NDIM,
TVM_ARRAY_GET_TYPE_CODE,
TVM_ARRAY_GET_TYPE_BITS,
TVM_ARRAY_GET_TYPE_LANES
TVM_ARRAY_GET_TYPE_LANES,
TVM_ARRAY_GET_BYTE_OFFSET
};
/*! \brief The code structure */
union Code {
......
......@@ -22,7 +22,8 @@ Buffer::Buffer(Array<Expr> shape,
: Buffer(BufferNode::make(
name,
Var(name, Type(Type::Handle, 0, 0)),
shape, Array<Expr>(), dtype)) {
shape, Array<Expr>(), dtype,
Expr(), 0)) {
}
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
......@@ -40,6 +41,9 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
base = base + index[i] * n->strides[i];
}
}
if (!is_zero(n->byte_offset)) {
base = base + (n->byte_offset / n->dtype.bytes());
}
return base;
}
......@@ -58,13 +62,27 @@ Buffer BufferNode::make(std::string name,
Var data,
Array<Expr> shape,
Array<Expr> strides,
Type dtype) {
Type dtype,
Expr byte_offset,
int offset_alignment) {
auto n = std::make_shared<BufferNode>();
n->name = name;
n->data = data;
n->shape = shape;
n->strides = strides;
n->dtype = dtype;
if (!byte_offset.defined()) {
byte_offset = make_const(shape[0].type(), 0);
}
if (offset_alignment != 0) {
CHECK_EQ(offset_alignment % dtype.bytes(), 0)
<< "Offset alignments must be at least " << dtype.bytes();
} else {
offset_alignment = dtype.bytes();
}
n->byte_offset = byte_offset;
n->offset_alignment = offset_alignment;
return Buffer(n);
}
......
......@@ -36,7 +36,8 @@ LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_unpacked_args) {
const Type tvm_index_type = UInt(32);
const Type tvm_shape_type = TVMShapeIndexType();
const Type tvm_ndim_type = Int(32);
const Stmt nop = Evaluate::make(0);
int num_args = static_cast<int>(api_args.size());
CHECK_LE(num_unpacked_args, num_args);
......@@ -120,13 +121,15 @@ LoweredFunc MakeAPI(Stmt body,
<< "api_args can only be Buffer or Var";
Buffer buf(api_args[i].node_);
// dimension checks
Expr v_ndim = TVMArrayGet(tvm_index_type, v_arg, intrinsic::kNDim);
Expr v_ndim = TVMArrayGet(tvm_ndim_type, v_arg, intrinsic::kNDim);
std::ostringstream ndim_err_msg;
ndim_err_msg << "arg_" << i
<< ".ndim is expected to equal "
<< buf->shape.size();
seq_init.emplace_back(
MakeAssertEQ(v_ndim, UIntImm::make(tvm_index_type, buf->shape.size()),
MakeAssertEQ(v_ndim,
make_const(tvm_ndim_type,
static_cast<int64_t>(buf->shape.size())),
ndim_err_msg.str()));
// type checks
Type dtype = buf->dtype;
......@@ -147,7 +150,7 @@ LoweredFunc MakeAPI(Stmt body,
}
// shape field
Var v_shape(v_arg->name_hint + ".shape", Handle());
handle_data_type.Set(v_shape, UIntImm::make(tvm_index_type, 0));
handle_data_type.Set(v_shape, make_const(tvm_shape_type, 0));
seq_init.emplace_back(LetStmt::make(
v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kShape), nop));
for (size_t k = 0; k < buf->shape.size(); ++k) {
......@@ -155,12 +158,12 @@ LoweredFunc MakeAPI(Stmt body,
field_name << v_shape->name_hint << '[' << k << ']';
f_push(buf->shape[k],
cast(buf->shape[k].type(),
Load::make(tvm_index_type, v_shape, IntImm::make(Int(32), k))),
Load::make(tvm_shape_type, v_shape, IntImm::make(Int(32), k))),
field_name.str());
}
// strides field
Var v_strides(v_arg->name_hint + ".strides", Handle());
handle_data_type.Set(v_strides, UIntImm::make(tvm_index_type, 0));
handle_data_type.Set(v_strides, make_const(tvm_shape_type, 0));
seq_init.emplace_back(LetStmt::make(
v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kStrides), nop));
if (buf->strides.size() == 0) {
......@@ -174,10 +177,13 @@ LoweredFunc MakeAPI(Stmt body,
field_name << v_strides->name_hint << '[' << k << ']';
f_push(buf->strides[k],
cast(buf->shape[k].type(),
Load::make(tvm_index_type, v_strides, IntImm::make(Int(32), k))),
Load::make(tvm_shape_type, v_strides, IntImm::make(Int(32), k))),
field_name.str());
}
}
// Byte_offset field.
f_push(buf->byte_offset, TVMArrayGet(UInt(64), v_arg, intrinsic::kByteOffset),
v_arg->name_hint + ".byte_offset");
}
}
......
......@@ -7,17 +7,56 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <dmlc/timer.h>
#include <array>
#include <algorithm>
#include <string>
#include <cstdlib>
#include <thread>
#include <mutex>
#include "./runtime_base.h"
#include "./device_api.h"
namespace tvm {
namespace runtime {
class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = 16;
// Get API
static DeviceAPI* Get(TVMContext ctx) {
return Global()->GetAPI(ctx.device_type);
}
private:
std::array<DeviceAPI*, kMaxDeviceAPI> api_;
std::mutex mutex_;
// constructor
DeviceAPIManager() {
std::fill(api_.begin(), api_.end(), nullptr);
}
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager inst;
return &inst;
}
// Get or initialize API.
DeviceAPI* GetAPI(DLDeviceType type) {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
std::string factory = "_device_api_" + DeviceName(type);
auto* f = Registry::Get(factory);
CHECK(f != nullptr)
<< "Device API " << DeviceName(type) << " is not enabled.";
void* ptr = (*f)();
api_[type] = static_cast<DeviceAPI*>(ptr);
return api_[type];
}
};
inline TVMArray* TVMArrayCreate_() {
TVMArray* arr = new TVMArray();
arr->shape = nullptr;
......@@ -33,9 +72,8 @@ inline void TVMArrayFree_(TVMArray* arr) {
delete[] arr->shape;
delete[] arr->strides;
if (arr->data != nullptr) {
TVM_DEVICE_SWITCH(arr->ctx, {
FreeDataSpace<xpu>(arr->ctx, arr->data);
});
DeviceAPIManager::Get(arr->ctx)->FreeDataSpace(
arr->ctx, arr->data);
}
}
delete arr;
......@@ -282,10 +320,8 @@ int TVMArrayAlloc(const tvm_index_t* shape,
arr->ctx = ctx;
size_t size = GetDataSize(arr);
size_t alignment = GetDataAlignment(arr);
// ctx data pointer
TVM_DEVICE_SWITCH(ctx, {
arr->data = AllocDataSpace<xpu>(ctx, size, alignment);
});
arr->data = DeviceAPIManager::Get(ctx)->AllocDataSpace(
ctx, size, alignment);
*out = arr;
API_END_HANDLE_ERROR(TVMArrayFree_(arr));
}
......@@ -306,28 +342,21 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
CHECK_EQ(from_size, to_size)
<< "TVMArrayCopyFromTo: The size must exactly match";
TVMContext ctx = from->ctx;
if (ctx.dev_mask == kCPU) {
if (ctx.device_type == kCPU) {
ctx = to->ctx;
} else {
CHECK(to->ctx.dev_mask == kCPU ||
to->ctx.dev_mask == from->ctx.dev_mask)
CHECK(to->ctx.device_type == kCPU ||
to->ctx.device_type == from->ctx.device_type)
<< "Can not copy across different ctx types directly";
}
TVM_DEVICE_SWITCH(ctx, {
CopyDataFromTo<xpu>(from->data, to->data,
from_size,
from->ctx,
to->ctx,
stream);
});
DeviceAPIManager::Get(ctx)->CopyDataFromTo(
from->data, to->data, from_size,
from->ctx, to->ctx, stream);
API_END();
}
int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
API_BEGIN();
TVM_DEVICE_SWITCH(ctx, {
StreamSync<xpu>(ctx, stream);
});
DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
API_END();
}
......@@ -7,6 +7,7 @@
#define TVM_RUNTIME_DEVICE_API_CPU_H_
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <cstdlib>
#include <cstring>
#include "./device_api.h"
......@@ -14,8 +15,9 @@
namespace tvm {
namespace runtime {
template<>
void* AllocDataSpace<kCPU>(TVMContext ctx, size_t size, size_t alignment) {
class CPUDeviceAPI : public DeviceAPI {
public:
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
void* ptr;
#if _MSC_VER
ptr = _aligned_malloc(size, alignment);
......@@ -25,30 +27,35 @@ void* AllocDataSpace<kCPU>(TVMContext ctx, size_t size, size_t alignment) {
if (ret != 0) throw std::bad_alloc();
#endif
return ptr;
}
}
template<>
void FreeDataSpace<kCPU>(TVMContext ctx, void* ptr) {
void FreeDataSpace(TVMContext ctx, void* ptr) final {
#if _MSC_VER
_aligned_free(ptr);
#else
free(ptr);
#endif
}
}
template<>
void CopyDataFromTo<kCPU>(const void* from,
void CopyDataFromTo(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) {
TVMStreamHandle stream) final {
memcpy(to, from, size);
}
}
template<>
void StreamSync<kCPU>(TVMContext ctx, TVMStreamHandle stream) {
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
}
};
TVM_REGISTER_GLOBAL(_device_api_cpu)
.set_body([](TVMArgs args, TVMRetValue* rv) {
static CPUDeviceAPI inst;
DeviceAPI* ptr = &inst;
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_CPU_H_
/*!
* Copyright (c) 2017 by Contributors
* \file device_api_cuda.h
* \file cuda_device_api.cc
* \brief GPU specific API
*/
#ifndef TVM_RUNTIME_CUDA_DEVICE_API_CUDA_H_
#define TVM_RUNTIME_CUDA_DEVICE_API_CUDA_H_
#include "./cuda_common.h"
#include <tvm/runtime/config.h>
#if TVM_CUDA_RUNTIME
#include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <cuda_runtime.h>
#include "./cuda_common.h"
#include "../device_api.h"
namespace tvm {
namespace runtime {
template<>
inline void* AllocDataSpace<kGPU>(TVMContext ctx, size_t size, size_t alignment) {
CUDA_CALL(cudaSetDevice(ctx.dev_id));
class CUDADeviceAPI : public DeviceAPI {
public:
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
CHECK_EQ(256 % alignment, 0U)
<< "CUDA space is aligned at 256 bytes";
void *ret;
CUDA_CALL(cudaMalloc(&ret, size));
return ret;
}
}
template<>
inline void FreeDataSpace<kGPU>(TVMContext ctx, void* ptr) {
CUDA_CALL(cudaSetDevice(ctx.dev_id));
void FreeDataSpace(TVMContext ctx, void* ptr) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
CUDA_CALL(cudaFree(ptr));
}
inline void GPUCopy(const void* from,
void* to,
size_t size,
cudaMemcpyKind kind,
cudaStream_t stream) {
if (stream != 0) {
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
} else {
CUDA_CALL(cudaMemcpy(to, from, size, kind));
}
}
template<>
inline void CopyDataFromTo<kGPU>(const void* from,
void CopyDataFromTo(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) {
TVMStreamHandle stream) final {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
if (ctx_from.dev_mask == kGPU && ctx_to.dev_mask == kGPU) {
CUDA_CALL(cudaSetDevice(ctx_from.dev_id));
if (ctx_from.dev_id == ctx_to.dev_id) {
if (ctx_from.device_type == kGPU && ctx_to.device_type == kGPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
} else {
cudaMemcpyPeerAsync(to, ctx_to.dev_id,
from, ctx_from.dev_id,
cudaMemcpyPeerAsync(to, ctx_to.device_id,
from, ctx_from.device_id,
size, cu_stream);
}
} else if (ctx_from.dev_mask == kGPU && ctx_to.dev_mask == kCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.dev_id));
} else if (ctx_from.device_type == kGPU && ctx_to.device_type == kCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.dev_mask == kCPU && ctx_to.dev_mask == kGPU) {
CUDA_CALL(cudaSetDevice(ctx_to.dev_id));
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kGPU) {
CUDA_CALL(cudaSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
} else {
LOG(FATAL) << "expect copy from/to GPU or between GPU";
}
}
}
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
}
private:
static void GPUCopy(const void* from,
void* to,
size_t size,
cudaMemcpyKind kind,
cudaStream_t stream) {
if (stream != 0) {
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
} else {
CUDA_CALL(cudaMemcpy(to, from, size, kind));
}
}
};
template<>
inline void StreamSync<kGPU>(TVMContext ctx, TVMStreamHandle stream) {
CUDA_CALL(cudaSetDevice(ctx.dev_id));
CUDA_CALL(cudaStreamSynchronize(
static_cast<cudaStream_t>(stream)));
}
TVM_REGISTER_GLOBAL(_device_api_gpu)
.set_body([](TVMArgs args, TVMRetValue* rv) {
static CUDADeviceAPI inst;
DeviceAPI* ptr = &inst;
*rv = static_cast<void*>(ptr);
});
} // namespace runtime
} // namespace tvm
#endif // TVM_CUDA_RUNTIME
#endif // TVM_RUNTIME_CUDA_DEVICE_API_CUDA_H_
......@@ -50,9 +50,9 @@ class CUDAModuleNode : public runtime::ModuleNode {
}
void PreCompile(const std::string& name, TVMContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.dev_id));
CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaFree(nullptr);
this->GetFunc(ctx.dev_id, name);
this->GetFunc(ctx.device_id, name);
}
PackedFunc GetFunction(
......@@ -79,15 +79,15 @@ class CUDAModuleNode : public runtime::ModuleNode {
}
}
// get a CUfunction from primary context in dev_id
CUfunction GetFunc(int dev_id, const std::string& func_name) {
// get a CUfunction from primary context in device_id
CUfunction GetFunc(int device_id, const std::string& func_name) {
std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope
if (module_[dev_id] == nullptr) {
CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[dev_id]), data_.c_str()));
if (module_[device_id] == nullptr) {
CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str()));
}
CUfunction func;
CUresult result = cuModuleGetFunction(&func, module_[dev_id], func_name.c_str());
CUresult result = cuModuleGetFunction(&func, module_[device_id], func_name.c_str());
if (result != CUDA_SUCCESS) {
const char *msg;
cuGetErrorName(result, &msg);
......@@ -132,14 +132,14 @@ class CUDAWrappedFunc {
void operator()(TVMArgs args,
TVMRetValue* rv,
void** void_args) const {
int dev_id;
CUDA_CALL(cudaGetDevice(&dev_id));
if (fcache_[dev_id] == nullptr) {
fcache_[dev_id] = m_->GetFunc(dev_id, func_name_);
int device_id;
CUDA_CALL(cudaGetDevice(&device_id));
if (fcache_[device_id] == nullptr) {
fcache_[device_id] = m_->GetFunc(device_id, func_name_);
}
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
CUDA_DRIVER_CALL(cuLaunchKernel(
fcache_[dev_id],
fcache_[device_id],
wl.grid_dim(0),
wl.grid_dim(1),
wl.grid_dim(2),
......@@ -169,23 +169,23 @@ void AutoSetCUDADevice(const TVMArgs& args, TVMRetValue* rv) {
int* type_codes = static_cast<int*>(args[1].operator void*());
int num_args = args[2].operator int();
int dev_id = -1;
int device_id = -1;
for (int i = 0; i < num_args; ++i) {
if (type_codes[i] == kArrayHandle) {
TVMContext ctx = static_cast<TVMArray*>(values[i].v_handle)->ctx;
CHECK_EQ(ctx.dev_mask, kGPU)
CHECK_EQ(ctx.device_type, kGPU)
<< "All operands need to be GPU";
if (dev_id == -1) {
dev_id = ctx.dev_id;
if (device_id == -1) {
device_id = ctx.device_id;
} else {
CHECK_EQ(dev_id, ctx.dev_id)
CHECK_EQ(device_id, ctx.device_id)
<< "Operands comes from different devices ";
}
}
}
CHECK_NE(dev_id, -1)
CHECK_NE(device_id, -1)
<< "Cannot detect device id from list";
CUDA_CALL(cudaSetDevice(dev_id));
CUDA_CALL(cudaSetDevice(device_id));
}
PackedFunc CUDAModuleNode::GetFunction(
......
......@@ -8,30 +8,31 @@
#include <tvm/base.h>
#include <tvm/runtime/c_runtime_api.h>
#include <string>
namespace tvm {
namespace runtime {
/*!
class DeviceAPI {
public:
/*! \brief virtual destructor */
virtual ~DeviceAPI() {}
/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
* \param size The size of the memory
* \param alignment The alignment of the memory.
* \return The allocated device pointer
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment);
/*!
virtual void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) = 0;
/*!
* \brief Free a data space on device.
* \param ctx The device context to perform operation.
* \param ptr The data space.
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline void FreeDataSpace(TVMContext ctx, void* ptr);
/*!
virtual void FreeDataSpace(TVMContext ctx, void* ptr) = 0;
/*!
* \brief copy data from one place to another
* \param dev The device to perform operation.
* \param from The source array.
......@@ -39,52 +40,33 @@ inline void FreeDataSpace(TVMContext ctx, void* ptr);
* \param size The size of the memory
* \param ctx_from The source context
* \param ctx_to The target context
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline void CopyDataFromTo(const void* from,
virtual void CopyDataFromTo(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream);
/*!
TVMStreamHandle stream) = 0;
/*!
* \brief Synchronize the stream
* \param ctx The context to perform operation.
* \param stream The stream to be sync.
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline void StreamSync(TVMContext ctx, TVMStreamHandle stream);
// macro to run cuda related code
#if TVM_CUDA_RUNTIME
#define TVM_RUN_CUDA(OP) { const TVMDeviceMask xpu = kGPU; OP; }
#else
#define TVM_RUN_CUDA(OP) LOG(FATAL) << "CUDA is not enabled";
#endif
virtual void StreamSync(TVMContext ctx, TVMStreamHandle stream) = 0;
};
// macro to run opencl related code
#if TVM_OPENCL_RUNTIME
#define TVM_RUN_OPENCL(OP) { const TVMDeviceMask xpu = kOpenCL; OP; }
#else
#define TVM_RUN_OPENCL(OP) LOG(FATAL) << "OpenCL is not enabled";
#endif
// macro to switch options between devices
#define TVM_DEVICE_SWITCH(ctx, OP) \
switch (ctx.dev_mask) { \
case kCPU: { const TVMDeviceMask xpu = kCPU; OP; break; } \
case kGPU: TVM_RUN_CUDA(OP); break; \
case kOpenCL: TVM_RUN_OPENCL(OP); break; \
default: LOG(FATAL) << "unknown device_mask " << ctx.dev_mask; \
/*!
* \brief The name of Device API factory.
* \param type The device type.
*/
inline std::string DeviceName(DLDeviceType type) {
switch (static_cast<int>(type)) {
case kCPU: return "cpu";
case kGPU: return "gpu";
case kOpenCL: return "opencl";
default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
}
}
} // namespace runtime
} // namespace tvm
#include "./device_api_cpu.h"
#include "./cuda/device_api_cuda.h"
#include "./opencl/device_api_opencl.h"
#endif // TVM_RUNTIME_DEVICE_API_H_
/*!
* Copyright (c) 2017 by Contributors
* \file device_api_opencl.h
* \brief OpenCL specific API
*/
#ifndef TVM_RUNTIME_OPENCL_DEVICE_API_OPENCL_H_
#define TVM_RUNTIME_OPENCL_DEVICE_API_OPENCL_H_
#include <tvm/runtime/config.h>
#if TVM_OPENCL_RUNTIME
#include <string>
#include <vector>
#include "./opencl_common.h"
namespace tvm {
namespace runtime {
template<>
inline void* AllocDataSpace<kOpenCL>(TVMContext ctx, size_t size, size_t alignment) {
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
cl_int err_code;
cl_mem mptr = clCreateBuffer(
w->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
OPENCL_CHECK_ERROR(err_code);
return mptr;
}
template<>
inline void FreeDataSpace<kOpenCL>(TVMContext ctx, void* ptr) {
cl_mem mptr = static_cast<cl_mem>(ptr);
OPENCL_CALL(clReleaseMemObject(mptr));
}
template<>
inline void CopyDataFromTo<kOpenCL>(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) {
CHECK(stream == nullptr);
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
if (ctx_from.dev_mask == kOpenCL && ctx_to.dev_mask == kOpenCL) {
OPENCL_CALL(clEnqueueCopyBuffer(
w->GetQueue(ctx_to),
static_cast<cl_mem>((void*)from), // NOLINT(*)
static_cast<cl_mem>(to),
0, 0, size, 0, nullptr, nullptr));
} else if (ctx_from.dev_mask == kOpenCL && ctx_to.dev_mask == kCPU) {
OPENCL_CALL(clEnqueueReadBuffer(
w->GetQueue(ctx_from),
static_cast<cl_mem>((void*)from), // NOLINT(*)
CL_FALSE, 0, size, to,
0, nullptr, nullptr));
OPENCL_CALL(clFinish(w->GetQueue(ctx_from)));
} else if (ctx_from.dev_mask == kCPU && ctx_to.dev_mask == kOpenCL) {
OPENCL_CALL(clEnqueueWriteBuffer(
w->GetQueue(ctx_to),
static_cast<cl_mem>(to),
CL_FALSE, 0, size, from,
0, nullptr, nullptr));
OPENCL_CALL(clFinish(w->GetQueue(ctx_to)));
} else {
LOG(FATAL) << "Expect copy from/to GPU or between GPU";
}
}
template<>
inline void StreamSync<kOpenCL>(TVMContext ctx, TVMStreamHandle stream) {
CHECK(stream == nullptr);
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
OPENCL_CALL(clFinish(w->GetQueue(ctx)));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_OPENCL_RUNTIME
#endif // TVM_RUNTIME_OPENCL_DEVICE_API_OPENCL_H_
......@@ -10,8 +10,8 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <dmlc/logging.h>
#if TVM_OPENCL_RUNTIME
#include "../device_api.h"
#ifdef __APPLE__
#include <OpenCL/opencl.h>
......@@ -101,7 +101,7 @@ inline const char* CLGetErrorString(cl_int error) {
/*!
* \brief Process global OpenCL workspace.
*/
class OpenCLWorkspace {
class OpenCLWorkspace : public DeviceAPI {
public:
// global platform id
cl_platform_id platform_id;
......@@ -132,13 +132,23 @@ class OpenCLWorkspace {
}
// get the queue of the context
cl_command_queue GetQueue(TVMContext ctx) const {
CHECK_EQ(ctx.dev_mask, kOpenCL);
CHECK_EQ(ctx.device_type, kOpenCL);
CHECK(initialized())
<< "The OpenCL is not initialized";
CHECK(ctx.dev_id >= 0 && static_cast<size_t>(ctx.dev_id) < queues.size())
<< "Invalid OpenCL dev_id=" << ctx.dev_id;
return queues[ctx.dev_id];
CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
<< "Invalid OpenCL device_id=" << ctx.device_id;
return queues[ctx.device_id];
}
// override device API
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
// get the global workspace
static OpenCLWorkspace* Global();
};
......@@ -160,8 +170,8 @@ class OpenCLThreadEntry {
std::vector<KTEntry> kernel_table;
OpenCLThreadEntry() {
context.dev_id = 0;
context.dev_mask = kOpenCL;
context.device_id = 0;
context.device_type = kOpenCL;
}
// get the global workspace
static OpenCLThreadEntry* ThreadLocal();
......
/*!
* Copyright (c) 2017 by Contributors
* \file opencl_workspace.cc
* \file opencl_device_api.cc
*/
#include "./opencl_common.h"
......@@ -18,6 +18,57 @@ OpenCLWorkspace* OpenCLWorkspace::Global() {
return &inst;
}
void* OpenCLWorkspace::AllocDataSpace(
TVMContext ctx, size_t size, size_t alignment) {
cl_int err_code;
cl_mem mptr = clCreateBuffer(
this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
OPENCL_CHECK_ERROR(err_code);
return mptr;
}
void OpenCLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
cl_mem mptr = static_cast<cl_mem>(ptr);
OPENCL_CALL(clReleaseMemObject(mptr));
}
void OpenCLWorkspace::CopyDataFromTo(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) {
CHECK(stream == nullptr);
if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kOpenCL) {
OPENCL_CALL(clEnqueueCopyBuffer(
this->GetQueue(ctx_to),
static_cast<cl_mem>((void*)from), // NOLINT(*)
static_cast<cl_mem>(to),
0, 0, size, 0, nullptr, nullptr));
} else if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kCPU) {
OPENCL_CALL(clEnqueueReadBuffer(
this->GetQueue(ctx_from),
static_cast<cl_mem>((void*)from), // NOLINT(*)
CL_FALSE, 0, size, to,
0, nullptr, nullptr));
OPENCL_CALL(clFinish(this->GetQueue(ctx_from)));
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kOpenCL) {
OPENCL_CALL(clEnqueueWriteBuffer(
this->GetQueue(ctx_to),
static_cast<cl_mem>(to),
CL_FALSE, 0, size, from,
0, nullptr, nullptr));
OPENCL_CALL(clFinish(this->GetQueue(ctx_to)));
} else {
LOG(FATAL) << "Expect copy from/to OpenCL or between OpenCL";
}
}
void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
CHECK(stream == nullptr);
OPENCL_CALL(clFinish(this->GetQueue(ctx)));
}
typedef dmlc::ThreadLocalStore<OpenCLThreadEntry> OpenCLThreadStore;
OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() {
......@@ -141,6 +192,12 @@ bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
TVM_REGISTER_GLOBAL(_module_init_opencl)
.set_body(InitOpenCL);
TVM_REGISTER_GLOBAL(_device_api_opencl)
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = OpenCLWorkspace::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace cl
} // namespace runtime
} // namespace tvm
......
......@@ -123,11 +123,11 @@ class OpenCLModuleNode : public ModuleNode {
const std::string& func_name,
const KTRefEntry& e) {
std::lock_guard<std::mutex> lock(build_lock_);
int dev_id = t->context.dev_id;
if (!device_built_flag_[dev_id]) {
int device_id = t->context.device_id;
if (!device_built_flag_[device_id]) {
// build program
cl_int err;
cl_device_id dev = w->devices[dev_id];
cl_device_id dev = w->devices[device_id];
err = clBuildProgram(program_, 1, &dev, nullptr, nullptr, nullptr);
if (err != CL_SUCCESS) {
size_t len;
......@@ -139,7 +139,7 @@ class OpenCLModuleNode : public ModuleNode {
program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr);
LOG(FATAL) << "OpenCL build error for device=" << dev << log;
}
device_built_flag_[dev_id] = true;
device_built_flag_[device_id] = true;
}
// build kernel
cl_int err;
......@@ -246,23 +246,23 @@ void AutoSetOpenCLDevice(const TVMArgs& args, TVMRetValue* rv) {
int num_args = args[2].operator int();
// TODO(tqchen): merge this with CUDA logic.
int dev_id = -1;
int device_id = -1;
for (int i = 0; i < num_args; ++i) {
if (type_codes[i] == kArrayHandle) {
TVMContext ctx = static_cast<TVMArray*>(values[i].v_handle)->ctx;
CHECK_EQ(ctx.dev_mask, kOpenCL)
CHECK_EQ(ctx.device_type, kOpenCL)
<< "All operands need to be OpenCL";
if (dev_id == -1) {
dev_id = ctx.dev_id;
if (device_id == -1) {
device_id = ctx.device_id;
} else {
CHECK_EQ(dev_id, ctx.dev_id)
CHECK_EQ(device_id, ctx.device_id)
<< "Operands comes from different devices ";
}
}
}
CHECK_NE(dev_id, -1)
CHECK_NE(device_id, -1)
<< "Cannot detect device id from list";
cl::OpenCLThreadEntry::ThreadLocal()->context.dev_id = dev_id;
cl::OpenCLThreadEntry::ThreadLocal()->context.device_id = device_id;
}
PackedFunc OpenCLModuleNode::GetFunction(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment