Commit e9ff9a89 by Tianqi Chen Committed by GitHub

[RUNTIME] Finish GPU runtime and python interface (#16)

* [RUNTIME] Finish GPU runtime and python interface

* fix travis test

* fix build
parent f2f1526d
......@@ -91,3 +91,4 @@ ENV/
*.pyc
*~
build
config.mk
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
ifndef config
ifneq ("$(wildcard ./config.mk)","")
config = config.mk
else
config = make/config.mk
endif
endif
include $(config)
# specify tensor path
.PHONY: clean all test doc
......@@ -13,6 +19,30 @@ SRC = $(wildcard src/*.cc src/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
ifneq ($(USE_CUDA_PATH), NONE)
NVCC=$(USE_CUDA_PATH)/bin/nvcc
endif
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
ifneq ($(ADD_CFLAGS), NONE)
CFLAGS += $(ADD_CFLAGS)
endif
ifneq ($(ADD_LDFLAGS), NONE)
LDFLAGS += $(ADD_LDFLAGS)
endif
ifeq ($(USE_CUDA), 1)
CFLAGS += -DTVM_CUDA_RUNTIME=1
LDFLAGS += -lcuda -lcudart
else
CFLAGS += -DTVM_CUDA_RUNTIME=0
endif
include tests/cpp/unittest.mk
test: $(TEST)
......
......@@ -17,6 +17,20 @@
namespace tvm {
/*!
*\brief whether to use CUDA runtime
*/
#ifndef TVM_CUDA_RUNTIME
#define TVM_CUDA_RUNTIME 1
#endif
/*!
*\brief whether to use opencl runtime
*/
#ifndef TVM_OPENCL_RUNTIME
#define TVM_OPENCL_RUNTIME 0
#endif
using ::tvm::Node;
using ::tvm::NodeRef;
using ::tvm::AttrVisitor;
......
......@@ -34,7 +34,7 @@
TVM_EXTERN_C {
/*! \brief type of array index. */
typedef unsigned tvm_index_t;
typedef uint32_t tvm_index_t;
/*!
* \brief union type for arguments and return values
......@@ -68,7 +68,7 @@ typedef enum {
/*! \brief NVidia GPU device(CUDA) */
kGPU = 2,
/*! \brief opencl device */
KOpenCL = 4
kOpenCL = 4
} TVMDeviceMask;
/*!
......@@ -79,7 +79,7 @@ typedef struct {
int dev_mask;
/*! \brief the device id */
int dev_id;
} TVMDevice;
} TVMContext;
/*! \brief The type code in TVMDataType */
typedef enum {
......@@ -122,8 +122,8 @@ typedef struct {
tvm_index_t ndim;
/*! \brief The data type flag */
TVMDataType dtype;
/*! \brief The device this array sits on */
TVMDevice device;
/*! \brief The device context this array sits on */
TVMContext ctx;
} TVMArray;
/*!
......@@ -151,20 +151,30 @@ typedef TVMArray* TVMArrayHandle;
TVM_DLL const char *TVMGetLastError(void);
/*!
* \brief Whether the specified context is enabled.
*
* \param ctx The context to be checked.
* \param out_enabled whether the ctx is enabled.
* \return Whether the function is successful.
*/
TVM_DLL int TVMContextEnabled(TVMContext ctx,
int* out_enabled);
/*!
* \brief Allocate a nd-array's memory,
* including space of shape, of given spec.
*
* \param shape The shape of the array, the data content will be copied to out
* \param ndim The number of dimension of the array.
* \param dtype The array data type.
* \param device The device this array sits on.
* \param ctx The ctx this array sits on.
* \param out The output handle.
* \return Whether the function is successful.
*/
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim,
int dtype,
TVMDevice device,
TVMDataType dtype,
TVMContext ctx,
TVMArrayHandle* out);
/*!
* \brief Free the TVM Array.
......@@ -183,9 +193,10 @@ TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMStreamHandle stream);
/*!
* \brief Wait until all computations on stream completes.
* \param stream the stream to be synchronized.
* \param ctx The ctx to be synchronized.
* \param stream The stream to be synchronized.
*/
TVM_DLL int TVMSynchronize(TVMStreamHandle stream);
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
/*!
* \brief Launch a generated TVM function
......
#-------------------------------------------------------------------------------
# Template configuration for compiling
#
# If you want to change the configuration, please use the following
# steps. Assume you are on the root directory. First copy the this
# file so that any local changes will be ignored by git
#
# $ cp make/config.mk .
#
# Next modify the according entries, and then compile by
#
# $ make
#
# or build in parallel with 8 threads
#
# $ make -j8
#-------------------------------------------------------------------------------
#---------------------
# choice of compiler
#--------------------
export NVCC = nvcc
# whether compile with debug
DEBUG = 0
# the additional link flags you want to add
ADD_LDFLAGS =
# the additional compile flags you want to add
ADD_CFLAGS =
#---------------------------------------------
# matrix computation libraries for CPU/GPU
#---------------------------------------------
# whether use CUDA during compile
USE_CUDA = 1
# add the path to CUDA library to link and compile flag
# if you have already add them to environment variable, leave it as NONE
# USE_CUDA_PATH = /usr/local/cuda
USE_CUDA_PATH = NONE
# whether use cuda runtime compiling for writing kernels in native language (i.e. Python)
USE_NVRTC = 0
......@@ -3,7 +3,7 @@
from __future__ import absolute_import as _abs
from ._ctypes._api import register_node
from . import tensor as tensor
from . import tensor
from . import expr
from . import stmt
from . import make
......@@ -11,5 +11,8 @@ from . import ir_pass
from . import collections
from . import schedule
from . import ndarray as nd
from .ndarray import cpu, gpu, opencl
from ._base import TVMError
from .function import *
......@@ -58,6 +58,7 @@ def check_call(ret):
if ret != 0:
raise TVMError(py_str(_LIB.TVMGetLastError()))
def c_str(string):
"""Create ctypes char * from a python string
Parameters
......@@ -72,6 +73,26 @@ def c_str(string):
"""
return ctypes.c_char_p(string.encode('utf-8'))
def c_array(ctype, values):
"""Create ctypes array from a python array
Parameters
----------
ctype : ctypes data type
data type of the array we want to convert to
values : tuple or list
data content
Returns
-------
out : ctypes array
Created ctypes array
"""
return (ctype * len(values))(*values)
def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.
......
# pylint: disable=invalid-name, protected-access, too-many-arguments, global-statement
# pylint: disable=attribute-defined-outside-init, no-member, missing-docstring
"""Symbolic configuration API."""
from __future__ import absolute_import as _abs
import ctypes
import numpy as np
from .._base import _LIB
from .._base import c_array
from .._base import check_call
tvm_index_t = ctypes.c_uint32
class TVMContext(ctypes.Structure):
"""TVM context strucure."""
_fields_ = [("dev_mask", ctypes.c_int),
("dev_id", ctypes.c_int)]
MASK2STR = {
1 : 'cpu',
2 : 'gpu',
4 : 'opencl'
}
def __init__(self, dev_mask, dev_id):
super(TVMContext, self).__init__()
self.dev_mask = dev_mask
self.dev_id = dev_id
def __repr__(self):
return "%s(%d)" % (
TVMContext.MASK2STR[self.dev_mask], self.dev_id)
@property
def enabled(self):
ret = ctypes.c_int()
check_call(_LIB.TVMContextEnabled(self, ctypes.byref(ret)))
return ret.value != 0
def cpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
"""
return TVMContext(1, dev_id)
def gpu(dev_id=0):
"""Construct a CPU device
Parameters
----------
dev_id : int, optional
The integer device id
"""
return TVMContext(2, dev_id)
def opencl(dev_id=0):
"""Construct a OpenCL device
Parameters
----------
dev_id : int, optional
The integer device id
"""
return TVMContext(4, dev_id)
class TVMDataType(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
("lanes", ctypes.c_uint16)]
CODE2STR = {
0 : 'int',
1 : 'uint',
2 : 'float'
}
def __init__(self, type_str, lanes=1):
super(TVMDataType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
else:
raise ValueError("Donot know how to handle type %s" % type_str)
bits = 32 if bits == 0 else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes
def __repr__(self):
x = "%s%d" % (TVMDataType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
class TVMArray(ctypes.Structure):
"""TVMArg in C API"""
_fields_ = [("data", ctypes.c_void_p),
("shape", ctypes.POINTER(tvm_index_t)),
("strides", ctypes.POINTER(tvm_index_t)),
("ndim", tvm_index_t),
("dtype", TVMDataType),
("ctx", TVMContext)]
TVMArrayHandle = ctypes.POINTER(TVMArray)
def numpyasarray(np_data):
"""Return a TVMArray representation of a numpy array.
"""
data = np_data
assert data.flags['C_CONTIGUOUS']
arr = TVMArray()
shape = c_array(tvm_index_t, data.shape)
arr.data = data.ctypes.data_as(ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = TVMDataType(np.dtype(data.dtype).name)
arr.ndim = data.ndim
# CPU device
arr.ctx = cpu(0)
return arr, shape
_ndarray_cls = None
def empty(shape, dtype="float32", ctx=cpu(0)):
"""Create an empty array given shape and device
Parameters
----------
shape : tuple of int
The shape of the array
dtype : type or str
The data type of the array.
ctx : TVMContext
The context of the array
Returns
-------
arr : tvm.nd.NDArray
The array tvm supported.
"""
shape = c_array(tvm_index_t, shape)
ndim = tvm_index_t(len(shape))
handle = TVMArrayHandle()
dtype = TVMDataType(dtype)
check_call(_LIB.TVMArrayAlloc(
shape, ndim, dtype, ctx, ctypes.byref(handle)))
return _ndarray_cls(handle)
def sync(ctx):
"""Synchronize all the context
Parameters
----------
ctx : TVMContext
The context to be synced
"""
check_call(_LIB.TVMSynchronize(ctx, None))
class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime."""
__slots__ = ["handle"]
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : TVMArrayHandle
the handle to the underlying C++ TVMArray
"""
self.handle = handle
def __del__(self):
check_call(_LIB.TVMArrayFree(self.handle))
@property
def shape(self):
"""Shape of this array"""
return tuple(self.handle.contents.shape[i] for i in range(self.handle.contents.ndim))
@property
def dtype(self):
"""Type of this array"""
return str(self.handle.contents.dtype)
@property
def ctx(self):
"""context of this array"""
return self.handle.contents.ctx
@property
def context(self):
"""context of this array"""
return self.ctx
def __setitem__(self, in_slice, value):
"""Set ndarray value"""
if (not isinstance(in_slice, slice) or
in_slice.start is not None
or in_slice.stop is not None):
raise ValueError('Array only support set from numpy array')
if isinstance(value, NDArrayBase):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)):
self._sync_copyfrom(value)
else:
raise TypeError('type %s not supported' % str(type(value)))
def _sync_copyfrom(self, source_array):
"""Peform an synchronize copy from the array.
Parameters
----------
source_array : array_like
The data source we should like to copy from.
"""
if not isinstance(source_array, np.ndarray):
try:
source_array = np.array(source_array, dtype=self.dtype)
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))
source_array = np.ascontiguousarray(source_array, dtype=self.dtype)
if source_array.shape != self.shape:
raise ValueError('array shape do not match the shape of NDArray')
source_tvm_arr, shape = numpyasarray(source_array)
check_call(_LIB.TVMArrayCopyFromTo(
ctypes.byref(source_tvm_arr), self.handle, None))
# de-allocate shape until now
_ = shape
def asnumpy(self):
"""Convert this array to numpy array
Returns
-------
np_arr : numpy.ndarray
The corresponding numpy array.
"""
np_arr = np.empty(self.shape, dtype=self.dtype)
tvm_arr, shape = numpyasarray(np_arr)
check_call(_LIB.TVMArrayCopyFromTo(
self.handle, ctypes.byref(tvm_arr), None))
_ = shape
return np_arr
def copyto(self, target):
"""Copy array to target
Parameters
----------
target : tvm.NDArray
The target array to be copied, must have same shape as this array.
"""
if isinstance(target, TVMContext):
target = empty(self.shape, self.dtype, target)
if isinstance(target, NDArrayBase):
check_call(_LIB.TVMArrayCopyFromTo(
self.handle, target.handle, None))
else:
raise ValueError("Unsupported target type %s" % str(type(target)))
return target
def _init_runtime_module(ndarray_class):
global _ndarray_cls
_ndarray_cls = ndarray_class
"""TVM Runtime API.
This is a simplified runtime API for quick testing and proptyping.
"""
# pylint: disable=unused-import
from __future__ import absolute_import as _abs
import numpy as _np
from ._ctypes._runtime_api import TVMContext, TVMDataType, NDArrayBase
from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync
from ._ctypes._runtime_api import _init_runtime_module
class NDArray(NDArrayBase):
"""Lightweight NDArray class of TVM runtime.
Strictly this is only an Array Container(a buffer object)
No arthimetic operations are defined.
All operations are performed by TVM functions.
The goal is not to re-build yet another array library.
Instead, this is a minimal data structure to demonstrate
how can we use TVM in existing project which might have their own array containers.
"""
pass
def array(arr, ctx=cpu(0)):
"""Create an array from source arr.
Parameters
----------
arr : numpy.ndarray
The array to be copied from
ctx : TVMContext
The device context to create the array
Returns
-------
ret : tvm.nd.NDArray
The created array
"""
if not isinstance(arr, _np.ndarray):
arr = _np.array(arr)
ret = empty(arr.shape, arr.dtype, ctx)
ret[:] = arr
return ret
_init_runtime_module(NDArray)
......@@ -14,6 +14,6 @@
#include <string>
#include <exception>
#include "./c_api_registry.h"
#include "../runtime/runtime_common.h"
#include "../runtime/runtime_base.h"
#endif // TVM_C_API_C_API_COMMON_H_
/*!
* Copyright (c) 2016 by Contributors
* \file c_runtime_api.cc
* \brief Device specific implementations
*/
#include <tvm/c_runtime_api.h>
#include <algorithm>
#include "./runtime_base.h"
#include "./device_api.h"
namespace tvm {
namespace runtime {
inline TVMArray* TVMArrayCreate_() {
TVMArray* arr = new TVMArray();
arr->shape = nullptr;
arr->strides = nullptr;
arr->ndim = 0;
arr->data = nullptr;
return arr;
}
inline void TVMArrayFree_(TVMArray* arr) {
if (arr != nullptr) {
// ok to delete nullptr
delete[] arr->shape;
delete[] arr->strides;
if (arr->data != nullptr) {
TVM_DEVICE_SWITCH(arr->ctx, {
FreeDataSpace<xpu>(arr->ctx, arr->data);
});
}
}
delete arr;
}
inline void VerifyType(TVMDataType dtype) {
CHECK_GE(dtype.lanes, 1U);
if (dtype.type_code == kFloat) {
CHECK_EQ(dtype.bits % 32U, 0U);
} else {
CHECK_EQ(dtype.bits % 8U, 0U);
}
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
}
inline size_t GetDataSize(TVMArray* arr) {
size_t size = 1;
for (tvm_index_t i = 0; i < arr->ndim; ++i) {
size *= arr->shape[i];
}
size *= (arr->dtype.bits / 8) * arr->dtype.lanes;
return size;
}
inline size_t GetDataAlignment(TVMArray* arr) {
size_t align = (arr->dtype.bits / 8) * arr->dtype.lanes;
if (align < 8) return 8;
return align;
}
} // namespace runtime
} // namespace tvm
using namespace tvm::runtime;
int TVMContextEnabled(TVMContext ctx,
int* out_enabled) {
API_BEGIN();
if (ctx.dev_mask == kGPU && TVM_CUDA_RUNTIME == 0) {
*out_enabled = 0;
} else if (ctx.dev_mask == kOpenCL && TVM_OPENCL_RUNTIME == 0) {
*out_enabled = 0;
} else {
TVM_DEVICE_SWITCH(ctx, {
*out_enabled = CheckEnabled<xpu>(ctx);
});
}
API_END();
}
int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim,
TVMDataType dtype,
TVMContext ctx,
TVMArrayHandle* out) {
TVMArray* arr = nullptr;
API_BEGIN();
// shape
arr = TVMArrayCreate_();
// ndim
arr->ndim = ndim;
// dtype
VerifyType(dtype);
arr->dtype = dtype;
tvm_index_t* shape_copy = new tvm_index_t[ndim];
std::copy(shape, shape + ndim, shape_copy);
arr->shape = shape_copy;
// ctx
arr->ctx = ctx;
size_t size = GetDataSize(arr);
size_t alignment = GetDataAlignment(arr);
// ctx data pointer
TVM_DEVICE_SWITCH(ctx, {
arr->data = AllocDataSpace<xpu>(ctx, size, alignment);
});
*out = arr;
API_END_HANDLE_ERROR(TVMArrayFree_(arr));
}
int TVMArrayFree(TVMArrayHandle handle) {
API_BEGIN();
TVMArray* arr = handle;
TVMArrayFree_(arr);
API_END();
}
int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMArrayHandle to,
TVMStreamHandle stream) {
API_BEGIN();
size_t from_size = GetDataSize(from);
size_t to_size = GetDataSize(to);
CHECK_EQ(from_size, to_size)
<< "TVMArrayCopyFromTo: The size must exactly match";
TVMContext ctx = from->ctx;
if (ctx.dev_mask == kCPU) {
ctx = to->ctx;
} else {
CHECK(to->ctx.dev_mask == kCPU ||
to->ctx.dev_mask == from->ctx.dev_mask)
<< "Can not copy across different ctx types directly";
}
TVM_DEVICE_SWITCH(ctx, {
CopyDataFromTo<xpu>(from->data, to->data,
from_size,
from->ctx,
to->ctx,
stream);
});
API_END();
}
int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
API_BEGIN();
TVM_DEVICE_SWITCH(ctx, {
StreamSync<xpu>(ctx, stream);
});
API_END();
}
/*!
* Copyright (c) 2016 by Contributors
* \file device_api.hx
* \brief Device specific API
*/
#ifndef TVM_RUNTIME_DEVICE_API_H_
#define TVM_RUNTIME_DEVICE_API_H_
#include <tvm/base.h>
#include <tvm/c_runtime_api.h>
namespace tvm {
namespace runtime {
/*!
* \brief Whether ctx is enabled.
* \param ctx The device context to perform operation.
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline bool CheckEnabled(TVMContext ctx) {
return true;
}
/*!
* \brief Allocate a data space on device.
* \param ctx The device context to perform operation.
* \param size The size of the memory
* \param alignment The alignment of the memory.
* \return The allocated device pointer
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline void* AllocDataSpace(TVMContext ctx, size_t size, size_t alignment);
/*!
* \brief Free a data space on device.
* \param ctx The device context to perform operation.
* \param ptr The data space.
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline void FreeDataSpace(TVMContext ctx, void* ptr);
/*!
* \brief copy data from one place to another
* \param dev The device to perform operation.
* \param from The source array.
* \param to The target array.
* \param size The size of the memory
* \param ctx_from The source context
* \param ctx_to The target context
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline void CopyDataFromTo(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream);
/*!
* \brief Synchronize the stream
* \param ctx The context to perform operation.
* \param stream The stream to be sync.
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline void StreamSync(TVMContext ctx, TVMStreamHandle stream);
// macro to run cuda related code
#if TVM_CUDA_RUNTIME
#define TVM_RUN_CUDA(OP) { const TVMDeviceMask xpu = kGPU; OP; }
#else
#define TVM_RUN_CUDA(OP) LOG(FATAL) << "CUDA is not enabled";
#endif
// macro to run opencl related code
#if TVM_OPENCL_RUNTIME
#define TVM_RUN_OPENCL(OP) { const TVMDeviceMask xpu = kOpenCL; OP; }
#else
#define TVM_RUN_OPENCL(OP) LOG(FATAL) << "OpenCL is not enabled";
#endif
// macro to switch options between devices
#define TVM_DEVICE_SWITCH(ctx, OP) \
switch (ctx.dev_mask) { \
case kCPU: { const TVMDeviceMask xpu = kCPU; OP; break; } \
case kGPU: TVM_RUN_CUDA(OP); break; \
case kOpenCL: TVM_RUN_OPENCL(OP); break; \
default: LOG(FATAL) << "unknown device_mask " << ctx.dev_mask; \
}
} // namespace runtime
} // namespace tvm
#include "./device_api_gpu.h"
#include "./device_api_cpu.h"
#endif // TVM_RUNTIME_DEVICE_API_H_
/*!
* Copyright (c) 2016 by Contributors
* \file device_api_gpu.h
* \brief GPU specific API
*/
#ifndef TVM_RUNTIME_DEVICE_API_CPU_H_
#define TVM_RUNTIME_DEVICE_API_CPU_H_
#include <dmlc/logging.h>
#include <cstdlib>
#include <cstring>
#include "./device_api.h"
namespace tvm {
namespace runtime {
template<>
void* AllocDataSpace<kCPU>(TVMContext ctx, size_t size, size_t alignment) {
void* ptr;
#if _MSC_VER
ptr = _aligned_malloc(size, alignment);
if (ptr == nullptr) throw std::bad_alloc();
#else
int ret = posix_memalign(&ptr, alignment, size);
if (ret != 0) throw std::bad_alloc();
#endif
return ptr;
}
template<>
void FreeDataSpace<kCPU>(TVMContext ctx, void* ptr) {
#if _MSC_VER
_aligned_free(ptr);
#else
free(ptr);
#endif
}
template<>
void CopyDataFromTo<kCPU>(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) {
memcpy(to, from, size);
}
template<>
void StreamSync<kCPU>(TVMContext ctx, TVMStreamHandle stream) {
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_DEVICE_API_CPU_H_
/*!
* Copyright (c) 2016 by Contributors
* \file ctxice_api_gpu.h
* \brief GPU specific API
*/
#ifndef TVM_RUNTIME_DEVICE_API_GPU_H_
#define TVM_RUNTIME_DEVICE_API_GPU_H_
#include <dmlc/logging.h>
#include "./device_api.h"
#if TVM_CUDA_RUNTIME
#include <cuda_runtime.h>
namespace tvm {
namespace runtime {
/*!
* \brief Check CUDA error.
* \param msg Message to print if an error occured.
*/
#define CHECK_CUDA_ERROR(msg) \
{ \
cudaError_t e = cudaGetLastError(); \
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
}
/*!
* \brief Protected CUDA call.
* \param func Expression to call.
*
* It checks for CUDA errors after invocation of the expression.
*/
#define CUDA_CALL(func) \
{ \
cudaError_t e = (func); \
CHECK(e == cudaSuccess || e == cudaErrorCudartUnloading) \
<< "CUDA: " << cudaGetErrorString(e); \
}
template<>
inline void* AllocDataSpace<kGPU>(TVMContext ctx, size_t size, size_t alignment) {
CUDA_CALL(cudaSetDevice(ctx.dev_id));
CHECK_EQ(256 % alignment, 0U)
<< "CUDA space is aligned at 256 bytes";
void *ret;
CUDA_CALL(cudaMalloc(&ret, size));
return ret;
}
template<>
inline void FreeDataSpace<kGPU>(TVMContext ctx, void* ptr) {
CUDA_CALL(cudaSetDevice(ctx.dev_id));
CUDA_CALL(cudaFree(ptr));
}
inline void GPUCopy(const void* from,
void* to,
size_t size,
cudaMemcpyKind kind,
cudaStream_t stream) {
if (stream != 0) {
CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
} else {
CUDA_CALL(cudaMemcpy(to, from, size, kind));
}
}
template<>
inline void CopyDataFromTo<kGPU>(const void* from,
void* to,
size_t size,
TVMContext ctx_from,
TVMContext ctx_to,
TVMStreamHandle stream) {
cudaStream_t cu_stream = static_cast<cudaStream_t>(stream);
if (ctx_from.dev_mask == kGPU && ctx_to.dev_mask == kGPU) {
CUDA_CALL(cudaSetDevice(ctx_from.dev_id));
if (ctx_from.dev_id == ctx_to.dev_id) {
GPUCopy(from, to, size, cudaMemcpyDeviceToDevice, cu_stream);
} else {
cudaMemcpyPeerAsync(to, ctx_to.dev_id,
from, ctx_from.dev_id,
size, cu_stream);
}
} else if (ctx_from.dev_mask == kGPU && ctx_to.dev_mask == kCPU) {
CUDA_CALL(cudaSetDevice(ctx_from.dev_id));
GPUCopy(from, to, size, cudaMemcpyDeviceToHost, cu_stream);
} else if (ctx_from.dev_mask == kCPU && ctx_to.dev_mask == kGPU) {
CUDA_CALL(cudaSetDevice(ctx_to.dev_id));
GPUCopy(from, to, size, cudaMemcpyHostToDevice, cu_stream);
} else {
LOG(FATAL) << "expect copy from/to GPU or between GPU";
}
}
template<>
inline void StreamSync<kGPU>(TVMContext ctx, TVMStreamHandle stream) {
CUDA_CALL(cudaSetDevice(ctx.dev_id));
CUDA_CALL(cudaStreamSynchronize(
static_cast<cudaStream_t>(stream)));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_CUDA_RUNTIME
#endif // TVM_RUNTIME_DEVICE_API_GPU_H_
......@@ -5,7 +5,7 @@
*/
#include <dmlc/thread_local.h>
#include <string>
#include "./runtime_common.h"
#include "./runtime_base.h"
struct TVMErrorEntry {
std::string last_error;
......
/*!
* Copyright (c) 2016 by Contributors
* \file c_runtime_common.h
* \brief Common fields of all C APIs
* \file runtime_base.h
* \brief Base of all C APIs
*/
#ifndef TVM_RUNTIME_RUNTIME_COMMON_H_
#define TVM_RUNTIME_RUNTIME_COMMON_H_
#ifndef TVM_RUNTIME_RUNTIME_BASE_H_
#define TVM_RUNTIME_RUNTIME_BASE_H_
#include <tvm/c_runtime_api.h>
#include <exception>
#include <stdexcept>
/*! \brief macro to guard beginning and end section of all functions */
#define API_BEGIN() try {
......@@ -33,4 +33,4 @@ inline int TVMAPIHandleException(const std::runtime_error &e) {
return -1;
}
#endif // TVM_RUNTIME_RUNTIME_COMMON_H_
#endif // TVM_RUNTIME_RUNTIME_BASE_H_
import tvm
import numpy as np
def enabled_ctx_list():
ctx_list = [tvm.cpu(0), tvm.gpu(0), tvm.opencl(0)]
ctx_list = [ctx for ctx in ctx_list if ctx.enabled]
return ctx_list
ENABLED_CTX_LIST = enabled_ctx_list()
print("Testing using contexts:", ENABLED_CTX_LIST)
def test_nd_create():
for ctx in ENABLED_CTX_LIST:
for dtype in ["float32", "int8", "uint16"]:
x = np.random.randint(0, 10, size=(3, 4))
x = np.array(x, dtype=dtype)
y = tvm.nd.array(x, ctx=ctx)
z = y.copyto(ctx)
assert y.dtype == x.dtype
assert y.shape == x.shape
assert isinstance(y, tvm.nd.NDArray)
np.testing.assert_equal(x, y.asnumpy())
np.testing.assert_equal(x, z.asnumpy())
# no need here, just to test usablity
tvm.nd.sync(ctx)
if __name__ == "__main__":
test_nd_create()
......@@ -14,6 +14,9 @@ if [ ${TASK} == "lint" ] || [ ${TASK} == "all_test" ]; then
fi
fi
cp make/config.mk config.mk
echo "USE_CUDA=0" >> config.mk
echo "USE_OPENCL=0" >> config.mk
if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then
# use g++-4.8 for linux
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment