Commit 08505e34 by Tianqi Chen Committed by GitHub

[ADDON] Allow piggy back nvcc compiler and code (#35)

parent 88377988
...@@ -51,8 +51,9 @@ typedef enum { ...@@ -51,8 +51,9 @@ typedef enum {
kArrayHandle = 5U, kArrayHandle = 5U,
kTVMType = 6U, kTVMType = 6U,
kNodeHandle = 7U, kNodeHandle = 7U,
kStr = 8U, kFuncHandle = 8U,
kFuncHandle = 9U kStr = 9U,
kBytes = 10U
} TVMTypeCode; } TVMTypeCode;
/*! /*!
...@@ -87,6 +88,15 @@ typedef union { ...@@ -87,6 +88,15 @@ typedef union {
} TVMValue; } TVMValue;
/*! /*!
* \brief Byte array type used to pass in byte array
* When kBytes is used as data type.
*/
typedef struct {
const char* data;
size_t size;
} TVMByteArray;
/*!
* \brief The device type * \brief The device type
*/ */
typedef enum { typedef enum {
......
...@@ -112,6 +112,12 @@ class PackedFunc { ...@@ -112,6 +112,12 @@ class PackedFunc {
*/ */
static const PackedFunc& GetGlobal(const std::string& name); static const PackedFunc& GetGlobal(const std::string& name);
/*! /*!
* \brief Whether the global function exist
* \param name The name of the function.
* \return Whetehr the global function exist.
*/
static bool GlobalExist(const std::string& name);
/*!
* \brief Get the names of currently registered global function. * \brief Get the names of currently registered global function.
*/ */
static std::vector<std::string> ListGlobalNames(); static std::vector<std::string> ListGlobalNames();
...@@ -267,10 +273,14 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -267,10 +273,14 @@ class TVMArgValue : public TVMPODValue_ {
operator std::string() const { operator std::string() const {
if (type_code_ == kTVMType) { if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType()); return TVMType2String(operator TVMType());
} } else if (type_code_ == kBytes) {
TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
return std::string(arr->data, arr->size);
} else {
TVM_CHECK_TYPE_CODE(type_code_, kStr); TVM_CHECK_TYPE_CODE(type_code_, kStr);
return std::string(value_.v_str); return std::string(value_.v_str);
} }
}
operator TVMType() const { operator TVMType() const {
if (type_code_ == kStr) { if (type_code_ == kStr) {
return String2TVMType(operator std::string()); return String2TVMType(operator std::string());
...@@ -452,7 +462,8 @@ class TVMRetValue : public TVMPODValue_ { ...@@ -452,7 +462,8 @@ class TVMRetValue : public TVMPODValue_ {
template<typename T> template<typename T>
void Assign(const T& other) { void Assign(const T& other) {
switch (other.type_code()) { switch (other.type_code()) {
case kStr: { case kStr:
case kBytes: {
SwitchToClass<std::string>(kStr, other); SwitchToClass<std::string>(kStr, other);
break; break;
} }
......
# coding: utf-8 # coding: utf-8
# pylint: disable=invalid-name, protected-access # pylint: disable=invalid-name, protected-access, too-many-branches
"""Symbolic configuration API.""" """Symbolic configuration API."""
from __future__ import absolute_import from __future__ import absolute_import
...@@ -9,7 +9,7 @@ from numbers import Number, Integral ...@@ -9,7 +9,7 @@ from numbers import Number, Integral
from .._base import _LIB, check_call from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types from .._base import c_str, py_str, string_types
from ._types import TVMValue, TypeCode, TVMType from ._types import TVMValue, TypeCode, TVMType, TVMByteArray
from ._types import TVMPackedCFunc, TVMCFuncFinalizer from ._types import TVMPackedCFunc, TVMCFuncFinalizer
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH
from ._node import NodeBase, SliceBase, convert_to_node from ._node import NodeBase, SliceBase, convert_to_node
...@@ -92,6 +92,15 @@ def _make_tvm_args(args, temp_args): ...@@ -92,6 +92,15 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, TVMType): elif isinstance(arg, TVMType):
values[i].v_str = c_str(str(arg)) values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR type_codes[i] = TypeCode.STR
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr)
type_codes[i] = TypeCode.BYTES
elif isinstance(arg, string_types): elif isinstance(arg, string_types):
values[i].v_str = c_str(arg) values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR type_codes[i] = TypeCode.STR
......
...@@ -18,8 +18,9 @@ class TypeCode(object): ...@@ -18,8 +18,9 @@ class TypeCode(object):
ARRAY_HANDLE = 5 ARRAY_HANDLE = 5
TVM_TYPE = 6 TVM_TYPE = 6
NODE_HANDLE = 7 NODE_HANDLE = 7
STR = 8 FUNC_HANDLE = 8
FUNC_HANDLE = 9 STR = 9
BYTES = 10
def _api_type(code): def _api_type(code):
"""create a type accepted by API""" """create a type accepted by API"""
...@@ -88,6 +89,11 @@ class TVMValue(ctypes.Union): ...@@ -88,6 +89,11 @@ class TVMValue(ctypes.Union):
("v_handle", ctypes.c_void_p), ("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p)] ("v_str", ctypes.c_char_p)]
class TVMByteArray(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
TVMPackedCFunc = ctypes.CFUNCTYPE( TVMPackedCFunc = ctypes.CFUNCTYPE(
None, None,
...@@ -110,20 +116,34 @@ def _return_handle(x): ...@@ -110,20 +116,34 @@ def _return_handle(x):
handle = ctypes.c_void_p(handle) handle = ctypes.c_void_p(handle)
return handle return handle
def _return_bytes(x):
"""return handle"""
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
size = arr.size
res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed')
return res
RETURN_SWITCH = { RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64, TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64, TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle, TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None, TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str) TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
} }
C_TO_PY_ARG_SWITCH = { C_TO_PY_ARG_SWITCH = {
TypeCode.INT: lambda x: x.v_int64, TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64, TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle, TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None, TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str) TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
} }
"""Addon utilities to python"""
"""Util to compile with NVCC"""
import os
import sys
import tempfile
import subprocess
def compile_source(code, target="cubin"):
"""Compile cuda code with NVCC from env.
Parameters
----------
code : str
The cuda code.
target: str
The target format
Return
------
cubin : bytearray
The bytearray of the cubin
"""
temp_dir = tempfile.mkdtemp()
if target not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target must be in cubin, ptx, fatbin")
path_code = os.path.join(temp_dir, "my_kernel.cu")
path_target = os.path.join(temp_dir, "my_kernel.%s" % target)
with open(path_code, "w") as out_file:
out_file.write(code)
cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"]
cmd += ["-o", path_target]
cmd += [path_code]
args = ' '.join(cmd)
proc = subprocess.Popen(
args, shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
sys.stderr.write("Compilation error:\n")
sys.stderr.write(out)
sys.stderr.flush()
cubin = None
else:
cubin = bytearray(open(path_target, "rb").read())
os.remove(path_code)
if os.path.exists(path_target):
os.remove(path_target)
os.rmdir(temp_dir)
return cubin
...@@ -158,7 +158,8 @@ class Canonical::Internal : public IRMutator { ...@@ -158,7 +158,8 @@ class Canonical::Internal : public IRMutator {
} }
// functions // functions
Stmt Mutate(Stmt stmt) final { Stmt Mutate(Stmt stmt) final {
return IRMutator::Mutate(stmt); stmt = IRMutator::Mutate(stmt);
return stmt;
} }
Expr MutateExpr_(Expr expr) { Expr MutateExpr_(Expr expr) {
static const FMutateExpr& f = Internal::vtable_expr(); static const FMutateExpr& f = Internal::vtable_expr();
...@@ -176,6 +177,7 @@ class Canonical::Internal : public IRMutator { ...@@ -176,6 +177,7 @@ class Canonical::Internal : public IRMutator {
ret_entry_.has_side_effect = stack_.back().has_side_effect; ret_entry_.has_side_effect = stack_.back().has_side_effect;
ret_entry_.max_level = stack_.back().max_level; ret_entry_.max_level = stack_.back().max_level;
stack_.pop_back(); stack_.pop_back();
CHECK(expr.defined());
return expr; return expr;
} }
// call produce to get a cache entry. // call produce to get a cache entry.
...@@ -399,6 +401,7 @@ class Canonical::Internal : public IRMutator { ...@@ -399,6 +401,7 @@ class Canonical::Internal : public IRMutator {
// subroutine to do produce // subroutine to do produce
Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) { Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) {
ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale); ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale);
CHECK_NE(stack_.size(), 0U);
ret_entry_.max_level = stack_.back().max_level; ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect; ret_entry_.has_side_effect = stack_.back().has_side_effect;
auto it = cache_sum_.find(ret_entry_.sum); auto it = cache_sum_.find(ret_entry_.sum);
...@@ -408,8 +411,6 @@ class Canonical::Internal : public IRMutator { ...@@ -408,8 +411,6 @@ class Canonical::Internal : public IRMutator {
ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type()); ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
cache_sum_[ret_entry_.sum] = ret_entry_; cache_sum_[ret_entry_.sum] = ret_entry_;
} }
ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
return ret_entry_.value; return ret_entry_.value;
} }
// convert sum to expr // convert sum to expr
...@@ -444,7 +445,11 @@ class Canonical::Internal : public IRMutator { ...@@ -444,7 +445,11 @@ class Canonical::Internal : public IRMutator {
} }
} }
} }
if (vsum.defined()) {
return vsum; return vsum;
} else {
return make_zero(t);
}
} }
}; };
......
...@@ -50,7 +50,19 @@ MakeNVRTC(Array<LoweredFunc> funcs) { ...@@ -50,7 +50,19 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
os << CodeGenCUDA().Compile(f, output_ssa); os << CodeGenCUDA().Compile(f, output_ssa);
os << '\n'; os << '\n';
} }
std::string ptx = runtime::NVRTCCompile(os.str()); std::string code = os.str();
if (PackedFunc::GlobalExist("tvm_callback_cuda_postproc")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
code = f(code).operator std::string();
}
std::string ptx;
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
ptx = f(code).operator std::string();
} else {
ptx = runtime::NVRTCCompile(os.str());
}
std::unordered_map<LoweredFunc, PackedFunc> ret; std::unordered_map<LoweredFunc, PackedFunc> ret;
runtime::CUDAModule m = runtime::CUDAModule::Create(ptx); runtime::CUDAModule m = runtime::CUDAModule::Create(ptx);
......
...@@ -46,6 +46,12 @@ const PackedFunc& PackedFunc::GetGlobal(const std::string& name) { ...@@ -46,6 +46,12 @@ const PackedFunc& PackedFunc::GetGlobal(const std::string& name) {
return *(it->second); return *(it->second);
} }
bool PackedFunc::GlobalExist(const std::string& name) {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
auto it = r->fmap.find(name);
return it != r->fmap.end();
}
std::vector<std::string> PackedFunc::ListGlobalNames() { std::vector<std::string> PackedFunc::ListGlobalNames() {
PackedFuncRegistry* r = PackedFuncRegistry::Global(); PackedFuncRegistry* r = PackedFuncRegistry::Global();
std::vector<std::string> keys; std::vector<std::string> keys;
......
import tvm import tvm
from tvm.addon import nvcc_compiler
import numpy as np import numpy as np
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc_compiler.compile_source(code, target="ptx")
print(ptx.decode("utf-8"))
return ptx
@tvm.register_func
def tvm_callback_cuda_postproc(code):
print(code)
return code
def test_gemm(): def test_gemm():
# graph # graph
nn = 1024 nn = 1024
...@@ -23,7 +35,6 @@ def test_gemm(): ...@@ -23,7 +35,6 @@ def test_gemm():
s = tvm.Schedule(C.op) s = tvm.Schedule(C.op)
xtile, ytile = 32, 32 xtile, ytile = 32, 32
s[AA].set_scope("shared") s[AA].set_scope("shared")
#s[CC].set_scope("global")
s[BB].set_scope("shared") s[BB].set_scope("shared")
scale = 8 scale = 8
...@@ -60,8 +71,6 @@ def test_gemm(): ...@@ -60,8 +71,6 @@ def test_gemm():
codes = [] codes = []
f = tvm.build(s, [A, B, C], target, record_codes=codes, f = tvm.build(s, [A, B, C], target, record_codes=codes,
max_auto_unroll_step=max_auto_unroll_step) max_auto_unroll_step=max_auto_unroll_step)
for c in codes[1:]:
print(c)
if target == "cuda": if target == "cuda":
ctx = tvm.gpu(0) ctx = tvm.gpu(0)
else: else:
...@@ -77,13 +86,14 @@ def test_gemm(): ...@@ -77,13 +86,14 @@ def test_gemm():
a = tvm.nd.array(a_np, ctx) a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx) b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx) c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
for i in range(4):
f(a, b, c) f(a, b, c)
np.testing.assert_allclose( np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5) c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
tvm.init_opencl()
check_device("cuda") check_device("cuda")
check_device("opencl") #tvm.init_opencl()
#check_device("opencl")
if __name__ == "__main__": if __name__ == "__main__":
test_gemm() test_gemm()
...@@ -35,9 +35,17 @@ def test_convert(): ...@@ -35,9 +35,17 @@ def test_convert():
assert isinstance(f, tvm.nd.Function) assert isinstance(f, tvm.nd.Function)
f(*targs) f(*targs)
def test_byte_array():
s = "hello"
a = bytearray(s, encoding="ascii")
def myfunc(ss):
assert ss == a
f = tvm.convert(myfunc)
f(a)
if __name__ == "__main__": if __name__ == "__main__":
test_function()
test_convert() test_convert()
test_get_global() test_get_global()
test_return_func() test_return_func()
test_byte_array()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment