Commit 85483c37 by MORITA Kazutaka Committed by Thierry Moreau

[TOPI] add injective scheduler for HLS backends (#1553)

* [TOPI] add injective scheduler for HLS backends

* Introduced PrintBinaryExpr
parent 53d24311
......@@ -91,9 +91,11 @@ Target CreateTarget(const std::string& target_name,
} else if (target_name == "sdaccel") {
t->device_type = kDLOpenCL;
t->keys_array.push_back(ir::StringImm::make("sdaccel"));
t->keys_array.push_back(ir::StringImm::make("hls"));
} else if (target_name == "aocl") {
t->device_type = kDLAOCL;
t->keys_array.push_back(ir::StringImm::make("aocl"));
t->keys_array.push_back(ir::StringImm::make("hls"));
} else if (target_name == "opengl") {
t->device_type = kOpenGL;
t->keys_array.push_back(ir::StringImm::make("opengl"));
......
......@@ -16,6 +16,7 @@ void CodeGenVivadoHLS::Init(bool output_ssa) {
CodeGenC::Init(output_ssa);
this->stream << "#include <ap_int.h>\n\n";
this->stream << "#include <algorithm>\n\n";
}
void CodeGenVivadoHLS::PrintType(Type t, std::ostream& os) {
......@@ -67,6 +68,46 @@ void CodeGenVivadoHLS::PreFunctionBody(LoweredFunc f) {
this->stream << "#pragma HLS INTERFACE s_axilite port=return bundle=control\n\n";
}
template<typename T>
inline void PrintBinaryExpr(const T* op,
const char *opstr,
std::ostream& os, // NOLINT(*)
CodeGenVivadoHLS* p) {
os << opstr << '(';
p->PrintExpr(op->a, os);
os << ", ";
p->PrintExpr(op->b, os);
os << ')';
}
void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
const char *opstr = "std::min";
if (op->type.is_float()) {
switch (op->type.bits()) {
case 32:
opstr = "fminf"; break;
case 64:
opstr = "fmin"; break;
}
}
PrintBinaryExpr(op, opstr, os, this);
}
void CodeGenVivadoHLS::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
const char *opstr = "std::max";
if (op->type.is_float()) {
switch (op->type.bits()) {
case 32:
opstr = "fmaxf"; break;
case 64:
opstr = "fmax"; break;
}
}
PrintBinaryExpr(op, opstr, os, this);
}
runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
using tvm::runtime::Registry;
......
......@@ -20,6 +20,8 @@ class CodeGenVivadoHLS final : public CodeGenC {
void PrintType(Type t, std::ostream& os);
void AddFunction(LoweredFunc f);
void PreFunctionBody(LoweredFunc f);
void VisitExpr_(const Min *op, std::ostream& os);
void VisitExpr_(const Max *op, std::ostream& os);
};
} // namespace codegen
......
......@@ -9,6 +9,21 @@ namespace tvm {
namespace codegen {
namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp")
.set_body(DispatchExtern<Direct>);
......
......@@ -32,6 +32,7 @@ from . import util
from . import rocm
from . import vision
from . import image
from . import hls
# not import testing by default
# because testing can have extra deps that are not necessary
# we can import them from test cases explicitly
......
# pylint: disable=redefined-builtin, wildcard-import
"""HLS specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
import tvm
from .. import generic
@generic.schedule_injective.register(["hls"])
def schedule_injective(outs):
"""Schedule for injective op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
for out in outs:
fused = s[out].fuse(*s[out].op.axis)
px, x = s[out].split(fused, nparts=1)
s[out].bind(px, tvm.thread_axis("pipeline"))
return s
schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
......@@ -31,6 +31,7 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
check_device("metal")
check_device("rocm")
check_device("nvptx")
check_device("sdaccel")
def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
......@@ -87,6 +88,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
check_device("metal")
check_device("rocm")
check_device("nvptx")
check_device("sdaccel")
def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,), topi.broadcast_to)
......
......@@ -34,7 +34,7 @@ def verify_clip(N, a_min, a_max, dtype):
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'opencl']:
for device in ['llvm', 'opencl', 'sdaccel']:
check_device(device)
def test_clip():
......
......@@ -39,7 +39,7 @@ def test_ewise():
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm', 'nvptx']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm', 'nvptx', 'sdaccel']:
check_device(device)
......
......@@ -27,7 +27,7 @@ def verify_relu(m, n):
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx', 'sdaccel']:
check_device(device)
......
......@@ -22,7 +22,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
check_device(device)
......@@ -45,7 +45,7 @@ def verify_tranpose(in_shape, axes):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
check_device(device)
......@@ -68,7 +68,7 @@ def verify_reshape(src_shape, dst_shape):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
check_device(device)
......@@ -96,7 +96,7 @@ def verify_squeeze(src_shape, axis):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
check_device(device)
def verify_concatenate(shapes, axis):
......@@ -121,7 +121,7 @@ def verify_concatenate(shapes, axis):
foo(*(data_nds + [out_nd]))
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
check_device(device)
......@@ -146,7 +146,7 @@ def verify_split(src_shape, indices_or_sections, axis):
for out_nd, out_npy in zip(out_nds, out_npys):
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
check_device(device)
......@@ -204,7 +204,7 @@ def verify_flip(in_shape, axis):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "cuda", "opencl"]:
for device in ["llvm", "cuda", "opencl", "sdaccel"]:
check_device(device)
def verify_take(src_shape, indices_src, axis=None):
......@@ -243,7 +243,7 @@ def verify_take(src_shape, indices_src, axis=None):
foo(data_nd, indices_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npys)
for device in ["llvm", "opencl"]:
for device in ["llvm", "opencl", "sdaccel"]:
check_device(device)
def verify_strided_slice(in_shape, begin, end, stride=None):
......@@ -270,7 +270,7 @@ def verify_strided_slice(in_shape, begin, end, stride=None):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "opencl"]:
for device in ["llvm", "opencl", "sdaccel"]:
check_device(device)
def test_strided_slice():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment