Commit 85483c37 by MORITA Kazutaka Committed by Thierry Moreau

[TOPI] add injective scheduler for HLS backends (#1553)

* [TOPI] add injective scheduler for HLS backends

* Introduced PrintBinaryExpr
parent 53d24311
...@@ -91,9 +91,11 @@ Target CreateTarget(const std::string& target_name, ...@@ -91,9 +91,11 @@ Target CreateTarget(const std::string& target_name,
} else if (target_name == "sdaccel") { } else if (target_name == "sdaccel") {
t->device_type = kDLOpenCL; t->device_type = kDLOpenCL;
t->keys_array.push_back(ir::StringImm::make("sdaccel")); t->keys_array.push_back(ir::StringImm::make("sdaccel"));
t->keys_array.push_back(ir::StringImm::make("hls"));
} else if (target_name == "aocl") { } else if (target_name == "aocl") {
t->device_type = kDLAOCL; t->device_type = kDLAOCL;
t->keys_array.push_back(ir::StringImm::make("aocl")); t->keys_array.push_back(ir::StringImm::make("aocl"));
t->keys_array.push_back(ir::StringImm::make("hls"));
} else if (target_name == "opengl") { } else if (target_name == "opengl") {
t->device_type = kOpenGL; t->device_type = kOpenGL;
t->keys_array.push_back(ir::StringImm::make("opengl")); t->keys_array.push_back(ir::StringImm::make("opengl"));
......
...@@ -16,6 +16,7 @@ void CodeGenVivadoHLS::Init(bool output_ssa) { ...@@ -16,6 +16,7 @@ void CodeGenVivadoHLS::Init(bool output_ssa) {
CodeGenC::Init(output_ssa); CodeGenC::Init(output_ssa);
this->stream << "#include <ap_int.h>\n\n"; this->stream << "#include <ap_int.h>\n\n";
this->stream << "#include <algorithm>\n\n";
} }
void CodeGenVivadoHLS::PrintType(Type t, std::ostream& os) { void CodeGenVivadoHLS::PrintType(Type t, std::ostream& os) {
...@@ -67,6 +68,46 @@ void CodeGenVivadoHLS::PreFunctionBody(LoweredFunc f) { ...@@ -67,6 +68,46 @@ void CodeGenVivadoHLS::PreFunctionBody(LoweredFunc f) {
this->stream << "#pragma HLS INTERFACE s_axilite port=return bundle=control\n\n"; this->stream << "#pragma HLS INTERFACE s_axilite port=return bundle=control\n\n";
} }
template<typename T>
inline void PrintBinaryExpr(const T* op,
const char *opstr,
std::ostream& os, // NOLINT(*)
CodeGenVivadoHLS* p) {
os << opstr << '(';
p->PrintExpr(op->a, os);
os << ", ";
p->PrintExpr(op->b, os);
os << ')';
}
void CodeGenVivadoHLS::VisitExpr_(const Min *op, std::ostream& os) { // NOLINT(*)
const char *opstr = "std::min";
if (op->type.is_float()) {
switch (op->type.bits()) {
case 32:
opstr = "fminf"; break;
case 64:
opstr = "fmin"; break;
}
}
PrintBinaryExpr(op, opstr, os, this);
}
void CodeGenVivadoHLS::VisitExpr_(const Max *op, std::ostream& os) { // NOLINT(*)
const char *opstr = "std::max";
if (op->type.is_float()) {
switch (op->type.bits()) {
case 32:
opstr = "fmaxf"; break;
case 64:
opstr = "fmax"; break;
}
}
PrintBinaryExpr(op, opstr, os, this);
}
runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) { runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
using tvm::runtime::Registry; using tvm::runtime::Registry;
......
...@@ -20,6 +20,8 @@ class CodeGenVivadoHLS final : public CodeGenC { ...@@ -20,6 +20,8 @@ class CodeGenVivadoHLS final : public CodeGenC {
void PrintType(Type t, std::ostream& os); void PrintType(Type t, std::ostream& os);
void AddFunction(LoweredFunc f); void AddFunction(LoweredFunc f);
void PreFunctionBody(LoweredFunc f); void PreFunctionBody(LoweredFunc f);
void VisitExpr_(const Min *op, std::ostream& os);
void VisitExpr_(const Max *op, std::ostream& os);
}; };
} // namespace codegen } // namespace codegen
......
...@@ -9,6 +9,21 @@ namespace tvm { ...@@ -9,6 +9,21 @@ namespace tvm {
namespace codegen { namespace codegen {
namespace intrin { namespace intrin {
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.floor")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.ceil")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.trunc")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.fabs")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.round")
.set_body(DispatchExtern<Direct>);
TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp") TVM_REGISTER_GLOBAL("tvm.intrin.rule.sdaccel.exp")
.set_body(DispatchExtern<Direct>); .set_body(DispatchExtern<Direct>);
......
...@@ -32,6 +32,7 @@ from . import util ...@@ -32,6 +32,7 @@ from . import util
from . import rocm from . import rocm
from . import vision from . import vision
from . import image from . import image
from . import hls
# not import testing by default # not import testing by default
# because testing can have extra deps that are not necessary # because testing can have extra deps that are not necessary
# we can import them from test cases explicitly # we can import them from test cases explicitly
......
# pylint: disable=redefined-builtin, wildcard-import
"""HLS specific declaration and schedules."""
from __future__ import absolute_import as _abs
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
import tvm
from .. import generic
@generic.schedule_injective.register(["hls"])
def schedule_injective(outs):
"""Schedule for injective op.
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce in the format
of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
for out in outs:
fused = s[out].fuse(*s[out].op.axis)
px, x = s[out].split(fused, nparts=1)
s[out].bind(px, tvm.thread_axis("pipeline"))
return s
schedule_elemwise = schedule_injective
schedule_broadcast = schedule_injective
...@@ -31,6 +31,7 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast): ...@@ -31,6 +31,7 @@ def verify_broadcast_to_ele(in_shape, out_shape, fbcast):
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
check_device("nvptx") check_device("nvptx")
check_device("sdaccel")
def verify_broadcast_binary_ele(lhs_shape, rhs_shape, def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
...@@ -87,6 +88,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape, ...@@ -87,6 +88,7 @@ def verify_broadcast_binary_ele(lhs_shape, rhs_shape,
check_device("metal") check_device("metal")
check_device("rocm") check_device("rocm")
check_device("nvptx") check_device("nvptx")
check_device("sdaccel")
def test_broadcast_to(): def test_broadcast_to():
verify_broadcast_to_ele((1,), (10,), topi.broadcast_to) verify_broadcast_to_ele((1,), (10,), topi.broadcast_to)
......
...@@ -34,7 +34,7 @@ def verify_clip(N, a_min, a_max, dtype): ...@@ -34,7 +34,7 @@ def verify_clip(N, a_min, a_max, dtype):
f(a, b) f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['llvm', 'opencl']: for device in ['llvm', 'opencl', 'sdaccel']:
check_device(device) check_device(device)
def test_clip(): def test_clip():
......
...@@ -39,7 +39,7 @@ def test_ewise(): ...@@ -39,7 +39,7 @@ def test_ewise():
foo(a, b) foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm', 'nvptx']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'llvm', 'nvptx', 'sdaccel']:
check_device(device) check_device(device)
......
...@@ -27,7 +27,7 @@ def verify_relu(m, n): ...@@ -27,7 +27,7 @@ def verify_relu(m, n):
foo(a, b) foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']: for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx', 'sdaccel']:
check_device(device) check_device(device)
......
...@@ -22,7 +22,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis): ...@@ -22,7 +22,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]: for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
check_device(device) check_device(device)
...@@ -45,7 +45,7 @@ def verify_tranpose(in_shape, axes): ...@@ -45,7 +45,7 @@ def verify_tranpose(in_shape, axes):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]: for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
check_device(device) check_device(device)
...@@ -68,7 +68,7 @@ def verify_reshape(src_shape, dst_shape): ...@@ -68,7 +68,7 @@ def verify_reshape(src_shape, dst_shape):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]: for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
check_device(device) check_device(device)
...@@ -96,7 +96,7 @@ def verify_squeeze(src_shape, axis): ...@@ -96,7 +96,7 @@ def verify_squeeze(src_shape, axis):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]: for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
check_device(device) check_device(device)
def verify_concatenate(shapes, axis): def verify_concatenate(shapes, axis):
...@@ -121,7 +121,7 @@ def verify_concatenate(shapes, axis): ...@@ -121,7 +121,7 @@ def verify_concatenate(shapes, axis):
foo(*(data_nds + [out_nd])) foo(*(data_nds + [out_nd]))
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]: for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
check_device(device) check_device(device)
...@@ -146,7 +146,7 @@ def verify_split(src_shape, indices_or_sections, axis): ...@@ -146,7 +146,7 @@ def verify_split(src_shape, indices_or_sections, axis):
for out_nd, out_npy in zip(out_nds, out_npys): for out_nd, out_npy in zip(out_nds, out_npys):
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]: for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
check_device(device) check_device(device)
...@@ -204,7 +204,7 @@ def verify_flip(in_shape, axis): ...@@ -204,7 +204,7 @@ def verify_flip(in_shape, axis):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "cuda", "opencl"]: for device in ["llvm", "cuda", "opencl", "sdaccel"]:
check_device(device) check_device(device)
def verify_take(src_shape, indices_src, axis=None): def verify_take(src_shape, indices_src, axis=None):
...@@ -243,7 +243,7 @@ def verify_take(src_shape, indices_src, axis=None): ...@@ -243,7 +243,7 @@ def verify_take(src_shape, indices_src, axis=None):
foo(data_nd, indices_nd, out_nd) foo(data_nd, indices_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npys) np.testing.assert_allclose(out_nd.asnumpy(), out_npys)
for device in ["llvm", "opencl"]: for device in ["llvm", "opencl", "sdaccel"]:
check_device(device) check_device(device)
def verify_strided_slice(in_shape, begin, end, stride=None): def verify_strided_slice(in_shape, begin, end, stride=None):
...@@ -270,7 +270,7 @@ def verify_strided_slice(in_shape, begin, end, stride=None): ...@@ -270,7 +270,7 @@ def verify_strided_slice(in_shape, begin, end, stride=None):
foo(data_nd, out_nd) foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy) np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "opencl"]: for device in ["llvm", "opencl", "sdaccel"]:
check_device(device) check_device(device)
def test_strided_slice(): def test_strided_slice():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment