Commit 3957926e by Tianqi Chen Committed by GitHub

[RUNTIME] Make rutnime DLPack compatible, allow new device plugin (#71)

* [RUNTIME] Refactor runtime to be DLPack compatible. Enable plugin of new runtime.

* fix mac compile

* ok
parent 9e660dbe
...@@ -4,3 +4,6 @@ ...@@ -4,3 +4,6 @@
[submodule "HalideIR"] [submodule "HalideIR"]
path = HalideIR path = HalideIR
url = ssh://git@github.com/tqchen/HalideIR url = ssh://git@github.com/tqchen/HalideIR
[submodule "dlpack"]
path = dlpack
url = https://github.com/dmlc/dlpack
...@@ -14,6 +14,9 @@ tvm_option(USE_MSVC_MT "Build with MT" OFF) ...@@ -14,6 +14,9 @@ tvm_option(USE_MSVC_MT "Build with MT" OFF)
include_directories("include") include_directories("include")
include_directories("HalideIR/src") include_directories("HalideIR/src")
include_directories("dlpack/include")
set(TVM_LINKER_LIBS "") set(TVM_LINKER_LIBS "")
set(TVM_RUNTIME_LINKER_LIBS "") set(TVM_RUNTIME_LINKER_LIBS "")
......
...@@ -28,7 +28,7 @@ ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR) ...@@ -28,7 +28,7 @@ ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\ export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0 -Iinclude -Idlpack/include -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
ifdef CUDA_PATH ifdef CUDA_PATH
NVCC=$(CUDA_PATH)/bin/nvcc NVCC=$(CUDA_PATH)/bin/nvcc
......
Subproject commit 9f433c5ecfdd47184339cdd2b99706d24fae3aa1
...@@ -70,7 +70,13 @@ class BufferNode : public Node { ...@@ -70,7 +70,13 @@ class BufferNode : public Node {
Array<Expr> strides; Array<Expr> strides;
/*! \brief data type in the content of the tensor */ /*! \brief data type in the content of the tensor */
Type dtype; Type dtype;
// Maybe need more information(alignment) later /*!
* \brief The offset in bytes to the beginning pointer to data
* Can be undefined, indicating this must be zero.
*/
Expr byte_offset;
/*! \brief Alignment bytes size of byte_offset */
int offset_alignment;
/*! \brief constructor */ /*! \brief constructor */
BufferNode() {} BufferNode() {}
...@@ -80,13 +86,17 @@ class BufferNode : public Node { ...@@ -80,13 +86,17 @@ class BufferNode : public Node {
v->Visit("shape", &shape); v->Visit("shape", &shape);
v->Visit("strides", &strides); v->Visit("strides", &strides);
v->Visit("dtype", &dtype); v->Visit("dtype", &dtype);
v->Visit("byte_offset", &byte_offset);
v->Visit("offset_alignment", &offset_alignment);
} }
static Buffer make(std::string name, static Buffer make(std::string name,
Var ptr, Var ptr,
Array<Expr> shape, Array<Expr> shape,
Array<Expr> strides, Array<Expr> strides,
Type dtype); Type dtype,
Expr byte_offset,
int offset_alignment);
static constexpr const char* _type_key = "Buffer"; static constexpr const char* _type_key = "Buffer";
TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node); TVM_DECLARE_NODE_TYPE_INFO(BufferNode, Node);
......
...@@ -41,6 +41,14 @@ using Halide::Internal::const_true; ...@@ -41,6 +41,14 @@ using Halide::Internal::const_true;
using Halide::Internal::const_false; using Halide::Internal::const_false;
using Halide::Internal::is_no_op; using Halide::Internal::is_no_op;
inline Type TVMShapeIndexType() {
if (std::is_signed<tvm_index_t>::value) {
return Int(sizeof(tvm_index_t) * 8);
} else {
return UInt(sizeof(tvm_index_t) * 8);
}
}
inline Type TVMType2Type(TVMType t) { inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes); return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
} }
......
...@@ -167,7 +167,8 @@ enum TVMArrayFieldKind { ...@@ -167,7 +167,8 @@ enum TVMArrayFieldKind {
kStrides = 3, kStrides = 3,
kTypeCode = 4, kTypeCode = 4,
kTypeBits = 5, kTypeBits = 5,
kTypeLanes = 6 kTypeLanes = 6,
kByteOffset = 7
}; };
} // namespace intrinsic } // namespace intrinsic
......
...@@ -31,22 +31,22 @@ ...@@ -31,22 +31,22 @@
#include <stdint.h> #include <stdint.h>
#include <stddef.h> #include <stddef.h>
// TVM Runtime is DLPack compatible.
#include <dlpack/dlpack.h>
TVM_EXTERN_C { TVM_EXTERN_C {
/*! \brief type of array index. */ /*! \brief type of array index. */
typedef uint32_t tvm_index_t; typedef int64_t tvm_index_t;
/*! /*!
* \brief The type code in TVMType * \brief The type code in TVMType
* \note TVMType is used in two places. * \note TVMType is used in two places.
*/ */
typedef enum { typedef enum {
kInt = 0U, // The type code of other types are compatible with DLPack.
kUInt = 1U,
kFloat = 2U,
kHandle = 3U,
// The next few fields are extension types // The next few fields are extension types
// that is used by TVM API calls. // that is used by TVM API calls.
kHandle = 3U,
kNull = 4U, kNull = 4U,
kArrayHandle = 5U, kArrayHandle = 5U,
kTVMType = 6U, kTVMType = 6U,
...@@ -67,14 +67,17 @@ typedef enum { ...@@ -67,14 +67,17 @@ typedef enum {
* *
* \note Arguments TVM API function always takes bits=64 and lanes=1 * \note Arguments TVM API function always takes bits=64 and lanes=1
*/ */
typedef struct { typedef DLDataType TVMType;
/*! \brief type code, in TVMTypeCode */
uint8_t code; /*!
/*! \brief number of bits of the type */ * \brief The Device information, abstract away common device types.
uint8_t bits; */
/*! \brief number of lanes, */ typedef DLContext TVMContext;
uint16_t lanes;
} TVMType; /*!
* \brief The tensor array stucture to TVM API.
*/
typedef DLTensor TVMArray;
/*! /*!
* \brief Union type of values * \brief Union type of values
...@@ -97,50 +100,6 @@ typedef struct { ...@@ -97,50 +100,6 @@ typedef struct {
size_t size; size_t size;
} TVMByteArray; } TVMByteArray;
/*!
* \brief The device type
*/
typedef enum {
/*! \brief CPU device */
kCPU = 1,
/*! \brief NVidia GPU device(CUDA) */
kGPU = 2,
/*! \brief opencl device */
kOpenCL = 4
} TVMDeviceMask;
/*!
* \brief The Device information, abstract away common device types.
*/
typedef struct {
/*! \brief The device type mask */
int dev_mask;
/*! \brief the device id */
int dev_id;
} TVMContext;
/*!
* \brief Data structure representing a n-dimensional array(tensor).
* This is used to pass data specification into TVM.
*/
typedef struct {
/*! \brief The data field pointer on specified device */
void* data;
/*! \brief The shape pointers of the array */
const tvm_index_t* shape;
/*!
* \brief The stride data about each dimension of the array, can be NULL
* When strides is NULL, it indicates that the array is empty.
*/
const tvm_index_t* strides;
/*! \brief number of dimensions of the array */
tvm_index_t ndim;
/*! \brief The data type flag */
TVMType dtype;
/*! \brief The device context this array sits on */
TVMContext ctx;
} TVMArray;
/*! \brief Handle to TVM runtime modules. */ /*! \brief Handle to TVM runtime modules. */
typedef void* TVMModuleHandle; typedef void* TVMModuleHandle;
/*! \brief Handle to packed function handle. */ /*! \brief Handle to packed function handle. */
......
...@@ -8,35 +8,37 @@ import numpy as np ...@@ -8,35 +8,37 @@ import numpy as np
from .._base import _LIB, check_call from .._base import _LIB, check_call
from .._base import c_array from .._base import c_array
from ._types import TVMType, tvm_index_t from ._types import TVMType, tvm_shape_index_t
class TVMContext(ctypes.Structure): class TVMContext(ctypes.Structure):
"""TVM context strucure.""" """TVM context strucure."""
_fields_ = [("dev_mask", ctypes.c_int), _fields_ = [("device_id", ctypes.c_int),
("dev_id", ctypes.c_int)] ("device_type", ctypes.c_int)]
MASK2STR = { MASK2STR = {
1 : 'cpu', 1 : 'cpu',
2 : 'gpu', 2 : 'gpu',
4 : 'opencl' 4 : 'opencl'
} }
def __init__(self, dev_mask, dev_id): def __init__(self, device_id, device_type):
super(TVMContext, self).__init__() super(TVMContext, self).__init__()
self.dev_mask = dev_mask self.device_id = device_id
self.dev_id = dev_id self.device_type = device_type
def __repr__(self): def __repr__(self):
return "%s(%d)" % ( return "%s(%d)" % (
TVMContext.MASK2STR[self.dev_mask], self.dev_id) TVMContext.MASK2STR[self.device_type], self.device_id)
class TVMArray(ctypes.Structure): class TVMArray(ctypes.Structure):
"""TVMValue in C API""" """TVMValue in C API"""
_fields_ = [("data", ctypes.c_void_p), _fields_ = [("data", ctypes.c_void_p),
("shape", ctypes.POINTER(tvm_index_t)), ("ctx", TVMContext),
("strides", ctypes.POINTER(tvm_index_t)), ("ndim", ctypes.c_int),
("ndim", tvm_index_t),
("dtype", TVMType), ("dtype", TVMType),
("ctx", TVMContext)] ("shape", ctypes.POINTER(tvm_shape_index_t)),
("strides", ctypes.POINTER(tvm_shape_index_t)),
("byte_offset", ctypes.c_size_t)]
TVMArrayHandle = ctypes.POINTER(TVMArray) TVMArrayHandle = ctypes.POINTER(TVMArray)
...@@ -50,7 +52,7 @@ def cpu(dev_id=0): ...@@ -50,7 +52,7 @@ def cpu(dev_id=0):
dev_id : int, optional dev_id : int, optional
The integer device id The integer device id
""" """
return TVMContext(1, dev_id) return TVMContext(dev_id, 1)
def gpu(dev_id=0): def gpu(dev_id=0):
...@@ -61,7 +63,7 @@ def gpu(dev_id=0): ...@@ -61,7 +63,7 @@ def gpu(dev_id=0):
dev_id : int, optional dev_id : int, optional
The integer device id The integer device id
""" """
return TVMContext(2, dev_id) return TVMContext(dev_id, 2)
def opencl(dev_id=0): def opencl(dev_id=0):
...@@ -72,7 +74,7 @@ def opencl(dev_id=0): ...@@ -72,7 +74,7 @@ def opencl(dev_id=0):
dev_id : int, optional dev_id : int, optional
The integer device id The integer device id
""" """
return TVMContext(4, dev_id) return TVMContext(dev_id, 4)
def numpyasarray(np_data): def numpyasarray(np_data):
...@@ -81,7 +83,7 @@ def numpyasarray(np_data): ...@@ -81,7 +83,7 @@ def numpyasarray(np_data):
data = np_data data = np_data
assert data.flags['C_CONTIGUOUS'] assert data.flags['C_CONTIGUOUS']
arr = TVMArray() arr = TVMArray()
shape = c_array(tvm_index_t, data.shape) shape = c_array(tvm_shape_index_t, data.shape)
arr.data = data.ctypes.data_as(ctypes.c_void_p) arr.data = data.ctypes.data_as(ctypes.c_void_p)
arr.shape = shape arr.shape = shape
arr.strides = None arr.strides = None
...@@ -114,8 +116,8 @@ def empty(shape, dtype="float32", ctx=cpu(0)): ...@@ -114,8 +116,8 @@ def empty(shape, dtype="float32", ctx=cpu(0)):
arr : tvm.nd.NDArray arr : tvm.nd.NDArray
The array tvm supported. The array tvm supported.
""" """
shape = c_array(tvm_index_t, shape) shape = c_array(tvm_shape_index_t, shape)
ndim = tvm_index_t(len(shape)) ndim = ctypes.c_int(len(shape))
handle = TVMArrayHandle() handle = TVMArrayHandle()
dtype = TVMType(dtype) dtype = TVMType(dtype)
check_call(_LIB.TVMArrayAlloc( check_call(_LIB.TVMArrayAlloc(
......
...@@ -6,7 +6,7 @@ import ctypes ...@@ -6,7 +6,7 @@ import ctypes
import numpy as np import numpy as np
from .._base import py_str from .._base import py_str
tvm_index_t = ctypes.c_uint32 tvm_shape_index_t = ctypes.c_int64
class TypeCode(object): class TypeCode(object):
"""Type code used in API calls""" """Type code used in API calls"""
......
...@@ -261,8 +261,10 @@ def call_packed(*args): ...@@ -261,8 +261,10 @@ def call_packed(*args):
def Buffer(shape, dtype=None, def Buffer(shape, dtype=None,
name="buffer", name="buffer",
ptr=None, data=None,
strides=None): strides=None,
byte_offset=None,
offset_alignment=0):
"""Create a new symbolic buffer """Create a new symbolic buffer
Parameters Parameters
...@@ -276,12 +278,18 @@ def Buffer(shape, dtype=None, ...@@ -276,12 +278,18 @@ def Buffer(shape, dtype=None,
name : str, optional name : str, optional
The name of the buffer. The name of the buffer.
ptr : Var, optional data : Var, optional
The data pointer in the buffer. The data pointer in the buffer.
strides: array of Expr strides: array of Expr
The stride of the buffer. The stride of the buffer.
byte_offset: Expr, optional
The offset in bytes to data pointer.
offset_alignment: int, optional
The alignment of offset
Returns Returns
------- -------
buffer : Buffer buffer : Buffer
...@@ -290,11 +298,11 @@ def Buffer(shape, dtype=None, ...@@ -290,11 +298,11 @@ def Buffer(shape, dtype=None,
shape = (shape,) if isinstance(shape, _expr.Expr) else shape shape = (shape,) if isinstance(shape, _expr.Expr) else shape
dtype = float32 if dtype is None else dtype dtype = float32 if dtype is None else dtype
strides = () if strides is None else strides strides = () if strides is None else strides
if ptr is None: if data is None:
ptr = Var(name, "handle") data = Var(name, "handle")
return _api_internal._Buffer( return _api_internal._Buffer(
name, ptr, shape, strides, dtype) name, data, shape, strides, dtype, byte_offset, offset_alignment)
def _IterVar(dom, name, iter_type, thread_tag=''): def _IterVar(dom, name, iter_type, thread_tag=''):
......
...@@ -138,7 +138,9 @@ TVM_REGISTER_API(_Buffer) ...@@ -138,7 +138,9 @@ TVM_REGISTER_API(_Buffer)
args[1], args[1],
args[2], args[2],
args[3], args[3],
args[4]); args[4],
args[5],
args[6]);
}); });
TVM_REGISTER_API(_Tensor) TVM_REGISTER_API(_Tensor)
......
...@@ -33,17 +33,18 @@ void CodeGenLLVM::Init(const std::string& module_name, ...@@ -33,17 +33,18 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_int32_ = llvm::Type::getInt32Ty(*ctx); t_int32_ = llvm::Type::getInt32Ty(*ctx);
t_int64_ = llvm::Type::getInt64Ty(*ctx); t_int64_ = llvm::Type::getInt64Ty(*ctx);
t_float64_ = llvm::Type::getDoubleTy(*ctx); t_float64_ = llvm::Type::getDoubleTy(*ctx);
t_tvm_index_ = llvm::Type::getIntNTy(*ctx, sizeof(tvm_index_t) * 8); t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, TVMShapeIndexType().bits());
t_tvm_context_ = llvm::StructType::create({t_int_, t_int_}); t_tvm_context_ = llvm::StructType::create({t_int_, t_int_});
t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_}); t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_});
t_tvm_func_handle_ = t_void_p_; t_tvm_func_handle_ = t_void_p_;
t_tvm_array_ = llvm::StructType::create( t_tvm_array_ = llvm::StructType::create(
{t_void_p_, {t_void_p_,
t_tvm_index_->getPointerTo(), t_tvm_context_,
t_tvm_index_->getPointerTo(), t_int_,
t_tvm_index_,
t_tvm_type_, t_tvm_type_,
t_tvm_context_}); t_tvm_shape_index_->getPointerTo(),
t_tvm_shape_index_->getPointerTo(),
t_int64_});
t_tvm_value_ = llvm::StructType::create({t_float64_}); t_tvm_value_ = llvm::StructType::create({t_float64_});
t_f_tvm_par_for_lambda_ = llvm::FunctionType::get( t_f_tvm_par_for_lambda_ = llvm::FunctionType::get(
t_int_, {t_int64_, t_int64_, t_void_p_}, false); t_int_, {t_int64_, t_int64_, t_void_p_}, false);
...@@ -663,25 +664,29 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) { ...@@ -663,25 +664,29 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(0)}); break; ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(0)}); break;
} }
case intrinsic::kShape: { case intrinsic::kShape: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(1)}); break; ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(4)}); break;
} }
case intrinsic::kStrides: { case intrinsic::kStrides: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(2)}); break; ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(5)}); break;
} }
case intrinsic::kNDim: { case intrinsic::kNDim: {
ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(3)}); break; ret = builder_->CreateInBoundsGEP(arr, {zero, ConstInt32(2)}); break;
} }
case intrinsic::kTypeCode: { case intrinsic::kTypeCode: {
ret = builder_->CreateInBoundsGEP( ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(4), ConstInt32(0)}); break; arr, {zero, ConstInt32(3), ConstInt32(0)}); break;
} }
case intrinsic::kTypeBits: { case intrinsic::kTypeBits: {
ret = builder_->CreateInBoundsGEP( ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(4), ConstInt32(1)}); break; arr, {zero, ConstInt32(3), ConstInt32(1)}); break;
} }
case intrinsic::kTypeLanes: { case intrinsic::kTypeLanes: {
ret = builder_->CreateInBoundsGEP( ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(4), ConstInt32(2)}); break; arr, {zero, ConstInt32(3), ConstInt32(2)}); break;
}
case intrinsic::kByteOffset: {
ret = builder_->CreateInBoundsGEP(
arr, {zero, ConstInt32(6)}); break;
} }
default: LOG(FATAL) << "unknown field code"; default: LOG(FATAL) << "unknown field code";
} }
......
...@@ -160,7 +160,7 @@ class CodeGenLLVM : ...@@ -160,7 +160,7 @@ class CodeGenLLVM :
llvm::MDNode* md_very_likely_branch_{nullptr}; llvm::MDNode* md_very_likely_branch_{nullptr};
llvm::MDNode* md_tbaa_root_{nullptr}; llvm::MDNode* md_tbaa_root_{nullptr};
// TVM related data types // TVM related data types
llvm::Type* t_tvm_index_{nullptr}; llvm::Type* t_tvm_shape_index_{nullptr};
llvm::Type* t_tvm_func_handle_{nullptr}; llvm::Type* t_tvm_func_handle_{nullptr};
llvm::StructType* t_tvm_context_{nullptr}; llvm::StructType* t_tvm_context_{nullptr};
llvm::StructType* t_tvm_type_{nullptr}; llvm::StructType* t_tvm_type_{nullptr};
......
...@@ -166,6 +166,7 @@ void CodeGenStackVM::VisitExpr_(const Call* op) { ...@@ -166,6 +166,7 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
case intrinsic::kTypeCode: PushOp(StackVM::TVM_ARRAY_GET_TYPE_CODE); break; case intrinsic::kTypeCode: PushOp(StackVM::TVM_ARRAY_GET_TYPE_CODE); break;
case intrinsic::kTypeBits: PushOp(StackVM::TVM_ARRAY_GET_TYPE_BITS); break; case intrinsic::kTypeBits: PushOp(StackVM::TVM_ARRAY_GET_TYPE_BITS); break;
case intrinsic::kTypeLanes: PushOp(StackVM::TVM_ARRAY_GET_TYPE_LANES); break; case intrinsic::kTypeLanes: PushOp(StackVM::TVM_ARRAY_GET_TYPE_LANES); break;
case intrinsic::kByteOffset: PushOp(StackVM::TVM_ARRAY_GET_BYTE_OFFSET); break;
default: LOG(FATAL) << "unknown field code"; default: LOG(FATAL) << "unknown field code";
} }
} else if (op->is_intrinsic(intrinsic::tvm_call_packed)) { } else if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
...@@ -227,15 +228,12 @@ void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, ...@@ -227,15 +228,12 @@ void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64,
void CodeGenStackVM::PushCast(Type dst, Type src) { void CodeGenStackVM::PushCast(Type dst, Type src) {
if (dst.is_int()) { if (dst.is_int()) {
if (src.is_int()) return; if (src.is_int() || src.is_uint()) return;
if (src.is_uint() && src.bits() <= 32) return; } else if (dst.is_uint()) {
} else if (dst.is_uint() && dst.bits() <= 32) { if (src.is_int() || src.is_uint()) return;
if (src.is_int()) return;
if (src.is_uint() && src.bits() <= 32) return;
} else if (dst.is_float()) { } else if (dst.is_float()) {
if (src.is_float()) return; if (src.is_float()) return;
} }
LOG(FATAL) << "Cannot handle cast " << src << " to " << dst;
} }
void CodeGenStackVM::VisitExpr_(const StringImm *op) { void CodeGenStackVM::VisitExpr_(const StringImm *op) {
......
...@@ -139,6 +139,7 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { ...@@ -139,6 +139,7 @@ int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const {
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_SHAPE); STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_SHAPE);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_STRIDES); STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_STRIDES);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_NDIM); STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_NDIM);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_BYTE_OFFSET);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_CODE); STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_CODE);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_BITS); STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_BITS);
STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_LANES); STACK_VM_PRINT_CODE0(TVM_ARRAY_GET_TYPE_LANES);
...@@ -352,6 +353,9 @@ void StackVM::Run(State* s) const { ...@@ -352,6 +353,9 @@ void StackVM::Run(State* s) const {
case TVM_ARRAY_GET_NDIM: { case TVM_ARRAY_GET_NDIM: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break; STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, ndim); break;
} }
case TVM_ARRAY_GET_BYTE_OFFSET: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, byte_offset); break;
}
case TVM_ARRAY_GET_TYPE_CODE: { case TVM_ARRAY_GET_TYPE_CODE: {
STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.code); break; STACK_VM_TVM_ARRARY_GET(v_int64, int64_t, dtype.code); break;
} }
......
...@@ -199,7 +199,8 @@ class StackVM { ...@@ -199,7 +199,8 @@ class StackVM {
TVM_ARRAY_GET_NDIM, TVM_ARRAY_GET_NDIM,
TVM_ARRAY_GET_TYPE_CODE, TVM_ARRAY_GET_TYPE_CODE,
TVM_ARRAY_GET_TYPE_BITS, TVM_ARRAY_GET_TYPE_BITS,
TVM_ARRAY_GET_TYPE_LANES TVM_ARRAY_GET_TYPE_LANES,
TVM_ARRAY_GET_BYTE_OFFSET
}; };
/*! \brief The code structure */ /*! \brief The code structure */
union Code { union Code {
......
...@@ -22,7 +22,8 @@ Buffer::Buffer(Array<Expr> shape, ...@@ -22,7 +22,8 @@ Buffer::Buffer(Array<Expr> shape,
: Buffer(BufferNode::make( : Buffer(BufferNode::make(
name, name,
Var(name, Type(Type::Handle, 0, 0)), Var(name, Type(Type::Handle, 0, 0)),
shape, Array<Expr>(), dtype)) { shape, Array<Expr>(), dtype,
Expr(), 0)) {
} }
inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) { inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
...@@ -40,6 +41,9 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) { ...@@ -40,6 +41,9 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
base = base + index[i] * n->strides[i]; base = base + index[i] * n->strides[i];
} }
} }
if (!is_zero(n->byte_offset)) {
base = base + (n->byte_offset / n->dtype.bytes());
}
return base; return base;
} }
...@@ -58,13 +62,27 @@ Buffer BufferNode::make(std::string name, ...@@ -58,13 +62,27 @@ Buffer BufferNode::make(std::string name,
Var data, Var data,
Array<Expr> shape, Array<Expr> shape,
Array<Expr> strides, Array<Expr> strides,
Type dtype) { Type dtype,
Expr byte_offset,
int offset_alignment) {
auto n = std::make_shared<BufferNode>(); auto n = std::make_shared<BufferNode>();
n->name = name; n->name = name;
n->data = data; n->data = data;
n->shape = shape; n->shape = shape;
n->strides = strides; n->strides = strides;
n->dtype = dtype; n->dtype = dtype;
if (!byte_offset.defined()) {
byte_offset = make_const(shape[0].type(), 0);
}
if (offset_alignment != 0) {
CHECK_EQ(offset_alignment % dtype.bytes(), 0)
<< "Offset alignments must be at least " << dtype.bytes();
} else {
offset_alignment = dtype.bytes();
}
n->byte_offset = byte_offset;
n->offset_alignment = offset_alignment;
return Buffer(n); return Buffer(n);
} }
......
...@@ -36,7 +36,8 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -36,7 +36,8 @@ LoweredFunc MakeAPI(Stmt body,
std::string name, std::string name,
Array<NodeRef> api_args, Array<NodeRef> api_args,
int num_unpacked_args) { int num_unpacked_args) {
const Type tvm_index_type = UInt(32); const Type tvm_shape_type = TVMShapeIndexType();
const Type tvm_ndim_type = Int(32);
const Stmt nop = Evaluate::make(0); const Stmt nop = Evaluate::make(0);
int num_args = static_cast<int>(api_args.size()); int num_args = static_cast<int>(api_args.size());
CHECK_LE(num_unpacked_args, num_args); CHECK_LE(num_unpacked_args, num_args);
...@@ -120,13 +121,15 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -120,13 +121,15 @@ LoweredFunc MakeAPI(Stmt body,
<< "api_args can only be Buffer or Var"; << "api_args can only be Buffer or Var";
Buffer buf(api_args[i].node_); Buffer buf(api_args[i].node_);
// dimension checks // dimension checks
Expr v_ndim = TVMArrayGet(tvm_index_type, v_arg, intrinsic::kNDim); Expr v_ndim = TVMArrayGet(tvm_ndim_type, v_arg, intrinsic::kNDim);
std::ostringstream ndim_err_msg; std::ostringstream ndim_err_msg;
ndim_err_msg << "arg_" << i ndim_err_msg << "arg_" << i
<< ".ndim is expected to equal " << ".ndim is expected to equal "
<< buf->shape.size(); << buf->shape.size();
seq_init.emplace_back( seq_init.emplace_back(
MakeAssertEQ(v_ndim, UIntImm::make(tvm_index_type, buf->shape.size()), MakeAssertEQ(v_ndim,
make_const(tvm_ndim_type,
static_cast<int64_t>(buf->shape.size())),
ndim_err_msg.str())); ndim_err_msg.str()));
// type checks // type checks
Type dtype = buf->dtype; Type dtype = buf->dtype;
...@@ -147,7 +150,7 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -147,7 +150,7 @@ LoweredFunc MakeAPI(Stmt body,
} }
// shape field // shape field
Var v_shape(v_arg->name_hint + ".shape", Handle()); Var v_shape(v_arg->name_hint + ".shape", Handle());
handle_data_type.Set(v_shape, UIntImm::make(tvm_index_type, 0)); handle_data_type.Set(v_shape, make_const(tvm_shape_type, 0));
seq_init.emplace_back(LetStmt::make( seq_init.emplace_back(LetStmt::make(
v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kShape), nop)); v_shape, TVMArrayGet(Handle(), v_arg, intrinsic::kShape), nop));
for (size_t k = 0; k < buf->shape.size(); ++k) { for (size_t k = 0; k < buf->shape.size(); ++k) {
...@@ -155,12 +158,12 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -155,12 +158,12 @@ LoweredFunc MakeAPI(Stmt body,
field_name << v_shape->name_hint << '[' << k << ']'; field_name << v_shape->name_hint << '[' << k << ']';
f_push(buf->shape[k], f_push(buf->shape[k],
cast(buf->shape[k].type(), cast(buf->shape[k].type(),
Load::make(tvm_index_type, v_shape, IntImm::make(Int(32), k))), Load::make(tvm_shape_type, v_shape, IntImm::make(Int(32), k))),
field_name.str()); field_name.str());
} }
// strides field // strides field
Var v_strides(v_arg->name_hint + ".strides", Handle()); Var v_strides(v_arg->name_hint + ".strides", Handle());
handle_data_type.Set(v_strides, UIntImm::make(tvm_index_type, 0)); handle_data_type.Set(v_strides, make_const(tvm_shape_type, 0));
seq_init.emplace_back(LetStmt::make( seq_init.emplace_back(LetStmt::make(
v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kStrides), nop)); v_strides, TVMArrayGet(Handle(), v_arg, intrinsic::kStrides), nop));
if (buf->strides.size() == 0) { if (buf->strides.size() == 0) {
...@@ -174,10 +177,13 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -174,10 +177,13 @@ LoweredFunc MakeAPI(Stmt body,
field_name << v_strides->name_hint << '[' << k << ']'; field_name << v_strides->name_hint << '[' << k << ']';
f_push(buf->strides[k], f_push(buf->strides[k],
cast(buf->shape[k].type(), cast(buf->shape[k].type(),
Load::make(tvm_index_type, v_strides, IntImm::make(Int(32), k))), Load::make(tvm_shape_type, v_strides, IntImm::make(Int(32), k))),
field_name.str()); field_name.str());
} }
} }
// Byte_offset field.
f_push(buf->byte_offset, TVMArrayGet(UInt(64), v_arg, intrinsic::kByteOffset),
v_arg->name_hint + ".byte_offset");
} }
} }
......
...@@ -7,17 +7,56 @@ ...@@ -7,17 +7,56 @@
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <dmlc/timer.h> #include <dmlc/timer.h>
#include <array>
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <cstdlib> #include <cstdlib>
#include <thread> #include <thread>
#include <mutex>
#include "./runtime_base.h" #include "./runtime_base.h"
#include "./device_api.h" #include "./device_api.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
class DeviceAPIManager {
public:
static const int kMaxDeviceAPI = 16;
// Get API
static DeviceAPI* Get(TVMContext ctx) {
return Global()->GetAPI(ctx.device_type);
}
private:
std::array<DeviceAPI*, kMaxDeviceAPI> api_;
std::mutex mutex_;
// constructor
DeviceAPIManager() {
std::fill(api_.begin(), api_.end(), nullptr);
}
// Global static variable.
static DeviceAPIManager* Global() {
static DeviceAPIManager inst;
return &inst;
}
// Get or initialize API.
DeviceAPI* GetAPI(DLDeviceType type) {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
std::string factory = "_device_api_" + DeviceName(type);
auto* f = Registry::Get(factory);
CHECK(f != nullptr)
<< "Device API " << DeviceName(type) << " is not enabled.";
void* ptr = (*f)();
api_[type] = static_cast<DeviceAPI*>(ptr);
return api_[type];
}
};
inline TVMArray* TVMArrayCreate_() { inline TVMArray* TVMArrayCreate_() {
TVMArray* arr = new TVMArray(); TVMArray* arr = new TVMArray();
arr->shape = nullptr; arr->shape = nullptr;
...@@ -33,9 +72,8 @@ inline void TVMArrayFree_(TVMArray* arr) { ...@@ -33,9 +72,8 @@ inline void TVMArrayFree_(TVMArray* arr) {
delete[] arr->shape; delete[] arr->shape;
delete[] arr->strides; delete[] arr->strides;
if (arr->data != nullptr) { if (arr->data != nullptr) {
TVM_DEVICE_SWITCH(arr->ctx, { DeviceAPIManager::Get(arr->ctx)->FreeDataSpace(
FreeDataSpace<xpu>(arr->ctx, arr->data); arr->ctx, arr->data);
});
} }
} }
delete arr; delete arr;
...@@ -282,10 +320,8 @@ int TVMArrayAlloc(const tvm_index_t* shape, ...@@ -282,10 +320,8 @@ int TVMArrayAlloc(const tvm_index_t* shape,
arr->ctx = ctx; arr->ctx = ctx;
size_t size = GetDataSize(arr); size_t size = GetDataSize(arr);
size_t alignment = GetDataAlignment(arr); size_t alignment = GetDataAlignment(arr);
// ctx data pointer arr->data = DeviceAPIManager::Get(ctx)->AllocDataSpace(
TVM_DEVICE_SWITCH(ctx, { ctx, size, alignment);
arr->data = AllocDataSpace<xpu>(ctx, size, alignment);
});
*out = arr; *out = arr;
API_END_HANDLE_ERROR(TVMArrayFree_(arr)); API_END_HANDLE_ERROR(TVMArrayFree_(arr));
} }
...@@ -306,28 +342,21 @@ int TVMArrayCopyFromTo(TVMArrayHandle from, ...@@ -306,28 +342,21 @@ int TVMArrayCopyFromTo(TVMArrayHandle from,
CHECK_EQ(from_size, to_size) CHECK_EQ(from_size, to_size)
<< "TVMArrayCopyFromTo: The size must exactly match"; << "TVMArrayCopyFromTo: The size must exactly match";
TVMContext ctx = from->ctx; TVMContext ctx = from->ctx;
if (ctx.dev_mask == kCPU) { if (ctx.device_type == kCPU) {
ctx = to->ctx; ctx = to->ctx;
} else { } else {
CHECK(to->ctx.dev_mask == kCPU || CHECK(to->ctx.device_type == kCPU ||
to->ctx.dev_mask == from->ctx.dev_mask) to->ctx.device_type == from->ctx.device_type)
<< "Can not copy across different ctx types directly"; << "Can not copy across different ctx types directly";
} }
DeviceAPIManager::Get(ctx)->CopyDataFromTo(
TVM_DEVICE_SWITCH(ctx, { from->data, to->data, from_size,
CopyDataFromTo<xpu>(from->data, to->data, from->ctx, to->ctx, stream);
from_size,
from->ctx,
to->ctx,
stream);
});
API_END(); API_END();
} }
int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) { int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
API_BEGIN(); API_BEGIN();
TVM_DEVICE_SWITCH(ctx, { DeviceAPIManager::Get(ctx)->StreamSync(ctx, stream);
StreamSync<xpu>(ctx, stream);
});
API_END(); API_END();
} }
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#define TVM_RUNTIME_DEVICE_API_CPU_H_ #define TVM_RUNTIME_DEVICE_API_CPU_H_
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <cstdlib> #include <cstdlib>
#include <cstring> #include <cstring>
#include "./device_api.h" #include "./device_api.h"
...@@ -14,41 +15,47 @@ ...@@ -14,41 +15,47 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
template<> class CPUDeviceAPI : public DeviceAPI {
void* AllocDataSpace<kCPU>(TVMContext ctx, size_t size, size_t alignment) { public:
void* ptr; void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
void* ptr;
#if _MSC_VER #if _MSC_VER
ptr = _aligned_malloc(size, alignment); ptr = _aligned_malloc(size, alignment);
if (ptr == nullptr) throw std::bad_alloc(); if (ptr == nullptr) throw std::bad_alloc();
#else #else
int ret = posix_memalign(&ptr, alignment, size); int ret = posix_memalign(&ptr, alignment, size);
if (ret != 0) throw std::bad_alloc(); if (ret != 0) throw std::bad_alloc();
#endif #endif
return ptr; return ptr;
} }
template<> void FreeDataSpace(TVMContext ctx, void* ptr) final {
void FreeDataSpace<kCPU>(TVMContext ctx, void* ptr) {
#if _MSC_VER #if _MSC_VER
_aligned_free(ptr); _aligned_free(ptr);
#else #else
free(ptr); free(ptr);
#endif #endif
} }
template<> void CopyDataFromTo(const void* from,
void CopyDataFromTo<kCPU>(const void* from, void* to,
void* to, size_t size,
size_t size, TVMContext ctx_from,
TVMContext ctx_from, TVMContext ctx_to,
TVMContext ctx_to, TVMStreamHandle stream) final {
TVMStreamHandle stream) { memcpy(to, from, size);
memcpy(to, from, size); }
}
template<> void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
void StreamSync<kCPU>(TVMContext ctx, TVMStreamHandle stream) { }
} };
TVM_REGISTER_GLOBAL(_device_api_cpu)
.set_body([](TVMArgs args, TVMRetValue* rv) {
static CPUDeviceAPI inst;
DeviceAPI* ptr = &inst;
*rv = static_cast<void*>(ptr);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_CPU_H_ #endif // TVM_RUNTIME_DEVICE_API_CPU_H_
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file device_api_cuda.h * \file cuda_device_api.cc
* \brief GPU specific API * \brief GPU specific API
*/ */
#ifndef TVM_RUNTIME_CUDA_DEVICE_API_CUDA_H_ #include <tvm/runtime/config.h>
#define TVM_RUNTIME_CUDA_DEVICE_API_CUDA_H_
#include "./cuda_common.h"
#if TVM_CUDA_RUNTIME #if TVM_CUDA_RUNTIME
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <tvm/runtime/registry.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "./cuda_common.h"
#include "../device_api.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
template<> class CUDADeviceAPI : public DeviceAPI {
inline void* AllocDataSpace<kGPU>(TVMContext ctx, size_t size, size_t alignment) { public:
CUDA_CALL(cudaSetDevice(ctx.dev_id)); void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final {
CHECK_EQ(256 % alignment, 0U) CUDA_CALL(cudaSetDevice(ctx.device_id));
<< "CUDA space is aligned at 256 bytes"; CHECK_EQ(256 % alignment, 0U)
void *ret; << "CUDA space is aligned at 256 bytes";
CUDA_CALL(cudaMalloc(&ret, size)); void *ret;
return ret; CUDA_CALL(cudaMalloc(&ret, size));
} return ret;
}
void FreeDataSpace(TVMContext ctx, void* ptr) final {
CUDA_CALL(cudaSetDevice(ctx.device_id));
CUDA_CALL(cudaFree(ptr));
}
template<> void CopyDataFromTo(const void* from,
inline void FreeDataSpace<kGPU>(TVMContext ctx, void* ptr) { void* to,
CUDA_CALL(cudaSetDevice(ctx.dev_id)); size_t size,
CUDA_CALL(cudaFree(ptr)); TVMContext ctx_from,
} TVMContext ctx_to,
TVMStreamHandle stream) final {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
if (ctx_from.device_type == kGPU && ctx_to.device_type == kGPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
if (ctx_from.device_id == ctx_to.device_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
} else {
cudaMemcpyPeerAsync(to, ctx_to.device_id,
from, ctx_from.device_id,
size, cu_stream);
}
} else if (ctx_from.device_type == kGPU && ctx_to.device_type == kCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.device_id));
GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kGPU) {
CUDA_CALL(cudaSetDevice(ctx_to.device_id));
GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
} else {
LOG(FATAL) << "expect copy from/to GPU or between GPU";
}
}
inline void GPUCopy(const void* from, void StreamSync(TVMContext ctx, TVMStreamHandle stream) final {
void* to, CUDA_CALL(cudaSetDevice(ctx.device_id));
size_t size, CUDA_CALL(cudaStreamSynchronize(static_cast<cudaStream_t>(stream)));
cudaMemcpyKind kind,
cudaStream_t stream) {
if (stream != 0) {
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
} else {
CUDA_CALL(cudaMemcpy(to, from, size, kind));
} }
}
template<> private:
inline void CopyDataFromTo<kGPU>(const void* from, static void GPUCopy(const void* from,
void* to, void* to,
size_t size, size_t size,
TVMContext ctx_from, cudaMemcpyKind kind,
TVMContext ctx_to, cudaStream_t stream) {
TVMStreamHandle stream) { if (stream != 0) {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream); CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
if (ctx_from.dev_mask == kGPU && ctx_to.dev_mask == kGPU) {
CUDA_CALL(cudaSetDevice(ctx_from.dev_id));
if (ctx_from.dev_id == ctx_to.dev_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
} else { } else {
cudaMemcpyPeerAsync(to, ctx_to.dev_id, CUDA_CALL(cudaMemcpy(to, from, size, kind));
from, ctx_from.dev_id,
size, cu_stream);
} }
} else if (ctx_from.dev_mask == kGPU && ctx_to.dev_mask == kCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.dev_id));
GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.dev_mask == kCPU && ctx_to.dev_mask == kGPU) {
CUDA_CALL(cudaSetDevice(ctx_to.dev_id));
GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
} else {
LOG(FATAL) << "expect copy from/to GPU or between GPU";
} }
} };
template<> TVM_REGISTER_GLOBAL(_device_api_gpu)
inline void StreamSync<kGPU>(TVMContext ctx, TVMStreamHandle stream) { .set_body([](TVMArgs args, TVMRetValue* rv) {
CUDA_CALL(cudaSetDevice(ctx.dev_id)); static CUDADeviceAPI inst;
CUDA_CALL(cudaStreamSynchronize( DeviceAPI* ptr = &inst;
static_cast<cudaStream_t>(stream))); *rv = static_cast<void*>(ptr);
} });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_CUDA_RUNTIME #endif // TVM_CUDA_RUNTIME
#endif // TVM_RUNTIME_CUDA_DEVICE_API_CUDA_H_
...@@ -50,9 +50,9 @@ class CUDAModuleNode : public runtime::ModuleNode { ...@@ -50,9 +50,9 @@ class CUDAModuleNode : public runtime::ModuleNode {
} }
void PreCompile(const std::string& name, TVMContext ctx) final { void PreCompile(const std::string& name, TVMContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.dev_id)); CUDA_CALL(cudaSetDevice(ctx.device_id));
cudaFree(nullptr); cudaFree(nullptr);
this->GetFunc(ctx.dev_id, name); this->GetFunc(ctx.device_id, name);
} }
PackedFunc GetFunction( PackedFunc GetFunction(
...@@ -79,15 +79,15 @@ class CUDAModuleNode : public runtime::ModuleNode { ...@@ -79,15 +79,15 @@ class CUDAModuleNode : public runtime::ModuleNode {
} }
} }
// get a CUfunction from primary context in dev_id // get a CUfunction from primary context in device_id
CUfunction GetFunc(int dev_id, const std::string& func_name) { CUfunction GetFunc(int device_id, const std::string& func_name) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
// must recheck under the lock scope // must recheck under the lock scope
if (module_[dev_id] == nullptr) { if (module_[device_id] == nullptr) {
CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[dev_id]), data_.c_str())); CUDA_DRIVER_CALL(cuModuleLoadData(&(module_[device_id]), data_.c_str()));
} }
CUfunction func; CUfunction func;
CUresult result = cuModuleGetFunction(&func, module_[dev_id], func_name.c_str()); CUresult result = cuModuleGetFunction(&func, module_[device_id], func_name.c_str());
if (result != CUDA_SUCCESS) { if (result != CUDA_SUCCESS) {
const char *msg; const char *msg;
cuGetErrorName(result, &msg); cuGetErrorName(result, &msg);
...@@ -132,14 +132,14 @@ class CUDAWrappedFunc { ...@@ -132,14 +132,14 @@ class CUDAWrappedFunc {
void operator()(TVMArgs args, void operator()(TVMArgs args,
TVMRetValue* rv, TVMRetValue* rv,
void** void_args) const { void** void_args) const {
int dev_id; int device_id;
CUDA_CALL(cudaGetDevice(&dev_id)); CUDA_CALL(cudaGetDevice(&device_id));
if (fcache_[dev_id] == nullptr) { if (fcache_[device_id] == nullptr) {
fcache_[dev_id] = m_->GetFunc(dev_id, func_name_); fcache_[device_id] = m_->GetFunc(device_id, func_name_);
} }
ThreadWorkLoad wl = thread_axis_cfg_.Extract(args); ThreadWorkLoad wl = thread_axis_cfg_.Extract(args);
CUDA_DRIVER_CALL(cuLaunchKernel( CUDA_DRIVER_CALL(cuLaunchKernel(
fcache_[dev_id], fcache_[device_id],
wl.grid_dim(0), wl.grid_dim(0),
wl.grid_dim(1), wl.grid_dim(1),
wl.grid_dim(2), wl.grid_dim(2),
...@@ -169,23 +169,23 @@ void AutoSetCUDADevice(const TVMArgs& args, TVMRetValue* rv) { ...@@ -169,23 +169,23 @@ void AutoSetCUDADevice(const TVMArgs& args, TVMRetValue* rv) {
int* type_codes = static_cast<int*>(args[1].operator void*()); int* type_codes = static_cast<int*>(args[1].operator void*());
int num_args = args[2].operator int(); int num_args = args[2].operator int();
int dev_id = -1; int device_id = -1;
for (int i = 0; i < num_args; ++i) { for (int i = 0; i < num_args; ++i) {
if (type_codes[i] == kArrayHandle) { if (type_codes[i] == kArrayHandle) {
TVMContext ctx = static_cast<TVMArray*>(values[i].v_handle)->ctx; TVMContext ctx = static_cast<TVMArray*>(values[i].v_handle)->ctx;
CHECK_EQ(ctx.dev_mask, kGPU) CHECK_EQ(ctx.device_type, kGPU)
<< "All operands need to be GPU"; << "All operands need to be GPU";
if (dev_id == -1) { if (device_id == -1) {
dev_id = ctx.dev_id; device_id = ctx.device_id;
} else { } else {
CHECK_EQ(dev_id, ctx.dev_id) CHECK_EQ(device_id, ctx.device_id)
<< "Operands comes from different devices "; << "Operands comes from different devices ";
} }
} }
} }
CHECK_NE(dev_id, -1) CHECK_NE(device_id, -1)
<< "Cannot detect device id from list"; << "Cannot detect device id from list";
CUDA_CALL(cudaSetDevice(dev_id)); CUDA_CALL(cudaSetDevice(device_id));
} }
PackedFunc CUDAModuleNode::GetFunction( PackedFunc CUDAModuleNode::GetFunction(
......
...@@ -8,83 +8,65 @@ ...@@ -8,83 +8,65 @@
#include <tvm/base.h> #include <tvm/base.h>
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <string>
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
* \param size The size of the memory
* \param alignment The alignment of the memory.
* \return The allocated device pointer
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment);
/*! class DeviceAPI {
* \brief Free a data space on device. public:
* \param ctx The device context to perform operation. /*! \brief virtual destructor */
* \param ptr The data space. virtual ~DeviceAPI() {}
* \tparam xpu The device mask. /*!
*/ * \brief Allocate a data space on device.
template<TVMDeviceMask xpu> * \param ctx The device context to perform operation.
inline void FreeDataSpace(TVMContext ctx, void* ptr); * \param size The size of the memory
* \param alignment The alignment of the memory.
* \return The allocated device pointer
*/
virtual void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) = 0;
/*!
* \brief Free a data space on device.
* \param ctx The device context to perform operation.
* \param ptr The data space.
* \tparam xpu The device mask.
*/
virtual void FreeDataSpace(TVMContext ctx, void* ptr) = 0;
/*!
* \brief copy data from one place to another
* \param dev The device to perform operation.
* \param from The source array.
* \param to The target array.
* \param size The size of the memory
* \param ctx_from The source context
* \param ctx_to The target context
*/
virtual void CopyDataFromTo(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) = 0;
/*!
* \brief Synchronize the stream
* \param ctx The context to perform operation.
* \param stream The stream to be sync.
*/
virtual void StreamSync(TVMContext ctx, TVMStreamHandle stream) = 0;
};
/*! /*!
* \brief copy data from one place to another * \brief The name of Device API factory.
* \param dev The device to perform operation. * \param type The device type.
* \param from The source array.
* \param to The target array.
* \param size The size of the memory
* \param ctx_from The source context
* \param ctx_to The target context
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline void CopyDataFromTo(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream);
/*!
* \brief Synchronize the stream
* \param ctx The context to perform operation.
* \param stream The stream to be sync.
* \tparam xpu The device mask.
*/ */
template<TVMDeviceMask xpu> inline std::string DeviceName(DLDeviceType type) {
inline void StreamSync(TVMContext ctx, TVMStreamHandle stream); switch (static_cast<int>(type)) {
case kCPU: return "cpu";
// macro to run cuda related code case kGPU: return "gpu";
#if TVM_CUDA_RUNTIME case kOpenCL: return "opencl";
#define TVM_RUN_CUDA(OP) { const TVMDeviceMask xpu = kGPU; OP; } default: LOG(FATAL) << "unknown type =" << type; return "Unknown";
#else
#define TVM_RUN_CUDA(OP) LOG(FATAL) << "CUDA is not enabled";
#endif
// macro to run opencl related code
#if TVM_OPENCL_RUNTIME
#define TVM_RUN_OPENCL(OP) { const TVMDeviceMask xpu = kOpenCL; OP; }
#else
#define TVM_RUN_OPENCL(OP) LOG(FATAL) << "OpenCL is not enabled";
#endif
// macro to switch options between devices
#define TVM_DEVICE_SWITCH(ctx, OP) \
switch (ctx.dev_mask) { \
case kCPU: { const TVMDeviceMask xpu = kCPU; OP; break; } \
case kGPU: TVM_RUN_CUDA(OP); break; \
case kOpenCL: TVM_RUN_OPENCL(OP); break; \
default: LOG(FATAL) << "unknown device_mask " << ctx.dev_mask; \
} }
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#include "./device_api_cpu.h"
#include "./cuda/device_api_cuda.h"
#include "./opencl/device_api_opencl.h"
#endif // TVM_RUNTIME_DEVICE_API_H_ #endif // TVM_RUNTIME_DEVICE_API_H_
/*!
* Copyright (c) 2017 by Contributors
* \file device_api_opencl.h
* \brief OpenCL specific API
*/
#ifndef TVM_RUNTIME_OPENCL_DEVICE_API_OPENCL_H_
#define TVM_RUNTIME_OPENCL_DEVICE_API_OPENCL_H_
#include <tvm/runtime/config.h>
#if TVM_OPENCL_RUNTIME
#include <string>
#include <vector>
#include "./opencl_common.h"
namespace tvm {
namespace runtime {
template<>
inline void* AllocDataSpace<kOpenCL>(TVMContext ctx, size_t size, size_t alignment) {
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
cl_int err_code;
cl_mem mptr = clCreateBuffer(
w->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
OPENCL_CHECK_ERROR(err_code);
return mptr;
}
template<>
inline void FreeDataSpace<kOpenCL>(TVMContext ctx, void* ptr) {
cl_mem mptr = static_cast<cl_mem>(ptr);
OPENCL_CALL(clReleaseMemObject(mptr));
}
template<>
inline void CopyDataFromTo<kOpenCL>(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) {
CHECK(stream == nullptr);
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
if (ctx_from.dev_mask == kOpenCL && ctx_to.dev_mask == kOpenCL) {
OPENCL_CALL(clEnqueueCopyBuffer(
w->GetQueue(ctx_to),
static_cast<cl_mem>((void*)from), // NOLINT(*)
static_cast<cl_mem>(to),
0, 0, size, 0, nullptr, nullptr));
} else if (ctx_from.dev_mask == kOpenCL && ctx_to.dev_mask == kCPU) {
OPENCL_CALL(clEnqueueReadBuffer(
w->GetQueue(ctx_from),
static_cast<cl_mem>((void*)from), // NOLINT(*)
CL_FALSE, 0, size, to,
0, nullptr, nullptr));
OPENCL_CALL(clFinish(w->GetQueue(ctx_from)));
} else if (ctx_from.dev_mask == kCPU && ctx_to.dev_mask == kOpenCL) {
OPENCL_CALL(clEnqueueWriteBuffer(
w->GetQueue(ctx_to),
static_cast<cl_mem>(to),
CL_FALSE, 0, size, from,
0, nullptr, nullptr));
OPENCL_CALL(clFinish(w->GetQueue(ctx_to)));
} else {
LOG(FATAL) << "Expect copy from/to GPU or between GPU";
}
}
template<>
inline void StreamSync<kOpenCL>(TVMContext ctx, TVMStreamHandle stream) {
CHECK(stream == nullptr);
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
OPENCL_CALL(clFinish(w->GetQueue(ctx)));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_OPENCL_RUNTIME
#endif // TVM_RUNTIME_OPENCL_DEVICE_API_OPENCL_H_
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
#include <tvm/runtime/c_runtime_api.h> #include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <dmlc/logging.h> #include <dmlc/logging.h>
#if TVM_OPENCL_RUNTIME #if TVM_OPENCL_RUNTIME
#include "../device_api.h"
#ifdef __APPLE__ #ifdef __APPLE__
#include <OpenCL/opencl.h> #include <OpenCL/opencl.h>
...@@ -101,7 +101,7 @@ inline const char* CLGetErrorString(cl_int error) { ...@@ -101,7 +101,7 @@ inline const char* CLGetErrorString(cl_int error) {
/*! /*!
* \brief Process global OpenCL workspace. * \brief Process global OpenCL workspace.
*/ */
class OpenCLWorkspace { class OpenCLWorkspace : public DeviceAPI {
public: public:
// global platform id // global platform id
cl_platform_id platform_id; cl_platform_id platform_id;
...@@ -132,13 +132,23 @@ class OpenCLWorkspace { ...@@ -132,13 +132,23 @@ class OpenCLWorkspace {
} }
// get the queue of the context // get the queue of the context
cl_command_queue GetQueue(TVMContext ctx) const { cl_command_queue GetQueue(TVMContext ctx) const {
CHECK_EQ(ctx.dev_mask, kOpenCL); CHECK_EQ(ctx.device_type, kOpenCL);
CHECK(initialized()) CHECK(initialized())
<< "The OpenCL is not initialized"; << "The OpenCL is not initialized";
CHECK(ctx.dev_id >= 0 && static_cast<size_t>(ctx.dev_id) < queues.size()) CHECK(ctx.device_id >= 0 && static_cast<size_t>(ctx.device_id) < queues.size())
<< "Invalid OpenCL dev_id=" << ctx.dev_id; << "Invalid OpenCL device_id=" << ctx.device_id;
return queues[ctx.dev_id]; return queues[ctx.device_id];
} }
// override device API
void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment) final;
void FreeDataSpace(TVMContext ctx, void* ptr) final;
void CopyDataFromTo(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) final;
void StreamSync(TVMContext ctx, TVMStreamHandle stream) final;
// get the global workspace // get the global workspace
static OpenCLWorkspace* Global(); static OpenCLWorkspace* Global();
}; };
...@@ -160,8 +170,8 @@ class OpenCLThreadEntry { ...@@ -160,8 +170,8 @@ class OpenCLThreadEntry {
std::vector<KTEntry> kernel_table; std::vector<KTEntry> kernel_table;
OpenCLThreadEntry() { OpenCLThreadEntry() {
context.dev_id = 0; context.device_id = 0;
context.dev_mask = kOpenCL; context.device_type = kOpenCL;
} }
// get the global workspace // get the global workspace
static OpenCLThreadEntry* ThreadLocal(); static OpenCLThreadEntry* ThreadLocal();
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file opencl_workspace.cc * \file opencl_device_api.cc
*/ */
#include "./opencl_common.h" #include "./opencl_common.h"
...@@ -18,6 +18,57 @@ OpenCLWorkspace* OpenCLWorkspace::Global() { ...@@ -18,6 +18,57 @@ OpenCLWorkspace* OpenCLWorkspace::Global() {
return &inst; return &inst;
} }
void* OpenCLWorkspace::AllocDataSpace(
TVMContext ctx, size_t size, size_t alignment) {
cl_int err_code;
cl_mem mptr = clCreateBuffer(
this->context, CL_MEM_READ_WRITE, size, nullptr, &err_code);
OPENCL_CHECK_ERROR(err_code);
return mptr;
}
void OpenCLWorkspace::FreeDataSpace(TVMContext ctx, void* ptr) {
cl_mem mptr = static_cast<cl_mem>(ptr);
OPENCL_CALL(clReleaseMemObject(mptr));
}
void OpenCLWorkspace::CopyDataFromTo(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) {
CHECK(stream == nullptr);
if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kOpenCL) {
OPENCL_CALL(clEnqueueCopyBuffer(
this->GetQueue(ctx_to),
static_cast<cl_mem>((void*)from), // NOLINT(*)
static_cast<cl_mem>(to),
0, 0, size, 0, nullptr, nullptr));
} else if (ctx_from.device_type == kOpenCL && ctx_to.device_type == kCPU) {
OPENCL_CALL(clEnqueueReadBuffer(
this->GetQueue(ctx_from),
static_cast<cl_mem>((void*)from), // NOLINT(*)
CL_FALSE, 0, size, to,
0, nullptr, nullptr));
OPENCL_CALL(clFinish(this->GetQueue(ctx_from)));
} else if (ctx_from.device_type == kCPU && ctx_to.device_type == kOpenCL) {
OPENCL_CALL(clEnqueueWriteBuffer(
this->GetQueue(ctx_to),
static_cast<cl_mem>(to),
CL_FALSE, 0, size, from,
0, nullptr, nullptr));
OPENCL_CALL(clFinish(this->GetQueue(ctx_to)));
} else {
LOG(FATAL) << "Expect copy from/to OpenCL or between OpenCL";
}
}
void OpenCLWorkspace::StreamSync(TVMContext ctx, TVMStreamHandle stream) {
CHECK(stream == nullptr);
OPENCL_CALL(clFinish(this->GetQueue(ctx)));
}
typedef dmlc::ThreadLocalStore<OpenCLThreadEntry> OpenCLThreadStore; typedef dmlc::ThreadLocalStore<OpenCLThreadEntry> OpenCLThreadStore;
OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() { OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() {
...@@ -141,6 +192,12 @@ bool InitOpenCL(TVMArgs args, TVMRetValue* rv) { ...@@ -141,6 +192,12 @@ bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
TVM_REGISTER_GLOBAL(_module_init_opencl) TVM_REGISTER_GLOBAL(_module_init_opencl)
.set_body(InitOpenCL); .set_body(InitOpenCL);
TVM_REGISTER_GLOBAL(_device_api_opencl)
.set_body([](TVMArgs args, TVMRetValue* rv) {
DeviceAPI* ptr = OpenCLWorkspace::Global();
*rv = static_cast<void*>(ptr);
});
} // namespace cl } // namespace cl
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -123,11 +123,11 @@ class OpenCLModuleNode : public ModuleNode { ...@@ -123,11 +123,11 @@ class OpenCLModuleNode : public ModuleNode {
const std::string& func_name, const std::string& func_name,
const KTRefEntry& e) { const KTRefEntry& e) {
std::lock_guard<std::mutex> lock(build_lock_); std::lock_guard<std::mutex> lock(build_lock_);
int dev_id = t->context.dev_id; int device_id = t->context.device_id;
if (!device_built_flag_[dev_id]) { if (!device_built_flag_[device_id]) {
// build program // build program
cl_int err; cl_int err;
cl_device_id dev = w->devices[dev_id]; cl_device_id dev = w->devices[device_id];
err = clBuildProgram(program_, 1, &dev, nullptr, nullptr, nullptr); err = clBuildProgram(program_, 1, &dev, nullptr, nullptr, nullptr);
if (err != CL_SUCCESS) { if (err != CL_SUCCESS) {
size_t len; size_t len;
...@@ -139,7 +139,7 @@ class OpenCLModuleNode : public ModuleNode { ...@@ -139,7 +139,7 @@ class OpenCLModuleNode : public ModuleNode {
program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr); program_, dev, CL_PROGRAM_BUILD_LOG, len, &log[0], nullptr);
LOG(FATAL) << "OpenCL build error for device=" << dev << log; LOG(FATAL) << "OpenCL build error for device=" << dev << log;
} }
device_built_flag_[dev_id] = true; device_built_flag_[device_id] = true;
} }
// build kernel // build kernel
cl_int err; cl_int err;
...@@ -246,23 +246,23 @@ void AutoSetOpenCLDevice(const TVMArgs& args, TVMRetValue* rv) { ...@@ -246,23 +246,23 @@ void AutoSetOpenCLDevice(const TVMArgs& args, TVMRetValue* rv) {
int num_args = args[2].operator int(); int num_args = args[2].operator int();
// TODO(tqchen): merge this with CUDA logic. // TODO(tqchen): merge this with CUDA logic.
int dev_id = -1; int device_id = -1;
for (int i = 0; i < num_args; ++i) { for (int i = 0; i < num_args; ++i) {
if (type_codes[i] == kArrayHandle) { if (type_codes[i] == kArrayHandle) {
TVMContext ctx = static_cast<TVMArray*>(values[i].v_handle)->ctx; TVMContext ctx = static_cast<TVMArray*>(values[i].v_handle)->ctx;
CHECK_EQ(ctx.dev_mask, kOpenCL) CHECK_EQ(ctx.device_type, kOpenCL)
<< "All operands need to be OpenCL"; << "All operands need to be OpenCL";
if (dev_id == -1) { if (device_id == -1) {
dev_id = ctx.dev_id; device_id = ctx.device_id;
} else { } else {
CHECK_EQ(dev_id, ctx.dev_id) CHECK_EQ(device_id, ctx.device_id)
<< "Operands comes from different devices "; << "Operands comes from different devices ";
} }
} }
} }
CHECK_NE(dev_id, -1) CHECK_NE(device_id, -1)
<< "Cannot detect device id from list"; << "Cannot detect device id from list";
cl::OpenCLThreadEntry::ThreadLocal()->context.dev_id = dev_id; cl::OpenCLThreadEntry::ThreadLocal()->context.device_id = device_id;
} }
PackedFunc OpenCLModuleNode::GetFunction( PackedFunc OpenCLModuleNode::GetFunction(
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment