Commit ca2ad6d4 by MORITA Kazutaka Committed by Tianqi Chen

Add support for Xilinx FPGA board with SDAccel (#1278)

parent c3df7726
HLS Backend Example
===================
TVM supports Xilinx FPGA board with SDAccel. Here is a tutorial for how to deploy TVM to AWS F1 FPGA instance.
***Note***: This feature is still experimental. We cannot use SDAccel to deploy an end to end neural networks for now.
We use two python scripts for this tutorial.
- build.py - a script to synthesize FPGA bitstream.
```python
import tvm
tgt_host="llvm"
tgt="sdaccel"
n = tvm.var("n")
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
C = tvm.compute(A.shape, lambda i: A[i] + B[i], name="C")
s = tvm.create_schedule(C.op)
px, x = s[C].split(C.op.axis[0], nparts=1)
s[C].bind(px, tvm.thread_axis("pipeline"))
fadd = tvm.build(s, [A, B, C], tgt, target_host=tgt_host, name="myadd")
fadd.save("myadd.o")
fadd.imported_modules[0].save("myadd.xclbin")
tvm.contrib.cc.create_shared("myadd.so", ["myadd.o"])
```
- run.py - a script to use FPGA as an accelerator.
```python
import tvm
import numpy as np
import os
tgt="sdaccel"
fadd = tvm.module.load("myadd.so")
if os.environ.get("XCL_EMULATION_MODE"):
fadd_dev = tvm.module.load("myadd.xclbin")
else:
fadd_dev = tvm.module.load("myadd.awsxclbin")
fadd.import_module(fadd_dev)
ctx = tvm.context(tgt, 0)
n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype("float32"), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype("float32"), ctx)
c = tvm.nd.array(np.zeros(n, dtype="float32"), ctx)
fadd(a, b, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
```
Setup
-----
- Launch an instance using the FPGA Developer AMI. We don't need an F1 instance for emulation and synthesis, so it is recommended to use a lower cost instance for them.
- Setup AWS FPGA development kit.
```bash
git clone https://github.com/aws/aws-fpga.git
cd aws-fpga
source sdaccel_setup.sh
source ${XILINX_SDX}/settings64.sh
```
- Setup TVM with OpenCL enabled.
Emulation
---------
- Create emconfig.json for emulation.
```bash
emconfigutil --platform ${AWS_PLATFORM} --nd 1
```
- Copy emconfig.json to the python binary directory. It is because the current Xilinx toolkit assumes that both host binary and the emconfig.json file are in the same path.
```bash
cp emconfig.json $(dirname $(which python))
```
- Run software emulation
```bash
export XCL_EMULATION_MODE=1
export XCL_TARGET=sw_emu
python build.py
python run.py
```
- Run hardware emulation
```bash
export XCL_EMULATION_MODE=1
export XCL_TARGET=hw_emu
python build.py
python run.py
```
Synthesis
---------
- Run synthesis with the following script. `XCL_EMULATION_MODE` must be set to 1 at this stage.
```bash
export XCL_EMULATION_MODE=1
export XCL_TARGET=hw
python build.py
```
The result shows CL_INVALID_PROGRAM error. It is because AWS SDAccel expects awsxclbin binary, but we pass xclbin instead of it. We don't load FPGA image here, so simply ignore the error for now.
- Create AWS FPGA image and upload it to AWS S3.
```
${SDACCEL_DIR}/tools/create_sdaccel_afi.sh -xclbin=myadd.xclbin -o=myadd \
-s3_bucket=<bucket-name> -s3_dcp_key=<dcp-folder-name> -s3_logs_key=<logs-folder-name>
```
This also generates an awsxclbin file, which is necessary to use the AWS FPGA image on F1 instances.
Run
---
- Launch Amazon EC2 F1 instance.
- Copy `myadd.so`, `myadd.awsxclbin`, and `run.py` to the F1 instance.
- Setup AWS FPGA development kit.
```bash
git clone https://github.com/aws/aws-fpga.git
cd aws-fpga
source sdaccel_setup.sh
```
- Setup TVM with OpenCL enabled.
- Become root and setup environment variables.
```bash
sudo sh
source ${INSTALL_ROOT}/setup.sh
```
- Run
```bash
python run.py
```
......@@ -35,4 +35,4 @@ from .build_module import build, lower, build_config
from .tag import tag_scope
# Contrib initializers
from .contrib import rocm as _rocm, nvcc as _nvcc
from .contrib import rocm as _rocm, nvcc as _nvcc, sdaccel as _sdaccel
......@@ -111,6 +111,7 @@ class TVMContext(ctypes.Structure):
'nvptx': 2,
'cl': 4,
'opencl': 4,
'sdaccel': 4,
'vulkan': 7,
'metal': 8,
'vpi': 9,
......
"""Utility for Interacting with SDAccel Tools"""
import subprocess
import os
import re
from . import util
from ..api import register_func
def _vhls_to_opencl(code):
"""Convert source code from Vivado HLS to OpenCL."""
out = ''
for line in code.split('\n'):
if re.match(r'#include', line):
# OpenCL doesn't support include.
continue
if re.match(r'#pragma', line):
# Remove Vivado HLS specific pragmas.
continue
if re.match(r'extern "C"', line):
line = re.sub(r'^extern "C"', "__kernel", line)
# Add __global to pointer parameters.
line = re.sub(r'(\w+)\s*\*', r"__global \1*", line)
out += line + '\n'
return out
def _fake_compile_vhls(code):
"""Fake compile Vivado HLS code for SDAccel.
Compile the Vivado HLS code as an OpenCL code, and generate a program
binary for GPU which can be used instead of xclbin.
Parameters
----------
code : str
The Vivado HLS code.
Return
------
binary : bytearray
The program binary which can be passed to clCreateProgramWithBinary
"""
try:
import pyopencl as cl
except ImportError:
raise ImportError('PyOpenCL is required for testing SDAccel backend.')
ctx = cl.Context(dev_type=cl.device_type.GPU)
program = cl.Program(ctx, _vhls_to_opencl(code)).build()
binary = bytearray(program.binaries[0])
return binary
@register_func("tvm_callback_sdaccel_compile")
def compile_vhls(code, kernel):
"""Compile Vivado HLS code for SDAccel.
Parameters
----------
code : str
The Vivado HLS code.
kernel : str
The kernel to compile or link.
Return
------
xclbin : bytearray
The bytearray of the xclbin
"""
tmp_dir = util.tempdir()
tmp_cpp = tmp_dir.relpath("input.cpp")
tmp_xo = tmp_dir.relpath("output.xo")
tmp_xclbin = tmp_dir.relpath("output.xclbin")
with open(tmp_cpp, "wb") as out_file:
out_file.write(bytes(code))
sdk = os.environ.get("XILINX_SDX", None)
xocc = os.path.join(sdk, "bin/xocc") if sdk else "xocc"
target = os.environ.get("XCL_TARGET",
"sw_emu" if os.environ.get("XCL_EMULATION_MODE") else "hw")
advanced_params = ["--xp", "param:compiler.preserveHlsOutput=1",
"--xp", "param:compiler.generateExtraRunData=true"]
platform = os.environ.get("XCL_PLATFORM", os.environ.get("AWS_PLATFORM"))
if platform is None:
# If we don't have the Xilinx toolchain, create a program binary for
# GPU and use it for testing.
return _fake_compile_vhls(code)
# build xo
args = [xocc, "-c", "-t", target, "--platform", platform, "-o", tmp_xo, "-k", kernel] + \
advanced_params + [tmp_cpp]
returncode = subprocess.call(args)
if returncode != 0:
raise RuntimeError("Compile error")
# build xclbin
args = [xocc, "-l", "-t", target, "--platform", platform, "-o", tmp_xclbin, tmp_xo] + \
advanced_params
returncode = subprocess.call(args)
if returncode != 0:
raise RuntimeError("Link error")
return bytearray(open(tmp_xclbin, "rb").read())
......@@ -88,6 +88,9 @@ Target CreateTarget(const std::string& target_name,
t->keys_array.push_back(ir::StringImm::make(target_name));
t->keys_array.push_back(ir::StringImm::make("gpu"));
t->max_num_threads = 256;
} else if (target_name == "sdaccel") {
t->device_type = kDLOpenCL;
t->keys_array.push_back(ir::StringImm::make("sdaccel"));
} else if (target_name == "opengl") {
t->device_type = kOpenGL;
t->keys_array.push_back(ir::StringImm::make("opengl"));
......
......@@ -60,6 +60,7 @@ void CodeGenC::AddFunction(LoweredFunc f) {
stream << ' ' << vid;
}
stream << ") {\n";
this->PreFunctionBody(f);
int func_scope = this->BeginScope();
this->PrintStmt(f->body);
this->EndScope(func_scope);
......
......@@ -74,6 +74,11 @@ class CodeGenC :
}
// The following parts are overloadable print operations.
/*!
* \brief Insert statement before function body.
* \param f The function to be compiled.
*/
virtual void PreFunctionBody(LoweredFunc f) {}
/*!
* \brief Initialize codegen state for generating f.
* \param f The function to be compiled.
*/
......
......@@ -218,7 +218,7 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
code = (*f)(code).operator std::string();
}
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs));
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs), code);
}
TVM_REGISTER_API("codegen.build_opencl")
......
/*!
* Copyright (c) 2018 by Contributors
* \file codegen_vhls.cc
*/
#include <vector>
#include <string>
#include "./codegen_vhls.h"
#include "./build_common.h"
#include "../runtime/opencl/opencl_module.h"
namespace tvm {
namespace codegen {
void CodeGenVivadoHLS::Init(bool output_ssa) {
CodeGenC::Init(output_ssa);
this->stream << "#include <ap_int.h>\n\n";
}
void CodeGenVivadoHLS::PrintType(Type t, std::ostream& os) {
if (t.is_uint()) {
switch (t.bits()) {
case 8:
os << "unsigned char"; break;
case 16:
os << "unsigned short"; break;
case 32:
os << "unsigned int"; break;
case 64:
os << "unsigned long long"; break;
default:
os << "ap_uint<" << t.bits() << ">"; break;
}
} else if (t.is_int()) {
switch (t.bits()) {
case 8:
os << "char"; break;
case 16:
os << "short"; break;
case 32:
os << "int"; break;
case 64:
os << "long long"; break;
default:
os << "ap_int<" << t.bits() << ">"; break;
}
} else {
CodeGenC::PrintType(t, os);
}
}
void CodeGenVivadoHLS::AddFunction(LoweredFunc f) {
this->stream << "extern \"C\" ";
CodeGenC::AddFunction(f);
}
void CodeGenVivadoHLS::PreFunctionBody(LoweredFunc f) {
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
std::string vid = GetVarID(v.get());
if (v.type().is_handle()) {
this->stream << "#pragma HLS INTERFACE m_axi port=" << vid << " offset=slave bundle=gmem\n";
}
this->stream << "#pragma HLS INTERFACE s_axilite port=" << vid << " bundle=control\n";
}
this->stream << "#pragma HLS INTERFACE s_axilite port=return bundle=control\n\n";
}
runtime::Module BuildSDAccel(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenVivadoHLS cg;
CHECK_EQ(funcs.size(), 1);
const std::string funcname = funcs[0]->name;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
code = (*f)(code).operator std::string();
}
std::string xclbin;
if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) {
xclbin = (*f)(code, funcname).operator std::string();
} else {
LOG(FATAL) << "Cannot compile Vivado HLS code.";
}
return OpenCLModuleCreate(xclbin, "xclbin", ExtractFuncInfo(funcs), code);
}
TVM_REGISTER_API("codegen.build_sdaccel")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildSDAccel(args[0]);
});
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* \file codegen_vhls.h
* \brief Utility to generate vhls code
*/
#ifndef TVM_CODEGEN_CODEGEN_VHLS_H_
#define TVM_CODEGEN_CODEGEN_VHLS_H_
#include <tvm/codegen.h>
#include <tvm/packed_func_ext.h>
#include <string>
#include "./codegen_c.h"
namespace tvm {
namespace codegen {
class CodeGenVivadoHLS final : public CodeGenC {
public:
void Init(bool output_ssa);
void PrintType(Type t, std::ostream& os);
void AddFunction(LoweredFunc f);
void PreFunctionBody(LoweredFunc f);
};
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_VHLS_H_
/*!
* Copyright (c) 2018 by Contributors
* \file intrin_rule_vhls.cc
* \brief VHLS intrinsic rules.
*/
#include "./intrin_rule.h"
namespace tvm {
namespace codegen {
namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.log")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.tanh")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.sqrt")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.pow")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.popcount")
.set_body(DispatchExtern<Direct>);
} // namespace intrin
} // namespace codegen
} // namespace tvm
......@@ -11,7 +11,8 @@ namespace runtime {
Module OpenCLModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap) {
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) {
LOG(WARNING) << "OpenCL runtime not enabled, return a source module...";
return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl");
}
......
......@@ -179,7 +179,7 @@ class HostDeviceSplitter : public IRMutator {
private:
Stmt SplitDeviceFunc(Stmt body) {
std::ostringstream os;
os << name_ << "__kernel" << device_funcs_.size();
os << name_ << "_kernel" << device_funcs_.size();
std::shared_ptr<LoweredFuncNode> n = std::make_shared<LoweredFuncNode>();
// isolate the device function.
IRUseDefAnalysis m;
......
......@@ -106,7 +106,7 @@ bool RuntimeEnabled(const std::string& target) {
return true;
} else if (target == "cuda" || target == "gpu") {
f_name = "device_api.gpu";
} else if (target == "cl" || target == "opencl") {
} else if (target == "cl" || target == "opencl" || target == "sdaccel") {
f_name = "device_api.opencl";
} else if (target == "gl" || target == "opengl") {
f_name = "device_api.opengl";
......
......@@ -244,16 +244,18 @@ void OpenCLWorkspace::Init() {
this->platform_id = platform_matched[0];
LOG(INFO) << "Initialize OpenCL platform \'"
<< cl::GetPlatformInfo(this->platform_id, CL_PLATFORM_NAME) << '\'';
std::vector<cl_device_id> devices_matched =
cl::GetDeviceIDs(this->platform_id, "gpu");
if (devices_matched.size() == 0) {
LOG(WARNING) << "No OpenCL device any device matched given the options: gpu mode";
LOG(WARNING) << "Now try OpenCL cpu mode";
devices_matched = cl::GetDeviceIDs(this->platform_id, "cpu");
if (devices_matched.size() == 0) {
LOG(WARNING) << "No OpenCL device any device matched given the options: cpu mode";
return;
std::string device_types[] = {"accelerator", "gpu", "cpu"};
std::vector<cl_device_id> devices_matched;
for (auto type : device_types) {
devices_matched = cl::GetDeviceIDs(this->platform_id, type);
if (devices_matched.size() > 0) {
break;
}
LOG(INFO) << "No OpenCL device any device matched given the options: " << type << " mode";
}
if (devices_matched.size() == 0) {
LOG(WARNING) << "No OpenCL device";
return;
}
this->devices = devices_matched;
cl_int err_code;
......
......@@ -32,8 +32,9 @@ class OpenCLModuleNode : public ModuleNode {
};
explicit OpenCLModuleNode(std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap)
: data_(data), fmt_(fmt), fmap_(fmap) {}
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source)
: data_(data), fmt_(fmt), fmap_(fmap), source_(source) {}
// destructor
~OpenCLModuleNode() {
{
......@@ -81,7 +82,7 @@ class OpenCLModuleNode : public ModuleNode {
if (fmt_ == "cl") {
return data_;
} else {
return "";
return source_;
}
}
......@@ -97,6 +98,15 @@ class OpenCLModuleNode : public ModuleNode {
program_ = clCreateProgramWithSource(
workspace_->context, 1, &s, &len, &err);
OPENCL_CHECK_ERROR(err);
} else if (fmt_ == "xclbin" || fmt_ == "awsxclbin") {
const unsigned char* s = (const unsigned char *)data_.c_str();
size_t len = data_.length();
cl_int err;
program_ = clCreateProgramWithBinary(
workspace_->context, 1, &(workspace_->devices[0]), &len, &s, NULL, &err);
if (err != CL_SUCCESS) {
LOG(ERROR) << "OpenCL Error: " << cl::CLGetErrorString(err);
}
} else {
LOG(FATAL) << "Unknown OpenCL format " << fmt_;
}
......@@ -162,6 +172,8 @@ class OpenCLModuleNode : public ModuleNode {
std::unordered_map<std::string, FunctionInfo> fmap_;
// Module local mutex
std::mutex build_lock_;
// The OpenCL source.
std::string source_;
// the binary data
cl_program program_{nullptr};
// build info
......@@ -270,9 +282,10 @@ PackedFunc OpenCLModuleNode::GetFunction(
Module OpenCLModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap) {
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) {
std::shared_ptr<OpenCLModuleNode> n =
std::make_shared<OpenCLModuleNode>(data, fmt, fmap);
std::make_shared<OpenCLModuleNode>(data, fmt, fmap, source);
n->Init();
return Module(n);
}
......@@ -286,7 +299,7 @@ Module OpenCLModuleLoadFile(const std::string& file_name,
std::string meta_file = GetMetaFilePath(file_name);
LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
return OpenCLModuleCreate(data, fmt, fmap);
return OpenCLModuleCreate(data, fmt, fmap, std::string());
}
Module OpenCLModuleLoadBinary(void* strm) {
......@@ -297,7 +310,7 @@ Module OpenCLModuleLoadBinary(void* strm) {
stream->Read(&fmt);
stream->Read(&fmap);
stream->Read(&data);
return OpenCLModuleCreate(data, fmt, fmap);
return OpenCLModuleCreate(data, fmt, fmap, std::string());
}
TVM_REGISTER_GLOBAL("module.loadfile_cl")
......@@ -310,6 +323,16 @@ TVM_REGISTER_GLOBAL("module.loadfile_clbin")
*rv = OpenCLModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadfile_xclbin")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadfile_awsxclbin")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadbinary_opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoadBinary(args[0]);
......
......@@ -24,7 +24,8 @@ namespace runtime {
Module OpenCLModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap);
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_
import tvm
import numpy as np
import os
os.environ["XCL_EMULATION_MODE"] = "1"
@tvm.register_func
def tvm_callback_vhls_postproc(code):
"""Hook to inspect the Vivado HLS code before actually run it"""
print(code)
return code
def test_exp():
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: tvm.exp(A(*i)), name='B')
s = tvm.create_schedule(B.op)
# create iter var and assign them tags.
px, x = s[B].split(B.op.axis[0], nparts=1)
s[B].bind(px, tvm.thread_axis("pipeline"))
# one line to build the function.
def check_device(device, host="stackvm"):
if not tvm.module.enabled(host):
return
ctx = tvm.context(device, 0)
if not ctx.exist:
return
fexp = tvm.build(s, [A, B],
device, host,
name="myexp")
ctx = tvm.context(device, 0)
# launch the kernel.
n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fexp(a, b)
np.testing.assert_allclose(
b.asnumpy(), np.exp(a.asnumpy()), rtol=1e-5)
check_device("sdaccel")
if __name__ == "__main__":
test_exp()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment