Commit 2f462cca by Tianqi Chen Committed by GitHub

[MODULE] Enable OpenCL and CUDA Modules (#53)

parent efae4be0
...@@ -374,34 +374,6 @@ TVM_DLL int TVMFuncListGlobalNames(int *out_size, ...@@ -374,34 +374,6 @@ TVM_DLL int TVMFuncListGlobalNames(int *out_size,
// Array related apis for quick proptying // Array related apis for quick proptying
/*! /*!
* \brief Initialize certain type of devices, this may
* not be necessary for all device types. But is needed for OpenCL.
*
* \param dev_mask The device mask of device type to be initialized
* \param option_keys Additional option keys to pass.
* \param option_vals Additional option values to pass
* \param num_options Number of options to be passed into it.
* \param out_code 1: success, 0: already initialized
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMDeviceInit(int dev_mask,
const char** option_keys,
const char** option_vals,
int num_options,
int *out_code);
/*!
* \brief Whether the specified context is enabled.
*
* \param ctx The context to be checked.
* \param out_enabled whether the ctx is enabled.
* \return Whether the function is successful.
*/
TVM_DLL int TVMContextEnabled(TVMContext ctx,
int* out_enabled);
/*!
* \brief Allocate a nd-array's memory, * \brief Allocate a nd-array's memory,
* including space of shape, of given spec. * including space of shape, of given spec.
* *
......
...@@ -535,8 +535,9 @@ inline const char* TypeCode2Str(int type_code) { ...@@ -535,8 +535,9 @@ inline const char* TypeCode2Str(int type_code) {
} }
inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*)
os << TypeCode2Str(t.code) os << TypeCode2Str(t.code);
<< static_cast<int>(t.bits); if (t.code == kHandle) return os;
os << static_cast<int>(t.bits);
if (t.lanes != 1) { if (t.lanes != 1) {
os << 'x' << static_cast<int>(t.lanes); os << 'x' << static_cast<int>(t.lanes);
} }
...@@ -559,7 +560,7 @@ inline TVMType String2TVMType(std::string s) { ...@@ -559,7 +560,7 @@ inline TVMType String2TVMType(std::string s) {
t.code = kUInt; scan = s.c_str() + 4; t.code = kUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") { } else if (s.substr(0, 5) == "float") {
t.code = kFloat; scan = s.c_str() + 5; t.code = kFloat; scan = s.c_str() + 5;
} else if (s == "handle") { } else if (s.substr(0, 6) == "handle") {
t.code = kHandle; t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default. t.bits = 64; // handle uses 64 bit by default.
scan = s.c_str() + 6; scan = s.c_str() + 6;
......
...@@ -15,7 +15,7 @@ from . import schedule ...@@ -15,7 +15,7 @@ from . import schedule
from . import module from . import module
from . import ndarray as nd from . import ndarray as nd
from .ndarray import cpu, gpu, opencl, init_opencl, cl from .ndarray import cpu, gpu, opencl, cl
from ._base import TVMError from ._base import TVMError
from .api import * from .api import *
......
...@@ -7,10 +7,9 @@ import ctypes ...@@ -7,10 +7,9 @@ import ctypes
import numpy as np import numpy as np
from .._base import _LIB, check_call from .._base import _LIB, check_call
from .._base import c_array, c_str from .._base import c_array
from ._types import TVMType, tvm_index_t from ._types import TVMType, tvm_index_t
class TVMContext(ctypes.Structure): class TVMContext(ctypes.Structure):
"""TVM context strucure.""" """TVM context strucure."""
_fields_ = [("dev_mask", ctypes.c_int), _fields_ = [("dev_mask", ctypes.c_int),
...@@ -29,12 +28,6 @@ class TVMContext(ctypes.Structure): ...@@ -29,12 +28,6 @@ class TVMContext(ctypes.Structure):
return "%s(%d)" % ( return "%s(%d)" % (
TVMContext.MASK2STR[self.dev_mask], self.dev_id) TVMContext.MASK2STR[self.dev_mask], self.dev_id)
@property
def enabled(self):
ret = ctypes.c_int()
check_call(_LIB.TVMContextEnabled(self, ctypes.byref(ret)))
return ret.value != 0
class TVMArray(ctypes.Structure): class TVMArray(ctypes.Structure):
"""TVMValue in C API""" """TVMValue in C API"""
...@@ -141,30 +134,6 @@ def sync(ctx): ...@@ -141,30 +134,6 @@ def sync(ctx):
check_call(_LIB.TVMSynchronize(ctx, None)) check_call(_LIB.TVMSynchronize(ctx, None))
def init_opencl(**kwargs):
"""Initialize the opencl with the options.
Parameters
----------
kwargs : dict
The options
"""
keys = []
vals = []
for k, v in kwargs.items():
keys.append(c_str(k))
vals.append(c_str(v))
dev_mask = ctypes.c_int(4)
out_code = ctypes.c_int()
check_call(_LIB.TVMDeviceInit(
dev_mask,
c_array(ctypes.c_char_p, keys),
c_array(ctypes.c_char_p, vals),
ctypes.c_int(len(keys)),
ctypes.byref(out_code)))
return out_code.value != 0
class NDArrayBase(object): class NDArrayBase(object):
"""A simple Device/CPU Array object in runtime.""" """A simple Device/CPU Array object in runtime."""
__slots__ = ["handle"] __slots__ = ["handle"]
......
"""Utilities to make tempdir"""
from __future__ import absolute_import as _abs
import os
import tempfile
import shutil
class TempDirectory(object):
"""Helper object to manage temp directory during testing"""
def __init__(self):
self.temp_dir = tempfile.mkdtemp()
def __del__(self):
shutil.rmtree(self.temp_dir)
def relpath(self, name):
"""Relative path in temp dir
Parameters
----------
name : str
The name of the file.
"""
return os.path.join(self.temp_dir, name)
def tempdir():
"""Return a new temp dir which deletes the contents when exit
Returns
-------
temp : TempDirectory
The temp directory object
"""
return TempDirectory()
# coding: utf-8 # coding: utf-8
"""Information about nnvm.""" """Information about nnvm."""
from __future__ import absolute_import from __future__ import absolute_import
import sys
import os import os
import platform import platform
def find_lib_path(): def find_lib_path():
"""Find dynamic library files. """Find dynamic library files.
...@@ -12,6 +14,7 @@ def find_lib_path(): ...@@ -12,6 +14,7 @@ def find_lib_path():
lib_path : list(string) lib_path : list(string)
List of all found path to the libraries List of all found path to the libraries
""" """
use_runtime = os.environ.get("TVM_USE_RUNTIME_LIB", False)
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
api_path = os.path.join(curr_path, '../../lib/') api_path = os.path.join(curr_path, '../../lib/')
cmake_build_path = os.path.join(curr_path, '../../build/Release/') cmake_build_path = os.path.join(curr_path, '../../build/Release/')
...@@ -26,15 +29,24 @@ def find_lib_path(): ...@@ -26,15 +29,24 @@ def find_lib_path():
dll_path.append(os.path.join(curr_path, '../../windows', vs_configuration)) dll_path.append(os.path.join(curr_path, '../../windows', vs_configuration))
elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None): elif os.name == "posix" and os.environ.get('LD_LIBRARY_PATH', None):
dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")]) dll_path.extend([p.strip() for p in os.environ['LD_LIBRARY_PATH'].split(":")])
if os.name == 'nt': if os.name == 'nt':
dll_path = [os.path.join(p, 'libtvm.dll') for p in dll_path] lib_dll_path = [os.path.join(p, 'libtvm.dll') for p in dll_path]
runtime_dll_path = [os.path.join(p, 'libtvm_runtime.dll') for p in dll_path]
else: else:
dll_path = [os.path.join(p, 'libtvm.so') for p in dll_path] lib_dll_path = [os.path.join(p, 'libtvm.so') for p in dll_path]
lib_path = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)] runtime_dll_path = [os.path.join(p, 'libtvm_runtime.so') for p in dll_path]
if len(lib_path) == 0:
dll_path = runtime_dll_path if use_runtime else lib_dll_path
lib_found = [p for p in dll_path if os.path.exists(p) and os.path.isfile(p)]
if len(lib_found) == 0:
raise RuntimeError('Cannot find the files.\n' + raise RuntimeError('Cannot find the files.\n' +
'List of candidates:\n' + str('\n'.join(dll_path))) 'List of candidates:\n' + str('\n'.join(dll_path)))
return lib_path if use_runtime:
sys.stderr.write("Loading runtime library... this is execution only\n")
sys.stderr.flush()
return lib_found
# current version # current version
......
...@@ -9,7 +9,6 @@ import numpy as _np ...@@ -9,7 +9,6 @@ import numpy as _np
from ._ctypes._ndarray import TVMContext, TVMType, NDArrayBase from ._ctypes._ndarray import TVMContext, TVMType, NDArrayBase
from ._ctypes._ndarray import cpu, gpu, opencl, empty, sync from ._ctypes._ndarray import cpu, gpu, opencl, empty, sync
from ._ctypes._ndarray import _init_ndarray_module from ._ctypes._ndarray import _init_ndarray_module
from ._ctypes._ndarray import init_opencl
from ._ctypes._function import Function from ._ctypes._function import Function
cl = opencl cl = opencl
......
...@@ -21,7 +21,7 @@ TVM_REGISTER_API(_codegen_build) ...@@ -21,7 +21,7 @@ TVM_REGISTER_API(_codegen_build)
} }
}); });
TVM_REGISTER_API(_codegen_target_enabled) TVM_REGISTER_API(_codegen_enabled)
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = TargetEnabled(args[0]); *ret = TargetEnabled(args[0]);
}); });
......
...@@ -61,10 +61,13 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) { ...@@ -61,10 +61,13 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) { if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code).operator std::string(); code = (*f)(code).operator std::string();
} }
std::string fmt = "ptx";
std::string ptx; std::string ptx;
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) { if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
ptx = (*f)(code).operator std::string(); ptx = (*f)(code).operator std::string();
// Dirty matching to check PTX vs cubin.
// TODO(tqchen) more reliable checks
if (ptx[0] != '/') fmt = "cubin";
} else { } else {
ptx = NVRTCCompile(code); ptx = NVRTCCompile(code);
} }
...@@ -80,7 +83,7 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) { ...@@ -80,7 +83,7 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
} }
fmap[f->name] = info; fmap[f->name] = info;
} }
return CUDAModuleCreate(ptx, "ptx", fmap, code); return CUDAModuleCreate(ptx, fmt, fmap, code);
} }
TVM_REGISTER_API(_codegen_build_cuda) TVM_REGISTER_API(_codegen_build_cuda)
......
...@@ -200,38 +200,6 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func, ...@@ -200,38 +200,6 @@ int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
API_END(); API_END();
} }
int TVMDeviceInit(int dev_mask,
const char** option_keys,
const char** option_vals,
int num_options,
int* out_code) {
API_BEGIN();
*out_code = 1;
switch (dev_mask) {
case kOpenCL: {
*out_code = DeviceInit<kOpenCL>(option_keys, option_vals, num_options);
break;
}
default: break;
}
API_END();
}
int TVMContextEnabled(TVMContext ctx,
int* out_enabled) {
API_BEGIN();
if (ctx.dev_mask == kGPU && TVM_CUDA_RUNTIME == 0) {
*out_enabled = 0;
} else if (ctx.dev_mask == kOpenCL && TVM_OPENCL_RUNTIME == 0) {
*out_enabled = 0;
} else {
TVM_DEVICE_SWITCH(ctx, {
*out_enabled = CheckEnabled<xpu>(ctx);
});
}
API_END();
}
int TVMArrayAlloc(const tvm_index_t* shape, int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim, tvm_index_t ndim,
TVMType dtype, TVMType dtype,
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "./cuda_module.h" #include "./cuda_module.h"
#if TVM_CUDA_RUNTIME #if TVM_CUDA_RUNTIME
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
...@@ -60,7 +61,12 @@ class CUDAModuleNode : public runtime::ModuleNode { ...@@ -60,7 +61,12 @@ class CUDAModuleNode : public runtime::ModuleNode {
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
const std::string& format) final { const std::string& format) final {
LOG(FATAL) << "Not implemented"; std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
} }
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
...@@ -212,9 +218,13 @@ Module CUDAModuleCreate( ...@@ -212,9 +218,13 @@ Module CUDAModuleCreate(
// Load module from module. // Load module from module.
Module CUDAModuleLoad(const std::string& file_name, Module CUDAModuleLoad(const std::string& file_name,
const std::string& format) { const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format); std::string fmt = GetFileFormat(file_name, format);
std::string data = LoadBinaryFile(file_name); std::string meta_file = GetMetaFilePath(file_name);
return CUDAModuleCreate(data, fmt, {{}}, std::string()); LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
return CUDAModuleCreate(data, fmt, fmap, std::string());
} }
TVM_REGISTER_GLOBAL(_module_loadfile_cubin) TVM_REGISTER_GLOBAL(_module_loadfile_cubin)
......
...@@ -12,31 +12,6 @@ ...@@ -12,31 +12,6 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*! /*!
* \brief Initialize the device.
* \param option_keys Additional option keys to pass.
* \param option_vals Additional option values to pass
* \param num_options Number of options to be passed into it.
* \return 0 if success, 1: if already initialized
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline bool DeviceInit(const char** option_keys,
const char** option_vals,
int num_options) {
return true;
}
/*!
* \brief Whether ctx is enabled.
* \param ctx The device context to perform operation.
* \tparam xpu The device mask.
*/
template<TVMDeviceMask xpu>
inline bool CheckEnabled(TVMContext ctx) {
return true;
}
/*!
* \brief Allocate a data space on device. * \brief Allocate a data space on device.
* \param ctx The device context to perform operation. * \param ctx The device context to perform operation.
* \param size The size of the memory * \param size The size of the memory
......
/*!
* Copyright (c) 2017 by Contributors
* \file file_util.cc
*/
#include <dmlc/json.h>
#include <dmlc/logging.h>
#include <tvm/runtime/packed_func.h>
#include <fstream>
#include "./file_util.h"
namespace tvm {
namespace runtime {
void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
std::vector<std::string> sarg_types(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
sarg_types[i] = TVMType2String(arg_types[i]);
}
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("arg_types", sarg_types);
writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags);
writer->EndObject();
}
void FunctionInfo::Load(dmlc::JSONReader* reader) {
dmlc::JSONObjectReadHelper helper;
std::vector<std::string> sarg_types;
helper.DeclareField("name", &name);
helper.DeclareField("arg_types", &sarg_types);
helper.DeclareField("thread_axis_tags", &thread_axis_tags);
helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
arg_types[i] = String2TVMType(sarg_types[i]);
}
}
std::string GetFileFormat(const std::string& file_name,
const std::string& format) {
std::string fmt = format;
if (fmt.length() == 0) {
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(pos + 1, file_name.length() - pos - 1);
} else {
return "";
}
} else {
return format;
}
}
std::string GetMetaFilePath(const std::string& file_name) {
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(0, pos) + ".tvm_meta.json";
} else {
return file_name + ".tvm_meta.json";
}
}
void LoadBinaryFromFile(const std::string& file_name,
std::string* data) {
std::ifstream fs(file_name, std::ios::in | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name;
// get its size:
fs.seekg(0, std::ios::end);
size_t size = fs.tellg();
fs.seekg(0, std::ios::beg);
data->resize(size);
fs.read(&(*data)[0], size);
}
void SaveBinaryToFile(
const std::string& file_name,
const std::string& data) {
std::ofstream fs(file_name, std::ios::out | std::ios::binary);
CHECK(!fs.fail()) << "Cannot open " << file_name;
fs.write(&data[0], data.length());
}
void SaveMetaDataToFile(
const std::string& file_name,
const std::unordered_map<std::string, FunctionInfo>& fmap) {
std::string version = "0.1.0";
std::ofstream fs(file_name.c_str());
CHECK(!fs.fail()) << "Cannot open file " << file_name;
dmlc::JSONWriter writer(&fs);
writer.BeginObject();
writer.WriteObjectKeyValue("tvm_version", version);
writer.WriteObjectKeyValue("func_info", fmap);
writer.EndObject();
fs.close();
}
void LoadMetaDataFromFile(
const std::string& file_name,
std::unordered_map<std::string, FunctionInfo>* fmap) {
std::ifstream fs(file_name.c_str());
CHECK(!fs.fail()) << "Cannot open file " << file_name;
std::string version;
dmlc::JSONReader reader(&fs);
dmlc::JSONObjectReadHelper helper;
helper.DeclareField("tvm_version", &version);
helper.DeclareField("func_info", fmap);
helper.ReadAllFields(&reader);
fs.close();
}
} // namespace runtime
} // namespace tvm
...@@ -6,9 +6,8 @@ ...@@ -6,9 +6,8 @@
#ifndef TVM_RUNTIME_FILE_UTIL_H_ #ifndef TVM_RUNTIME_FILE_UTIL_H_
#define TVM_RUNTIME_FILE_UTIL_H_ #define TVM_RUNTIME_FILE_UTIL_H_
#include <dmlc/logging.h>
#include <fstream>
#include <string> #include <string>
#include "./meta_data.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -17,39 +16,48 @@ namespace runtime { ...@@ -17,39 +16,48 @@ namespace runtime {
* \param file_name The name of the file. * \param file_name The name of the file.
* \param format The format of the file. * \param format The format of the file.
*/ */
inline std::string GetFileFormat(const std::string& file_name, std::string GetFileFormat(const std::string& file_name,
const std::string& format) { const std::string& format);
std::string fmt = format;
if (fmt.length() == 0) { /*!
size_t pos = file_name.find_last_of("."); * \brief Get meta file path given file name and format.
if (pos != std::string::npos) { * \param file_name The name of the file.
return file_name.substr(pos + 1, file_name.length() - pos - 1); */
} else { std::string GetMetaFilePath(const std::string& file_name);
return "";
} /*!
} else { * \brief Load binary file into a in-memory buffer.
return format; * \param file_name The name of the file.
} * \param data The data to be loaded.
} */
void LoadBinaryFromFile(const std::string& file_name,
std::string* data);
/*! /*!
* \brief Load binary file into a in-memory buffer. * \brief Load binary file into a in-memory buffer.
* \param file_name The name of the file. * \param file_name The name of the file.
* \param The binary
*/
void SaveBinaryToFile(const std::string& file_name,
const std::string& data);
/*!
* \brief Save meta data to file.
* \param file_name The name of the file.
* \param fmap The function info map.
*/ */
inline std::string LoadBinaryFile(const std::string& file_name) { void SaveMetaDataToFile(
std::ifstream fs(file_name, std::ios::in | std::ios::binary); const std::string& file_name,
CHECK(!fs.fail()) const std::unordered_map<std::string, FunctionInfo>& fmap);
<< "Cannot open " << file_name;
// get its size:
fs.seekg(0, std::ios::end);
size_t size = fs.tellg();
fs.seekg(0, std::ios::beg);
std::string data;
data.resize(size);
fs.read(&data[0], size);
return data;
}
/*!
* \brief Load meta data to file.
* \param file_name The name of the file.
* \param fmap The function info map.
*/
void LoadMetaDataFromFile(
const std::string& file_name,
std::unordered_map<std::string, FunctionInfo>* fmap);
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
#endif // TVM_RUNTIME_FILE_UTIL_H_ #endif // TVM_RUNTIME_FILE_UTIL_H_
...@@ -27,30 +27,8 @@ struct FunctionInfo { ...@@ -27,30 +27,8 @@ struct FunctionInfo {
std::vector<TVMType> arg_types; std::vector<TVMType> arg_types;
std::vector<std::string> thread_axis_tags; std::vector<std::string> thread_axis_tags;
void Save(dmlc::JSONWriter *writer) const { void Save(dmlc::JSONWriter *writer) const;
std::vector<std::string> sarg_types(arg_types.size()); void Load(dmlc::JSONReader *reader);
for (size_t i = 0; i < arg_types.size(); ++i) {
sarg_types[i] = TVMType2String(arg_types[i]);
}
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("arg_types", sarg_types);
writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags);
writer->EndObject();
}
void Load(dmlc::JSONReader *reader) {
dmlc::JSONObjectReadHelper helper;
std::vector<std::string> sarg_types;
helper.DeclareField("name", &name);
helper.DeclareField("arg_types", &sarg_types);
helper.DeclareField("thread_axis_tags", &thread_axis_tags);
helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
arg_types[i] = String2TVMType(sarg_types[i]);
}
}
}; };
} // namespace runtime } // namespace runtime
......
...@@ -83,6 +83,25 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) { ...@@ -83,6 +83,25 @@ const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
} }
} }
bool RuntimeEnabled(const std::string& target) {
std::string load_f_name;
if (target == "cpu") {
return true;
} else if (target == "cuda" || target == "gpu") {
load_f_name = "_module_loadfile_ptx";
} else if (target == "cl" || target == "opencl") {
load_f_name = "_module_loadfile_cl";
} else {
LOG(FATAL) << "Unknown optional runtime " << target;
}
return runtime::Registry::Get(load_f_name) != nullptr;
}
TVM_REGISTER_GLOBAL(_module_enabled)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = RuntimeEnabled(args[0]);
});
TVM_REGISTER_GLOBAL(_module__GetSource) TVM_REGISTER_GLOBAL(_module__GetSource)
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator Module()->GetSource(args[1]); *ret = args[0].operator Module()->GetSource(args[1]);
......
...@@ -15,121 +15,6 @@ ...@@ -15,121 +15,6 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
namespace cl {
inline std::string GetPlatformInfo(
cl_platform_id pid, cl_platform_info param_name) {
size_t ret_size;
OPENCL_CALL(clGetPlatformInfo(pid, param_name, 0, nullptr, &ret_size));
std::string ret;
ret.resize(ret_size);
OPENCL_CALL(clGetPlatformInfo(pid, param_name, ret_size, &ret[0], nullptr));
return ret;
}
inline std::string GetDeviceInfo(
cl_device_id pid, cl_device_info param_name) {
size_t ret_size;
OPENCL_CALL(clGetDeviceInfo(pid, param_name, 0, nullptr, &ret_size));
std::string ret;
ret.resize(ret_size);
OPENCL_CALL(clGetDeviceInfo(pid, param_name, ret_size, &ret[0], nullptr));
return ret;
}
inline std::vector<cl_platform_id> GetPlatformIDs() {
cl_uint ret_size;
OPENCL_CALL(clGetPlatformIDs(0, nullptr, &ret_size));
std::vector<cl_platform_id> ret;
ret.resize(ret_size);
OPENCL_CALL(clGetPlatformIDs(ret_size, &ret[0], nullptr));
return ret;
}
inline std::vector<cl_device_id> GetDeviceIDs(
cl_platform_id pid, std::string device_type) {
cl_device_type dtype = CL_DEVICE_TYPE_ALL;
if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU;
if (device_type == "gpu") dtype = CL_DEVICE_TYPE_CPU;
if (device_type == "accelerator") dtype = CL_DEVICE_TYPE_ACCELERATOR;
cl_uint ret_size;
OPENCL_CALL(clGetDeviceIDs(pid, dtype, 0, nullptr, &ret_size));
std::vector<cl_device_id> ret;
ret.resize(ret_size);
OPENCL_CALL(clGetDeviceIDs(pid, dtype, ret_size, &ret[0], nullptr));
return ret;
}
inline bool MatchPlatformInfo(
cl_platform_id pid,
cl_platform_info param_name,
std::string value) {
if (value.length() == 0) return true;
std::string param_value = GetPlatformInfo(pid, param_name);
return param_value.find(value) != std::string::npos;
}
} // namespace cl
template<>
inline bool DeviceInit<kOpenCL>(const char** option_keys,
const char** option_vals,
int num_options) {
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
std::lock_guard<std::mutex>(w->mu);
if (w->initialized()) return false;
// matching conditions
std::string platform_name, device_type;
for (int i = 0; i < num_options; ++i) {
std::string key = option_keys[i];
std::string val = option_vals[i];
if (key == "platform_name") {
platform_name = val;
} else if (key == "device_type") {
device_type = val;
} else {
LOG(FATAL) << "unknown DeviceInit option " << key;
}
}
// matched platforms
std::vector<cl_platform_id> platform_matched;
for (cl_platform_id pid : cl::GetPlatformIDs()) {
bool matched = true;
if (!cl::MatchPlatformInfo(pid, CL_PLATFORM_NAME, platform_name)) matched = false;
if (matched) platform_matched.push_back(pid);
}
if (platform_matched.size() == 0) {
LOG(FATAL) << "No OpenCL platform matched given existing options ...";
}
if (platform_matched.size() > 1) {
LOG(WARNING) << "Multiple OpenCL platforms matched, use the first one ... ";
}
w->platform_id = platform_matched[0];
LOG(INFO) << "Initialize OpenCL platform \'"
<< cl::GetPlatformInfo(w->platform_id, CL_PLATFORM_NAME) << '\'';
std::vector<cl_device_id> devices_matched =
cl::GetDeviceIDs(w->platform_id, device_type);
CHECK_GT(devices_matched.size(), 0U)
<< "No OpenCL device any device matched given the options";
w->devices = devices_matched;
cl_int err_code;
w->context = clCreateContext(
nullptr, w->devices.size(), &(w->devices[0]),
nullptr, nullptr, &err_code);
OPENCL_CHECK_ERROR(err_code);
CHECK_EQ(w->queues.size(), 0U);
for (size_t i = 0; i < w->devices.size(); ++i) {
cl_device_id did = w->devices[i];
w->queues.push_back(
clCreateCommandQueue(w->context, did, 0, &err_code));
OPENCL_CHECK_ERROR(err_code);
LOG(INFO) << "opencl(" << i
<< ")=\'" << cl::GetDeviceInfo(did, CL_DEVICE_NAME)
<< "\' cl_device_id=" << did;
}
return true;
}
template<> template<>
inline void* AllocDataSpace<kOpenCL>(TVMContext ctx, size_t size, size_t alignment) { inline void* AllocDataSpace<kOpenCL>(TVMContext ctx, size_t size, size_t alignment) {
......
...@@ -7,11 +7,14 @@ ...@@ -7,11 +7,14 @@
#if TVM_OPENCL_RUNTIME #if TVM_OPENCL_RUNTIME
#include <tvm/runtime/registry.h>
#include <vector> #include <vector>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include "../void_addr_args.h" #include "../void_addr_args.h"
#include "../thread_storage_scope.h" #include "../thread_storage_scope.h"
#include "../meta_data.h"
#include "../file_util.h"
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -67,7 +70,12 @@ class OpenCLModuleNode : public ModuleNode { ...@@ -67,7 +70,12 @@ class OpenCLModuleNode : public ModuleNode {
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
const std::string& format) final { const std::string& format) final {
LOG(FATAL) << "Not implemented"; std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
} }
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
...@@ -294,6 +302,27 @@ Module OpenCLModuleCreate( ...@@ -294,6 +302,27 @@ Module OpenCLModuleCreate(
return Module(n); return Module(n);
} }
// Load module from module.
Module OpenCLModuleLoad(const std::string& file_name,
const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
return OpenCLModuleCreate(data, fmt, fmap);
}
TVM_REGISTER_GLOBAL(_module_loadfile_cl)
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoad(args[0], args[1]);
});
TVM_REGISTER_GLOBAL(_module_loadfile_clbin)
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoad(args[0], args[1]);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#if TVM_OPENCL_RUNTIME #if TVM_OPENCL_RUNTIME
#include <tvm/runtime/registry.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
namespace tvm { namespace tvm {
...@@ -23,6 +24,123 @@ OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() { ...@@ -23,6 +24,123 @@ OpenCLThreadEntry* OpenCLThreadEntry::ThreadLocal() {
return OpenCLThreadStore::Get(); return OpenCLThreadStore::Get();
} }
std::string GetPlatformInfo(
cl_platform_id pid, cl_platform_info param_name) {
size_t ret_size;
OPENCL_CALL(clGetPlatformInfo(pid, param_name, 0, nullptr, &ret_size));
std::string ret;
ret.resize(ret_size);
OPENCL_CALL(clGetPlatformInfo(pid, param_name, ret_size, &ret[0], nullptr));
return ret;
}
std::string GetDeviceInfo(
cl_device_id pid, cl_device_info param_name) {
size_t ret_size;
OPENCL_CALL(clGetDeviceInfo(pid, param_name, 0, nullptr, &ret_size));
std::string ret;
ret.resize(ret_size);
OPENCL_CALL(clGetDeviceInfo(pid, param_name, ret_size, &ret[0], nullptr));
return ret;
}
std::vector<cl_platform_id> GetPlatformIDs() {
cl_uint ret_size;
OPENCL_CALL(clGetPlatformIDs(0, nullptr, &ret_size));
std::vector<cl_platform_id> ret;
ret.resize(ret_size);
OPENCL_CALL(clGetPlatformIDs(ret_size, &ret[0], nullptr));
return ret;
}
std::vector<cl_device_id> GetDeviceIDs(
cl_platform_id pid, std::string device_type) {
cl_device_type dtype = CL_DEVICE_TYPE_ALL;
if (device_type == "cpu") dtype = CL_DEVICE_TYPE_CPU;
if (device_type == "gpu") dtype = CL_DEVICE_TYPE_CPU;
if (device_type == "accelerator") dtype = CL_DEVICE_TYPE_ACCELERATOR;
cl_uint ret_size;
OPENCL_CALL(clGetDeviceIDs(pid, dtype, 0, nullptr, &ret_size));
std::vector<cl_device_id> ret;
ret.resize(ret_size);
OPENCL_CALL(clGetDeviceIDs(pid, dtype, ret_size, &ret[0], nullptr));
return ret;
}
bool MatchPlatformInfo(
cl_platform_id pid,
cl_platform_info param_name,
std::string value) {
if (value.length() == 0) return true;
std::string param_value = GetPlatformInfo(pid, param_name);
return param_value.find(value) != std::string::npos;
}
bool InitOpenCL(TVMArgs args, TVMRetValue* rv) {
cl::OpenCLWorkspace* w = cl::OpenCLWorkspace::Global();
std::lock_guard<std::mutex>(w->mu);
if (w->initialized()) return false;
// matching conditions
std::string platform_name, device_type;
for (size_t i = 0; i < args.num_args; ++i) {
std::string arg = args[i];
size_t pos = arg.find_first_of('=');
CHECK_EQ(pos, std::string::npos)
<< "Argumentes need to be key=value";
std::string key = arg.substr(0, pos);
std::string val = arg.substr(pos + 1, arg.length() - pos - 1);
if (key == "platform_name") {
platform_name = val;
} else if (key == "device_type") {
device_type = val;
} else {
LOG(FATAL) << "unknown DeviceInit option " << key;
}
}
// matched platforms
std::vector<cl_platform_id> platform_matched;
for (cl_platform_id pid : cl::GetPlatformIDs()) {
bool matched = true;
if (!cl::MatchPlatformInfo(pid, CL_PLATFORM_NAME, platform_name)) matched = false;
if (matched) platform_matched.push_back(pid);
}
if (platform_matched.size() == 0) {
LOG(FATAL) << "No OpenCL platform matched given existing options ...";
}
if (platform_matched.size() > 1) {
LOG(WARNING) << "Multiple OpenCL platforms matched, use the first one ... ";
}
w->platform_id = platform_matched[0];
LOG(INFO) << "Initialize OpenCL platform \'"
<< cl::GetPlatformInfo(w->platform_id, CL_PLATFORM_NAME) << '\'';
std::vector<cl_device_id> devices_matched =
cl::GetDeviceIDs(w->platform_id, device_type);
CHECK_GT(devices_matched.size(), 0U)
<< "No OpenCL device any device matched given the options";
w->devices = devices_matched;
cl_int err_code;
w->context = clCreateContext(
nullptr, w->devices.size(), &(w->devices[0]),
nullptr, nullptr, &err_code);
OPENCL_CHECK_ERROR(err_code);
CHECK_EQ(w->queues.size(), 0U);
for (size_t i = 0; i < w->devices.size(); ++i) {
cl_device_id did = w->devices[i];
w->queues.push_back(
clCreateCommandQueue(w->context, did, 0, &err_code));
OPENCL_CHECK_ERROR(err_code);
LOG(INFO) << "opencl(" << i
<< ")=\'" << cl::GetDeviceInfo(did, CL_DEVICE_NAME)
<< "\' cl_device_id=" << did;
}
return true;
}
TVM_REGISTER_GLOBAL(_module_init_opencl)
.set_body(InitOpenCL);
} // namespace cl } // namespace cl
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -20,9 +20,9 @@ def test_add(): ...@@ -20,9 +20,9 @@ def test_add():
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
if not tvm.codegen.target_enabled(host): if not tvm.codegen.enabled(host):
return return
if not tvm.codegen.target_enabled(device): if not tvm.codegen.enabled(device):
return return
fadd = tvm.build(s, [A, B, C], fadd = tvm.build(s, [A, B, C],
device, host, device, host,
...@@ -37,7 +37,8 @@ def test_add(): ...@@ -37,7 +37,8 @@ def test_add():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy()) c.asnumpy(), a.asnumpy() + b.asnumpy())
tvm.init_opencl() if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("cuda", "llvm") check_device("cuda", "llvm")
check_device("opencl") check_device("opencl")
......
...@@ -54,9 +54,9 @@ def test_gemm(): ...@@ -54,9 +54,9 @@ def test_gemm():
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
if not tvm.codegen.target_enabled(host): if not tvm.codegen.enabled(host):
return return
if not tvm.codegen.target_enabled(device): if not tvm.codegen.enabled(device):
return return
f = tvm.build(s, [A, B, C], device, host, f = tvm.build(s, [A, B, C], device, host,
...@@ -76,8 +76,9 @@ def test_gemm(): ...@@ -76,8 +76,9 @@ def test_gemm():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("cuda") check_device("cuda")
tvm.init_opencl()
check_device("opencl") check_device("opencl")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -19,9 +19,9 @@ def test_sum(): ...@@ -19,9 +19,9 @@ def test_sum():
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
if not tvm.codegen.target_enabled(host): if not tvm.codegen.enabled(host):
return return
if not tvm.codegen.target_enabled(device): if not tvm.codegen.enabled(device):
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
fsum = tvm.build(s, fsum = tvm.build(s,
...@@ -37,7 +37,9 @@ def test_sum(): ...@@ -37,7 +37,9 @@ def test_sum():
np.testing.assert_allclose( np.testing.assert_allclose(
b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4) b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
tvm.init_opencl() if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
......
...@@ -23,9 +23,9 @@ def test_scan(): ...@@ -23,9 +23,9 @@ def test_scan():
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
if not tvm.codegen.target_enabled(host): if not tvm.codegen.enabled(host):
return return
if not tvm.codegen.target_enabled(device): if not tvm.codegen.enabled(device):
return return
fscan = tvm.build(s, [X, res], fscan = tvm.build(s, [X, res],
device, host, device, host,
...@@ -41,7 +41,9 @@ def test_scan(): ...@@ -41,7 +41,9 @@ def test_scan():
np.testing.assert_allclose( np.testing.assert_allclose(
b.asnumpy(), np.cumsum(a_np, axis=0)) b.asnumpy(), np.cumsum(a_np, axis=0))
tvm.init_opencl() if tvm.module.enabled("opencl"):
tvm.module.init_opencl()
check_device("cuda") check_device("cuda")
check_device("opencl") check_device("opencl")
......
import tvm import tvm
from tvm.addon import testing
import numpy as np import numpy as np
def test_add_pipeline(): def test_add_pipeline():
...@@ -27,9 +28,9 @@ def test_add_pipeline(): ...@@ -27,9 +28,9 @@ def test_add_pipeline():
fsplits = tvm.ir_pass.SplitHostDevice(fapi) fsplits = tvm.ir_pass.SplitHostDevice(fapi)
def check_target(device, host="stackvm"): def check_target(device, host="stackvm"):
if not tvm.codegen.target_enabled(host): if not tvm.codegen.enabled(host):
return return
if not tvm.codegen.target_enabled(device): if not tvm.codegen.enabled(device):
return return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
mhost = tvm.codegen.build(fsplits[0], host) mhost = tvm.codegen.build(fsplits[0], host)
...@@ -47,8 +48,33 @@ def test_add_pipeline(): ...@@ -47,8 +48,33 @@ def test_add_pipeline():
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy()) c.asnumpy(), a.asnumpy() + b.asnumpy())
def check_module_save(device, host="stackvm"):
if not tvm.codegen.enabled(host):
return
if not tvm.codegen.enabled(device):
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
fmt = "ptx" if device == "cuda" else "cl"
mhost = tvm.codegen.build(fsplits[0], host)
mdev = tvm.codegen.build(fsplits[1:], device)
temp = testing.tempdir()
mpath = temp.relpath("test.%s" % fmt)
mdev.save(mpath)
mdev2 = tvm.module.load(mpath)
mhost.import_module(mdev2)
f = mhost.entry_func
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
check_target("cuda", host="stackvm") check_target("cuda", host="stackvm")
check_target("cuda", host="llvm") check_target("cuda", host="llvm")
check_module_save("cuda", host="stackvm")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -8,7 +8,7 @@ def tvm_call_packed(*args): ...@@ -8,7 +8,7 @@ def tvm_call_packed(*args):
def run_jit(fapi, check): def run_jit(fapi, check):
for target in ["llvm", "stackvm"]: for target in ["llvm", "stackvm"]:
if not tvm.codegen.target_enabled(target): if not tvm.codegen.enabled(target):
continue continue
f = tvm.codegen.build(fapi, target) f = tvm.codegen.build(fapi, target)
s = f.get_source() s = f.get_source()
...@@ -95,7 +95,7 @@ def test_llvm_add_pipeline(): ...@@ -95,7 +95,7 @@ def test_llvm_add_pipeline():
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0) fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
def check_llvm(): def check_llvm():
if not tvm.codegen.target_enabled("llvm"): if not tvm.codegen.enabled("llvm"):
return return
# build and invoke the kernel. # build and invoke the kernel.
f = tvm.codegen.build(fapi, "llvm") f = tvm.codegen.build(fapi, "llvm")
......
import tvm import tvm
from tvm.addon import cc_compiler as cc from tvm.addon import cc_compiler as cc, testing
import os import os
import tempfile
import numpy as np import numpy as np
import subprocess
runtime_py = """
import os
import sys
os.environ["TVM_USE_RUNTIME_LIB"] = "1"
import tvm
import numpy as np
path_dso = sys.argv[1]
dtype = sys.argv[2]
ff = tvm.module.load(path_dso)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
ff(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
print("Finish runtime checking...")
"""
def test_dso_module_load(): def test_dso_module_load():
if not tvm.codegen.target_enabled("llvm"): if not tvm.codegen.enabled("llvm"):
return return
dtype = 'int64' dtype = 'int64'
temp_dir = tempfile.mkdtemp() temp = testing.tempdir()
def save_object(names): def save_object(names):
n = tvm.Var('n') n = tvm.Var('n')
...@@ -25,10 +40,10 @@ def test_dso_module_load(): ...@@ -25,10 +40,10 @@ def test_dso_module_load():
for name in names: for name in names:
m.save(name) m.save(name)
path_obj = "%s/test.o" % temp_dir path_obj = temp.relpath("test.o")
path_ll = "%s/test.ll" % temp_dir path_ll = temp.relpath("test.ll")
path_bc = "%s/test.bc" % temp_dir path_bc = temp.relpath("test.bc")
path_dso = "%s/test.so" % temp_dir path_dso = temp.relpath("test.so")
save_object([path_obj, path_ll, path_bc]) save_object([path_obj, path_ll, path_bc])
cc.create_shared(path_dso, [path_obj]) cc.create_shared(path_dso, [path_obj])
...@@ -41,14 +56,14 @@ def test_dso_module_load(): ...@@ -41,14 +56,14 @@ def test_dso_module_load():
a = tvm.nd.array(np.zeros(10, dtype=dtype)) a = tvm.nd.array(np.zeros(10, dtype=dtype))
f2(a) f2(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0])) np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
files = [path_obj, path_ll, path_bc, path_dso]
for f in files:
os.remove(f)
os.rmdir(temp_dir)
path_runtime_py = temp.relpath("runtime.py")
with open(path_runtime_py, "w") as fo:
fo.write(runtime_py)
def test_cuda_module_load(): subprocess.check_call(
pass "python %s %s %s" % (path_runtime_py, path_dso, dtype),
shell=True)
if __name__ == "__main__": if __name__ == "__main__":
test_dso_module_load() test_dso_module_load()
...@@ -2,9 +2,11 @@ import tvm ...@@ -2,9 +2,11 @@ import tvm
import numpy as np import numpy as np
def enabled_ctx_list(): def enabled_ctx_list():
tvm.init_opencl() if tvm.module.enabled("opencl"):
ctx_list = [tvm.cpu(0), tvm.gpu(0), tvm.opencl(0)] tvm.module.init_opencl()
ctx_list = [ctx for ctx in ctx_list if ctx.enabled]
ctx_list = [('cpu', tvm.cpu(0)), ('gpu', tvm.gpu(0)), ('cl', tvm.opencl(0))]
ctx_list = [x[1] for x in ctx_list if tvm.module.enabled(x[0])]
return ctx_list return ctx_list
ENABLED_CTX_LIST = enabled_ctx_list() ENABLED_CTX_LIST = enabled_ctx_list()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment