Commit 9e1a5ec4 by Tianqi Chen Committed by GitHub

[RUNTIME] Enable OpenCL (#17)

parent e9ff9a89
...@@ -26,6 +26,7 @@ endif ...@@ -26,6 +26,7 @@ endif
export LDFLAGS = -pthread -lm export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2\ export CFLAGS = -std=c++11 -Wall -O2\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -Iinclude -Idmlc-core/include -IHalideIR/src -fPIC
export FRAMEWORKS=
ifneq ($(ADD_CFLAGS), NONE) ifneq ($(ADD_CFLAGS), NONE)
CFLAGS += $(ADD_CFLAGS) CFLAGS += $(ADD_CFLAGS)
...@@ -43,6 +44,20 @@ else ...@@ -43,6 +44,20 @@ else
CFLAGS += -DTVM_CUDA_RUNTIME=0 CFLAGS += -DTVM_CUDA_RUNTIME=0
endif endif
ifeq ($(USE_OPENCL), 1)
CFLAGS += -DTVM_OPENCL_RUNTIME=1
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S), Darwin)
FRAMEWORKS += -framework OpenCL
else
LDFLAGS += -lOpenCL
endif
else
CFLAGS += -DTVM_OPENCL_RUNTIME=0
endif
include tests/cpp/unittest.mk include tests/cpp/unittest.mk
test: $(TEST) test: $(TEST)
...@@ -59,7 +74,7 @@ lib/libtvm.a: $(ALL_DEP) ...@@ -59,7 +74,7 @@ lib/libtvm.a: $(ALL_DEP)
lib/libtvm.so: $(ALL_DEP) lib/libtvm.so: $(ALL_DEP)
@mkdir -p $(@D) @mkdir -p $(@D)
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS) $(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
$(LIB_HALIDE_IR): LIBHALIDEIR $(LIB_HALIDE_IR): LIBHALIDEIR
......
...@@ -151,6 +151,23 @@ typedef TVMArray* TVMArrayHandle; ...@@ -151,6 +151,23 @@ typedef TVMArray* TVMArrayHandle;
TVM_DLL const char *TVMGetLastError(void); TVM_DLL const char *TVMGetLastError(void);
/*! /*!
* \brief Initialize certain type of devices, this may
* not be necessary for all device types. But is needed for OpenCL.
*
* \param dev_mask The device mask of device type to be initialized
* \param option_keys Additional option keys to pass.
* \param option_vals Additional option values to pass
* \param num_options Number of options to be passed into it.
* \param out_code 1: success, 0: already initialized
* \return Whether the function is successful.
*/
TVM_DLL int TVMDeviceInit(int dev_mask,
const char** option_keys,
const char** option_vals,
int num_options,
int *out_code);
/*!
* \brief Whether the specified context is enabled. * \brief Whether the specified context is enabled.
* *
* \param ctx The context to be checked. * \param ctx The context to be checked.
......
...@@ -37,6 +37,9 @@ ADD_CFLAGS = ...@@ -37,6 +37,9 @@ ADD_CFLAGS =
# whether use CUDA during compile # whether use CUDA during compile
USE_CUDA = 1 USE_CUDA = 1
# whether use OpenCL during compile
USE_OPENCL = 0
# add the path to CUDA library to link and compile flag # add the path to CUDA library to link and compile flag
# if you have already add them to environment variable, leave it as NONE # if you have already add them to environment variable, leave it as NONE
# USE_CUDA_PATH = /usr/local/cuda # USE_CUDA_PATH = /usr/local/cuda
......
...@@ -12,7 +12,7 @@ from . import collections ...@@ -12,7 +12,7 @@ from . import collections
from . import schedule from . import schedule
from . import ndarray as nd from . import ndarray as nd
from .ndarray import cpu, gpu, opencl from .ndarray import cpu, gpu, opencl, init_opencl
from ._base import TVMError from ._base import TVMError
from .function import * from .function import *
...@@ -7,7 +7,7 @@ import ctypes ...@@ -7,7 +7,7 @@ import ctypes
import numpy as np import numpy as np
from .._base import _LIB from .._base import _LIB
from .._base import c_array from .._base import c_array, c_str
from .._base import check_call from .._base import check_call
...@@ -182,6 +182,30 @@ def sync(ctx): ...@@ -182,6 +182,30 @@ def sync(ctx):
check_call(_LIB.TVMSynchronize(ctx, None)) check_call(_LIB.TVMSynchronize(ctx, None))
def init_opencl(**kwargs):
"""Initialize the opencl with the options.
Parameters
----------
kwargs : dict
The options
"""
keys = []
vals = []
for k, v in kwargs.items():
keys.append(c_str(k))
vals.append(c_str(v))
dev_mask = ctypes.c_int(4)
out_code = ctypes.c_int()
check_call(_LIB.TVMDeviceInit(
dev_mask,
c_array(ctypes.c_char_p, keys),
c_array(ctypes.c_char_p, vals),
ctypes.c_int(len(keys)),
ctypes.byref(out_code)))
return out_code.value != 0
class NDArrayBase(object): class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime.""" """A simple Device/CPU Array object in runtime."""
__slots__ = ["handle"] __slots__ = ["handle"]
......
...@@ -9,6 +9,7 @@ import numpy as _np ...@@ -9,6 +9,7 @@ import numpy as _np
from ._ctypes._runtime_api import TVMContext, TVMDataType, NDArrayBase from ._ctypes._runtime_api import TVMContext, TVMDataType, NDArrayBase
from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync from ._ctypes._runtime_api import cpu, gpu, opencl, empty, sync
from ._ctypes._runtime_api import _init_runtime_module from ._ctypes._runtime_api import _init_runtime_module
from ._ctypes._runtime_api import init_opencl
class NDArray(NDArrayBase): class NDArray(NDArrayBase):
......
...@@ -24,7 +24,7 @@ class Schedule(NodeBase): ...@@ -24,7 +24,7 @@ class Schedule(NodeBase):
k = k.op k = k.op
if not isinstance(k, _tensor.Operation): if not isinstance(k, _tensor.Operation):
raise ValueError("Expect schedule key to be Tensor or Operation") raise ValueError("Expect schedule key to be Tensor or Operation")
if not k in self.stage_map: if k not in self.stage_map:
raise ValueError("Cannot find the operation %s in schedule" % (str(k))) raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
return self.stage_map[k] return self.stage_map[k]
......
...@@ -64,6 +64,23 @@ inline size_t GetDataAlignment(TVMArray* arr) { ...@@ -64,6 +64,23 @@ inline size_t GetDataAlignment(TVMArray* arr) {
using namespace tvm::runtime; using namespace tvm::runtime;
int TVMDeviceInit(int dev_mask,
const char** option_keys,
const char** option_vals,
int num_options,
int* out_code) {
API_BEGIN();
*out_code = 1;
switch (dev_mask) {
case kOpenCL: {
*out_code = DeviceInit<kOpenCL>(option_keys, option_vals, num_options);
break;
}
default: break;
}
API_END();
}
int TVMContextEnabled(TVMContext ctx, int TVMContextEnabled(TVMContext ctx,
int* out_enabled) { int* out_enabled) {
API_BEGIN(); API_BEGIN();
......
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file device_api.hx * \file device_api.h
* \brief Device specific API * \brief Device specific API
*/ */
#ifndef TVM_RUNTIME_DEVICE_API_H_ #ifndef TVM_RUNTIME_DEVICE_API_H_
...@@ -12,6 +12,21 @@ ...@@ -12,6 +12,21 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*! /*!
* \brief Initialize the device.
* \param option_keys Additional option keys to pass.
* \param option_vals Additional option values to pass
* \param num_options Number of options to be passed into it.
* \return 0 if success, 1: if already initialized
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline bool DeviceInit(const char** option_keys,
const char** option_vals,
int num_options) {
return true;
}
/*!
* \brief Whether ctx is enabled. * \brief Whether ctx is enabled.
* \param ctx The device context to perform operation. * \param ctx The device context to perform operation.
* \tparam xpu The device mask. * \tparam xpu The device mask.
...@@ -93,7 +108,8 @@ inline void StreamSync(TVMContext ctx, TVMStreamHandle stream); ...@@ -93,7 +108,8 @@ inline void StreamSync(TVMContext ctx, TVMStreamHandle stream);
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#include "./device_api_gpu.h"
#include "./device_api_cpu.h" #include "./device_api_cpu.h"
#include "./device_api_gpu.h"
#include "./device_api_opencl.h"
#endif // TVM_RUNTIME_DEVICE_API_H_ #endif // TVM_RUNTIME_DEVICE_API_H_
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file ctxice_api_gpu.h * \file device_api_gpu.h
* \brief GPU specific API * \brief GPU specific API
*/ */
#ifndef TVM_RUNTIME_DEVICE_API_GPU_H_ #ifndef TVM_RUNTIME_DEVICE_API_GPU_H_
...@@ -14,15 +14,6 @@ ...@@ -14,15 +14,6 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*!
* \brief Check CUDA error.
* \param msg Message to print if an error occured.
*/
#define CHECK_CUDA_ERROR(msg) \
{ \
cudaError_t e = cudaGetLastError(); \
CHECK_EQ(e, cudaSuccess) << (msg) << " CUDA: " << cudaGetErrorString(e); \
}
/*! /*!
* \brief Protected CUDA call. * \brief Protected CUDA call.
......
...@@ -2,6 +2,7 @@ import tvm ...@@ -2,6 +2,7 @@ import tvm
import numpy as np import numpy as np
def enabled_ctx_list(): def enabled_ctx_list():
tvm.init_opencl()
ctx_list = [tvm.cpu(0), tvm.gpu(0), tvm.opencl(0)] ctx_list = [tvm.cpu(0), tvm.gpu(0), tvm.opencl(0)]
ctx_list = [ctx for ctx in ctx_list if ctx.enabled] ctx_list = [ctx for ctx in ctx_list if ctx.enabled]
return ctx_list return ctx_list
......
...@@ -16,13 +16,15 @@ fi ...@@ -16,13 +16,15 @@ fi
cp make/config.mk config.mk cp make/config.mk config.mk
echo "USE_CUDA=0" >> config.mk echo "USE_CUDA=0" >> config.mk
echo "USE_OPENCL=0" >> config.mk
if [ ! ${TRAVIS_OS_NAME} == "osx" ]; then if [ ${TRAVIS_OS_NAME} == "osx" ]; then
echo "USE_OPENCL=1" >> config.mk
else
# use g++-4.8 for linux # use g++-4.8 for linux
if [ ${CXX} == "g++" ]; then if [ ${CXX} == "g++" ]; then
export CXX=g++-4.8 export CXX=g++-4.8
fi fi
echo "USE_OPENCL=0" >> config.mk
fi fi
if [ ${TASK} == "cpp_test" ] || [ ${TASK} == "all_test" ]; then if [ ${TASK} == "cpp_test" ] || [ ${TASK} == "all_test" ]; then
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment