Commit 12cf343a by MORITA Kazutaka Committed by Tianqi Chen

[RUNTIME][SDACCEL] Add support for multiple kernels (#1424)

parent b4043855
......@@ -6,16 +6,14 @@ from ..api import register_func
@register_func("tvm_callback_sdaccel_compile")
def compile_vhls(code, kernel):
def compile_vhls(kernel_info):
"""Compile Vivado HLS code for SDAccel.
Parameters
----------
code : str
The Vivado HLS code.
kernel : str
The kernel to compile or link.
kernel_info : list of (str, str)
List of kernel information. The kernel information is a tuple of
function name and source code.
Return
------
......@@ -23,12 +21,6 @@ def compile_vhls(code, kernel):
The bytearray of the xclbin
"""
tmp_dir = util.tempdir()
tmp_cpp = tmp_dir.relpath("input.cpp")
tmp_xo = tmp_dir.relpath("output.xo")
tmp_xclbin = tmp_dir.relpath("output.xclbin")
with open(tmp_cpp, "wb") as out_file:
out_file.write(bytes(code))
sdk = os.environ.get("XILINX_SDX", None)
xocc = os.path.join(sdk, "bin/xocc") if sdk else "xocc"
......@@ -41,15 +33,29 @@ def compile_vhls(code, kernel):
if platform is None:
raise RuntimeError("No Xlinx device specified.")
tmp_xo_files = []
for funcname, code in kernel_info:
funcname = funcname.value
code = code.value
tmp_cpp = tmp_dir.relpath(funcname + ".cpp")
tmp_xo = tmp_dir.relpath(funcname + ".xo")
with open(tmp_cpp, "wb") as out_file:
out_file.write(bytes(code))
# build xo
args = [xocc, "-c", "-t", target, "--platform", platform, "-o", tmp_xo, "-k", kernel] + \
args = [xocc, "-c", "-t", target, "--platform", platform, "-o", tmp_xo, "-k", funcname] + \
advanced_params + [tmp_cpp]
returncode = subprocess.call(args)
if returncode != 0:
raise RuntimeError("Compile error")
tmp_xo_files.append(tmp_xo)
# build xclbin
args = [xocc, "-l", "-t", target, "--platform", platform, "-o", tmp_xclbin, tmp_xo] + \
tmp_xclbin = tmp_dir.relpath("output.xclbin")
args = [xocc, "-l", "-t", target, "--platform", platform, "-o", tmp_xclbin] + tmp_xo_files + \
advanced_params
returncode = subprocess.call(args)
if returncode != 0:
......
......@@ -72,26 +72,33 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs) {
bool output_ssa = false;
CodeGenVivadoHLS cg;
CHECK_EQ(funcs.size(), 1);
const std::string funcname = funcs[0]->name;
// Generate source code for get_source().
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
std::string whole_code = cg.Finish();
// Generate source code for compilation.
Array<Array<Expr> > kernel_info;
for (LoweredFunc f : funcs) {
CodeGenVivadoHLS cg;
cg.Init(output_ssa);
cg.AddFunction(f);
std::string code = cg.Finish();
if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
code = (*f)(code).operator std::string();
}
kernel_info.push_back(Array<Expr>({f->name, code}));
}
std::string xclbin;
if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) {
xclbin = (*f)(code, funcname).operator std::string();
xclbin = (*f)(kernel_info).operator std::string();
} else {
LOG(FATAL) << "Cannot compile Vivado HLS code.";
}
return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(funcs), code);
return SDAccelModuleCreate(xclbin, "xclbin", ExtractFuncInfo(funcs), whole_code);
}
TVM_REGISTER_API("codegen.build_sdaccel")
......
......@@ -21,7 +21,7 @@ def test_exp():
s[B].bind(px, tvm.thread_axis("pipeline"))
# one line to build the function.
def check_device(device, host="stackvm"):
def check_device(device, host="llvm"):
if not tvm.module.enabled(host):
return
ctx = tvm.context(device, 0)
......@@ -42,5 +42,44 @@ def test_exp():
check_device("sdaccel")
def test_multi_kernel():
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
D = tvm.compute(A.shape, lambda *i: A(*i) + C(*i), name='D')
s = tvm.create_schedule(D.op)
# create iter var and assign them tags.
px, x = s[C].split(C.op.axis[0], nparts=1)
s[C].bind(px, tvm.thread_axis("pipeline"))
px, x = s[D].split(D.op.axis[0], nparts=1)
s[D].bind(px, tvm.thread_axis("pipeline"))
# one line to build the function.
def check_device(device, host="llvm"):
if not tvm.module.enabled(host):
return
ctx = tvm.context(device, 0)
if not ctx.exist:
return
fadd = tvm.build(s, [A, B, C, D],
device, host,
name="myadd")
ctx = tvm.context(device, 0)
# launch the kernel.
n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx)
c = tvm.nd.array(np.random.uniform(size=n).astype(C.dtype), ctx)
d = tvm.nd.array(np.random.uniform(size=n).astype(D.dtype), ctx)
fadd(a, b, c, d)
np.testing.assert_allclose(
d.asnumpy(), a.asnumpy() * 2 + b.asnumpy(), rtol=1e-5)
check_device("sdaccel")
if __name__ == "__main__":
test_exp()
test_multi_kernel()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment