Commit 08505e34 by Tianqi Chen Committed by GitHub

[ADDON] Allow piggy back nvcc compiler and code (#35)

parent 88377988
......@@ -51,8 +51,9 @@ typedef enum {
kArrayHandle = 5U,
kTVMType = 6U,
kNodeHandle = 7U,
kStr = 8U,
kFuncHandle = 9U
kFuncHandle = 8U,
kStr = 9U,
kBytes = 10U
} TVMTypeCode;
/*!
......@@ -87,6 +88,15 @@ typedef union {
} TVMValue;
/*!
* \brief Byte array type used to pass in byte array
* When kBytes is used as data type.
*/
typedef struct {
const char* data;
size_t size;
} TVMByteArray;
/*!
* \brief The device type
*/
typedef enum {
......
......@@ -112,6 +112,12 @@ class PackedFunc {
*/
static const PackedFunc& GetGlobal(const std::string& name);
/*!
* \brief Whether the global function exist
* \param name The name of the function.
* \return Whetehr the global function exist.
*/
static bool GlobalExist(const std::string& name);
/*!
* \brief Get the names of currently registered global function.
*/
static std::vector<std::string> ListGlobalNames();
......@@ -267,9 +273,13 @@ class TVMArgValue : public TVMPODValue_ {
operator std::string() const {
if (type_code_ == kTVMType) {
return TVMType2String(operator TVMType());
} else if (type_code_ == kBytes) {
TVMByteArray* arr = static_cast<TVMByteArray*>(value_.v_handle);
return std::string(arr->data, arr->size);
} else {
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return std::string(value_.v_str);
}
TVM_CHECK_TYPE_CODE(type_code_, kStr);
return std::string(value_.v_str);
}
operator TVMType() const {
if (type_code_ == kStr) {
......@@ -452,7 +462,8 @@ class TVMRetValue : public TVMPODValue_ {
template<typename T>
void Assign(const T& other) {
switch (other.type_code()) {
case kStr: {
case kStr:
case kBytes: {
SwitchToClass<std::string>(kStr, other);
break;
}
......
# coding: utf-8
# pylint: disable=invalid-name, protected-access
# pylint: disable=invalid-name, protected-access, too-many-branches
"""Symbolic configuration API."""
from __future__ import absolute_import
......@@ -9,7 +9,7 @@ from numbers import Number, Integral
from .._base import _LIB, check_call
from .._base import c_str, py_str, string_types
from ._types import TVMValue, TypeCode, TVMType
from ._types import TVMValue, TypeCode, TVMType, TVMByteArray
from ._types import TVMPackedCFunc, TVMCFuncFinalizer
from ._types import RETURN_SWITCH, C_TO_PY_ARG_SWITCH
from ._node import NodeBase, SliceBase, convert_to_node
......@@ -92,6 +92,15 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, TVMType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, bytearray):
arr = TVMByteArray()
arr.data = ctypes.cast(
(ctypes.c_byte * len(arg)).from_buffer(arg),
ctypes.POINTER(ctypes.c_byte))
arr.size = len(arg)
values[i].v_handle = ctypes.c_void_p(ctypes.addressof(arr))
temp_args.append(arr)
type_codes[i] = TypeCode.BYTES
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
......
......@@ -18,8 +18,9 @@ class TypeCode(object):
ARRAY_HANDLE = 5
TVM_TYPE = 6
NODE_HANDLE = 7
STR = 8
FUNC_HANDLE = 9
FUNC_HANDLE = 8
STR = 9
BYTES = 10
def _api_type(code):
"""create a type accepted by API"""
......@@ -88,6 +89,11 @@ class TVMValue(ctypes.Union):
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p)]
class TVMByteArray(ctypes.Structure):
"""TVM datatype structure"""
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
TVMPackedCFunc = ctypes.CFUNCTYPE(
None,
......@@ -110,20 +116,34 @@ def _return_handle(x):
handle = ctypes.c_void_p(handle)
return handle
def _return_bytes(x):
"""return handle"""
handle = x.v_handle
if not isinstance(handle, ctypes.c_void_p):
handle = ctypes.c_void_p(handle)
arr = ctypes.cast(handle, ctypes.POINTER(TVMByteArray))[0]
size = arr.size
res = bytearray(size)
rptr = (ctypes.c_byte * size).from_buffer(res)
if not ctypes.memmove(rptr, arr.data, size):
raise RuntimeError('memmove failed')
return res
RETURN_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str)
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
}
C_TO_PY_ARG_SWITCH = {
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.HANDLE: _return_handle,
TypeCode.NULL: lambda x: None,
TypeCode.STR: lambda x: py_str(x.v_str)
TypeCode.STR: lambda x: py_str(x.v_str),
TypeCode.BYTES: _return_bytes
}
"""Addon utilities to python"""
"""Util to compile with NVCC"""
import os
import sys
import tempfile
import subprocess
def compile_source(code, target="cubin"):
"""Compile cuda code with NVCC from env.
Parameters
----------
code : str
The cuda code.
target: str
The target format
Return
------
cubin : bytearray
The bytearray of the cubin
"""
temp_dir = tempfile.mkdtemp()
if target not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target must be in cubin, ptx, fatbin")
path_code = os.path.join(temp_dir, "my_kernel.cu")
path_target = os.path.join(temp_dir, "my_kernel.%s" % target)
with open(path_code, "w") as out_file:
out_file.write(code)
cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"]
cmd += ["-o", path_target]
cmd += [path_code]
args = ' '.join(cmd)
proc = subprocess.Popen(
args, shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
sys.stderr.write("Compilation error:\n")
sys.stderr.write(out)
sys.stderr.flush()
cubin = None
else:
cubin = bytearray(open(path_target, "rb").read())
os.remove(path_code)
if os.path.exists(path_target):
os.remove(path_target)
os.rmdir(temp_dir)
return cubin
......@@ -158,7 +158,8 @@ class Canonical::Internal : public IRMutator {
}
// functions
Stmt Mutate(Stmt stmt) final {
return IRMutator::Mutate(stmt);
stmt = IRMutator::Mutate(stmt);
return stmt;
}
Expr MutateExpr_(Expr expr) {
static const FMutateExpr& f = Internal::vtable_expr();
......@@ -176,6 +177,7 @@ class Canonical::Internal : public IRMutator {
ret_entry_.has_side_effect = stack_.back().has_side_effect;
ret_entry_.max_level = stack_.back().max_level;
stack_.pop_back();
CHECK(expr.defined());
return expr;
}
// call produce to get a cache entry.
......@@ -399,6 +401,7 @@ class Canonical::Internal : public IRMutator {
// subroutine to do produce
Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) {
ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale);
CHECK_NE(stack_.size(), 0U);
ret_entry_.max_level = stack_.back().max_level;
ret_entry_.has_side_effect = stack_.back().has_side_effect;
auto it = cache_sum_.find(ret_entry_.sum);
......@@ -408,8 +411,6 @@ class Canonical::Internal : public IRMutator {
ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
}
ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type());
cache_sum_[ret_entry_.sum] = ret_entry_;
return ret_entry_.value;
}
// convert sum to expr
......@@ -444,7 +445,11 @@ class Canonical::Internal : public IRMutator {
}
}
}
return vsum;
if (vsum.defined()) {
return vsum;
} else {
return make_zero(t);
}
}
};
......
......@@ -50,7 +50,19 @@ MakeNVRTC(Array<LoweredFunc> funcs) {
os << CodeGenCUDA().Compile(f, output_ssa);
os << '\n';
}
std::string ptx = runtime::NVRTCCompile(os.str());
std::string code = os.str();
if (PackedFunc::GlobalExist("tvm_callback_cuda_postproc")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
code = f(code).operator std::string();
}
std::string ptx;
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
ptx = f(code).operator std::string();
} else {
ptx = runtime::NVRTCCompile(os.str());
}
std::unordered_map<LoweredFunc, PackedFunc> ret;
runtime::CUDAModule m = runtime::CUDAModule::Create(ptx);
......
......@@ -46,6 +46,12 @@ const PackedFunc& PackedFunc::GetGlobal(const std::string& name) {
return *(it->second);
}
bool PackedFunc::GlobalExist(const std::string& name) {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
auto it = r->fmap.find(name);
return it != r->fmap.end();
}
std::vector<std::string> PackedFunc::ListGlobalNames() {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
std::vector<std::string> keys;
......
import tvm
from tvm.addon import nvcc_compiler
import numpy as np
@tvm.register_func
def tvm_callback_cuda_compile(code):
ptx = nvcc_compiler.compile_source(code, target="ptx")
print(ptx.decode("utf-8"))
return ptx
@tvm.register_func
def tvm_callback_cuda_postproc(code):
print(code)
return code
def test_gemm():
# graph
nn = 1024
......@@ -23,7 +35,6 @@ def test_gemm():
s = tvm.Schedule(C.op)
xtile, ytile = 32, 32
s[AA].set_scope("shared")
#s[CC].set_scope("global")
s[BB].set_scope("shared")
scale = 8
......@@ -60,8 +71,6 @@ def test_gemm():
codes = []
f = tvm.build(s, [A, B, C], target, record_codes=codes,
max_auto_unroll_step=max_auto_unroll_step)
for c in codes[1:]:
print(c)
if target == "cuda":
ctx = tvm.gpu(0)
else:
......@@ -77,13 +86,14 @@ def test_gemm():
a = tvm.nd.array(a_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), ctx)
f(a, b, c)
for i in range(4):
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), np.dot(a_np, b_np.T), rtol=1e-5)
tvm.init_opencl()
check_device("cuda")
check_device("opencl")
#tvm.init_opencl()
#check_device("opencl")
if __name__ == "__main__":
test_gemm()
......@@ -35,9 +35,17 @@ def test_convert():
assert isinstance(f, tvm.nd.Function)
f(*targs)
def test_byte_array():
s = "hello"
a = bytearray(s, encoding="ascii")
def myfunc(ss):
assert ss == a
f = tvm.convert(myfunc)
f(a)
if __name__ == "__main__":
test_function()
test_convert()
test_get_global()
test_return_func()
test_byte_array()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment