Commit 12cf343a by MORITA Kazutaka Committed by Tianqi Chen

[RUNTIME][SDACCEL] Add support for multiple kernels (#1424)

parent b4043855
...@@ -6,16 +6,14 @@ from ..api import register_func ...@@ -6,16 +6,14 @@ from ..api import register_func
@register_func("tvm_callback_sdaccel_compile") @register_func("tvm_callback_sdaccel_compile")
def compile_vhls(code, kernel): def compile_vhls(kernel_info):
"""Compile Vivado HLS code for SDAccel. """Compile Vivado HLS code for SDAccel.
Parameters Parameters
---------- ----------
code : str kernel_info : list of (str, str)
The Vivado HLS code. List of kernel information. The kernel information is a tuple of
function name and source code.
kernel : str
The kernel to compile or link.
Return Return
------ ------
...@@ -23,12 +21,6 @@ def compile_vhls(code, kernel): ...@@ -23,12 +21,6 @@ def compile_vhls(code, kernel):
The bytearray of the xclbin The bytearray of the xclbin
""" """
tmp_dir = util.tempdir() tmp_dir = util.tempdir()
tmp_cpp = tmp_dir.relpath("input.cpp")
tmp_xo = tmp_dir.relpath("output.xo")
tmp_xclbin = tmp_dir.relpath("output.xclbin")
with open(tmp_cpp, "wb") as out_file:
out_file.write(bytes(code))
sdk = os.environ.get("XILINX_SDX", None) sdk = os.environ.get("XILINX_SDX", None)
xocc = os.path.join(sdk, "bin/xocc") if sdk else "xocc" xocc = os.path.join(sdk, "bin/xocc") if sdk else "xocc"
...@@ -41,15 +33,29 @@ def compile_vhls(code, kernel): ...@@ -41,15 +33,29 @@ def compile_vhls(code, kernel):
if platform is None: if platform is None:
raise RuntimeError("No Xlinx device specified.") raise RuntimeError("No Xlinx device specified.")
tmp_xo_files = []
for funcname, code in kernel_info:
funcname = funcname.value
code = code.value
tmp_cpp = tmp_dir.relpath(funcname + ".cpp")
tmp_xo = tmp_dir.relpath(funcname + ".xo")
with open(tmp_cpp, "wb") as out_file:
out_file.write(bytes(code))
# build xo # build xo
args = [xocc, "-c", "-t", target, "--platform", platform, "-o", tmp_xo, "-k", kernel] + \ args = [xocc, "-c", "-t", target, "--platform", platform, "-o", tmp_xo, "-k", funcname] + \
advanced_params + [tmp_cpp] advanced_params + [tmp_cpp]
returncode = subprocess.call(args) returncode = subprocess.call(args)
if returncode != 0: if returncode != 0:
raise RuntimeError("Compile error") raise RuntimeError("Compile error")
tmp_xo_files.append(tmp_xo)
# build xclbin # build xclbin
args = [xocc, "-l", "-t", target, "--platform", platform, "-o", tmp_xclbin, tmp_xo] + \ tmp_xclbin = tmp_dir.relpath("output.xclbin")
args = [xocc, "-l", "-t", target, "--platform", platform, "-o", tmp_xclbin] + tmp_xo_files + \
advanced_params advanced_params
returncode = subprocess.call(args) returncode = subprocess.call(args)
if returncode != 0: if returncode != 0:
......
...@@ -72,26 +72,33 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs) { ...@@ -72,26 +72,33 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs) {
bool output_ssa = false; bool output_ssa = false;
CodeGenVivadoHLS cg; CodeGenVivadoHLS cg;
CHECK_EQ(funcs.size(), 1); // Generate source code for get_source().
const std::string funcname = funcs[0]->name;
cg.Init(output_ssa); cg.Init(output_ssa);
for (LoweredFunc f : funcs) { for (LoweredFunc f : funcs) {
cg.AddFunction(f); cg.AddFunction(f);
} }
std::string whole_code = cg.Finish();
// Generate source code for compilation.
Array<Array<Expr> > kernel_info;
for (LoweredFunc f : funcs) {
CodeGenVivadoHLS cg;
cg.Init(output_ssa);
cg.AddFunction(f);
std::string code = cg.Finish(); std::string code = cg.Finish();
if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) { if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
code = (*f)(code).operator std::string(); code = (*f)(code).operator std::string();
} }
kernel_info.push_back(Array<Expr>({f->name, code}));
}
std::string xclbin; std::string xclbin;
if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) { if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) {
xclbin = (*f)(code, funcname).operator std::string(); xclbin = (*f)(kernel_info).operator std::string();
} else { } else {
LOG(FATAL) << "Cannot compile Vivado HLS code."; LOG(FATAL) << "Cannot compile Vivado HLS code.";
} }
return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(funcs), code); return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(funcs), whole_code);
} }
TVM_REGISTER_API("codegen.build_sdaccel") TVM_REGISTER_API("codegen.build_sdaccel")
......
...@@ -21,7 +21,7 @@ def test_exp(): ...@@ -21,7 +21,7 @@ def test_exp():
s[B].bind(px, tvm.thread_axis("pipeline")) s[B].bind(px, tvm.thread_axis("pipeline"))
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="llvm"):
if not tvm.module.enabled(host): if not tvm.module.enabled(host):
return return
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -42,5 +42,44 @@ def test_exp(): ...@@ -42,5 +42,44 @@ def test_exp():
check_device("sdaccel") check_device("sdaccel")
def test_multi_kernel():
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
D = tvm.compute(A.shape, lambda *i: A(*i) + C(*i), name='D')
s = tvm.create_schedule(D.op)
# create iter var and assign them tags.
px, x = s[C].split(C.op.axis[0], nparts=1)
s[C].bind(px, tvm.thread_axis("pipeline"))
px, x = s[D].split(D.op.axis[0], nparts=1)
s[D].bind(px, tvm.thread_axis("pipeline"))
# one line to build the function.
def check_device(device, host="llvm"):
if not tvm.module.enabled(host):
return
ctx = tvm.context(device, 0)
if not ctx.exist:
return
fadd = tvm.build(s, [A, B, C, D],
device, host,
name="myadd")
ctx = tvm.context(device, 0)
# launch the kernel.
n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.random.uniform(size=n).astype(C.dtype), ctx)
d = tvm.nd.array(np.random.uniform(size=n).astype(D.dtype), ctx)
fadd(a, b, c, d)
np.testing.assert_allclose(
d.asnumpy(), a.asnumpy() * 2 + b.asnumpy(), rtol=1e-5)
check_device("sdaccel")
if __name__ == "__main__": if __name__ == "__main__":
test_exp() test_exp()
test_multi_kernel()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment