Commit 7c4dd0bd by MORITA Kazutaka Committed by Thierry Moreau

[TOPI] add nn schedulers for HLS backends (#1663)

* [TOPI] add nn schedulers for HLS backends

* fix pylint

* fix topi transform test
parent 12839e6d
......@@ -10,5 +10,5 @@ def ctx_list():
device_list = (device_list.split(",") if device_list
else ["llvm", "cuda"])
device_list = set(device_list)
res = [("llvm", tvm.cpu(0)), ("cuda", tvm.gpu(0))]
return [x for x in res if x[1].exist and x[0] in device_list]
res = [(device, tvm.context(device, 0)) for device in device_list]
return [x for x in res if x[1].exist]
......@@ -33,6 +33,8 @@ runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str,
// Compile the .cl file.
std::string cmd = "aoc aocl.cl";
// AOCL supports fp64.
cmd += " -Dcl_khr_fp64";
Target target = Target::create(target_str);
if (target->device_name != "") {
cmd += " -board=" + target->device_name;
......
......@@ -3,3 +3,4 @@
from __future__ import absolute_import as _abs
from .injective import schedule_injective, schedule_elemwise, schedule_broadcast
from .nn import *
# pylint: disable=invalid-name,unused-variable,unused-argument
"""HLS nn operators"""
from __future__ import absolute_import as _abs
import tvm
from .. import tag
from .. import generic
def _schedule_conv2d(outs):
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_injective(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule conv2d
elif OP.tag.find("conv2d") >= 0:
Conv2d = OP.output(0)
if not Conv2d.op in s.outputs:
Out = outs[0].op.output(0)
s[Conv2d].compute_at(s[Out], s[Out].op.axis[1])
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
px, x = s[outs[0]].split(outs[0].op.axis[0], nparts=1)
s[outs[0]].bind(px, tvm.thread_axis("pipeline"))
return s
@generic.schedule_conv2d_nchw.register(["hls"])
def schedule_conv2d_nchw(outs):
"""Schedule for conv2d_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _schedule_conv2d(outs)
@generic.schedule_conv2d_nhwc.register(["hls"])
def schedule_conv2d_nhwc(outs):
"""Schedule for conv2d_nhwc
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _schedule_conv2d(outs)
@generic.schedule_conv2d_NCHWc.register(["hls"])
def schedule_conv2d_NCHWc(num_filter, kernel_size, strides,
padding, layout, out_layout, outs):
"""Schedule for conv2d_NCHW[x]c
Parameters
----------
num_filter : int
The number of filter, i.e., the output channel.
kernel_size : tuple of int
(kernel_height, kernel_width)
strides : tuple of int
(stride_of_height, stride_of_width)
padding : tuple of int
(pad_of_height, pad_of_width)
layout : str
Input data layout
out_layout : str
Output data layout
outs : Array of Tensor
The computation graph description of conv2d_NCHWc
in the format of an array of tensors.
Returns
-------
sch : Schedule
The computation schedule for the op.
"""
return _schedule_conv2d(outs)
@generic.schedule_conv2d_transpose_nchw.register(["hls"])
def schedule_conv2d_transpose_nchw(outs):
"""Schedule for conv2d_transpose_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of conv2d_transpose_nchw
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _schedule_conv2d(outs)
@generic.schedule_depthwise_conv2d_nchw.register(["hls"])
def schedule_depthwise_conv2d_nchw(outs):
"""Schedule for depthwise_conv2d_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of depthwise_conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _schedule_conv2d(outs)
@generic.schedule_depthwise_conv2d_nhwc.register(["hls"])
def schedule_depthwise_conv2d_nhwc(outs):
"""Schedule for depthwise_conv2d_nhwc
Parameters
----------
outs: Array of Tensor
The computation graph description of depthwise_conv2d_nhwc
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _schedule_conv2d(outs)
@generic.schedule_bitserial_conv2d_nchw.register(["hls"])
def schedule_bitserial_conv2d_nchw(outs):
"""Schedule for bitserial_conv2d_nchw
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial_conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _schedule_conv2d(outs)
@generic.schedule_bitserial_conv2d_nhwc.register(["hls"])
def schedule_bitserial_conv2d_nhwc(outs):
"""Schedule for bitserial_conv2d_nhwc
Parameters
----------
outs: Array of Tensor
The computation graph description of bitserial_conv2d_nchw
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
return _schedule_conv2d(outs)
@generic.schedule_reduce.register(["hls"])
def schedule_reduce(outs):
"""Schedule for reduction
Parameters
----------
outs: Array of Tensor
The computation graph description of reduce
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif OP.tag in ["comm_reduce", "comm_reduce_idx"]:
if OP.tag == "comm_reduce":
Reduce = OP.output(0)
else:
Reduce = OP.input_tensors[0]
if not Reduce.op in s.outputs:
Out = outs[0].op.output(0)
s[Reduce].compute_at(s[Out], s[Out].op.axis[0])
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
fused = s[outs[0]].fuse()
px, x = s[outs[0]].split(fused, nparts=1)
s[outs[0]].bind(px, tvm.thread_axis("pipeline"))
return s
@generic.schedule_softmax.register(["hls"])
def schedule_softmax(outs):
"""Schedule for softmax
Parameters
----------
outs: Array of Tensor
The computation graph description of softmax
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
softmax = outs[0]
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
s[expsum].compute_at(s[softmax], s[softmax].op.axis[1])
s[max_elem].compute_at(s[softmax], s[softmax].op.axis[1])
px, x = s[softmax].split(softmax.op.axis[0], nparts=1)
s[softmax].bind(px, tvm.thread_axis("pipeline"))
return s
@generic.schedule_dense.register(["hls"])
def schedule_dense(outs):
"""Schedule for dense
Parameters
----------
outs: Array of Tensor
The computation graph description of dense
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule dense
elif OP.tag == 'dense':
Dense = OP.output(0)
if not Dense.op in s.outputs:
Out = outs[0].op.output(0)
s[Dense].compute_at(s[Out], s[Out].op.axis[1])
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
px, x = s[outs[0]].split(outs[0].op.axis[0], nparts=1)
s[outs[0]].bind(px, tvm.thread_axis("pipeline"))
return s
@generic.schedule_pool.register(["hls"])
def schedule_pool(outs, layout):
"""Schedule for pool
Parameters
----------
outs: Array of Tensor
The computation graph description of pool
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule pool
elif OP.tag.startswith('pool'):
Pool = OP.output(0)
if not Pool.op in s.outputs:
Out = outs[0].op.output(0)
s[Pool].compute_at(s[Out], s[Out].op.axis[1])
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
px, x = s[outs[0]].split(outs[0].op.axis[0], nparts=1)
s[outs[0]].bind(px, tvm.thread_axis("pipeline"))
return s
@generic.schedule_global_pool.register(["hls"])
def schedule_global_pool(outs):
"""Schedule for global pool
Parameters
----------
outs: Array of Tensor
The computation graph description of global pool
in the format of an array of tensors.
Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
tvm.schedule.AutoInlineInjective(s)
def traverse(OP):
"""Internal travserse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
# schedule global_pool
elif OP.tag.startswith('global_pool'):
Pool = OP.output(0)
if not Pool.op in s.outputs:
Out = outs[0].op.output(0)
s[Pool].compute_at(s[Out], s[Out].op.axis[1])
else:
raise RuntimeError("Unsupported operator: %s" % OP.tag)
traverse(outs[0].op)
px, x = s[outs[0]].split(outs[0].op.axis[0], nparts=1)
s[outs[0]].bind(px, tvm.thread_axis("pipeline"))
return s
......@@ -9,4 +9,4 @@ def get_all_backend():
A list of all supported targets
"""
return ['llvm', 'cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx',
'llvm -device=arm_cpu']
'llvm -device=arm_cpu', 'aocl_sw_emu']
......@@ -5,6 +5,8 @@ import topi
import math
from topi.util import get_const_tuple
from common import get_all_backend
def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_pad=True):
iw = ih
kw = kh
......@@ -64,7 +66,7 @@ def verify_pool(n, ic, ih, kh, sh, padding, pool_type, ceil_mode, count_include_
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
for device in get_all_backend():
check_device(device)
def test_pool():
......@@ -109,7 +111,7 @@ def verify_global_pool(n, c, h, w, pool_type):
f(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
for device in get_all_backend():
check_device(device)
def test_global_pool():
......
......@@ -4,6 +4,8 @@ import numpy as np
import tvm
import topi
from common import get_all_backend
def _my_npy_argmax(arr, axis, keepdims):
if not keepdims:
return arr.argmax(axis=axis)
......@@ -90,7 +92,7 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum", dtype="float32")
np.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1E-3, 1E-3)
else:
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
for device in ["cuda", "opencl", "metal", "llvm", "rocm", "vulkan", "nvptx"]:
for device in get_all_backend():
check_device(device)
......
......@@ -5,6 +5,8 @@ import tvm
import topi
from topi.util import get_const_tuple
from common import get_all_backend
def verify_relu(m, n):
A = tvm.placeholder((m, n), name='A')
B = topi.nn.relu(A)
......@@ -27,7 +29,7 @@ def verify_relu(m, n):
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx', 'sdaccel']:
for device in get_all_backend():
check_device(device)
......
......@@ -7,6 +7,8 @@ import topi.testing
import logging
from topi.util import get_const_tuple
from common import get_all_backend
def verify_softmax(m, n, dtype="float32"):
A = tvm.placeholder((m, n), dtype=dtype, name='A')
B = topi.nn.softmax(A)
......@@ -63,7 +65,7 @@ def verify_log_softmax(m, n, dtype="float32"):
foo(a, b)
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
for device in ["cuda", "opencl", "metal", "rocm", "vulkan", "nvptx"]:
for device in get_all_backend():
check_device(device)
......
......@@ -3,6 +3,8 @@ import numpy as np
import tvm
import topi
from common import get_all_backend
def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
A = tvm.placeholder(shape=in_shape, name="A")
B = topi.expand_dims(A, axis, num_newaxis)
......@@ -22,7 +24,7 @@ def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
for device in get_all_backend():
check_device(device)
......@@ -45,7 +47,7 @@ def verify_tranpose(in_shape, axes):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
for device in get_all_backend():
check_device(device)
......@@ -68,7 +70,7 @@ def verify_reshape(src_shape, dst_shape):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
for device in get_all_backend():
check_device(device)
......@@ -96,7 +98,7 @@ def verify_squeeze(src_shape, axis):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
for device in get_all_backend():
check_device(device)
def verify_concatenate(shapes, axis):
......@@ -121,7 +123,7 @@ def verify_concatenate(shapes, axis):
foo(*(data_nds + [out_nd]))
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
for device in get_all_backend():
check_device(device)
......@@ -146,7 +148,7 @@ def verify_split(src_shape, indices_or_sections, axis):
for out_nd, out_npy in zip(out_nds, out_npys):
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan", "sdaccel"]:
for device in get_all_backend():
check_device(device)
......@@ -204,7 +206,7 @@ def verify_flip(in_shape, axis):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "cuda", "opencl", "sdaccel"]:
for device in ["llvm", "cuda", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(device)
def verify_take(src_shape, indices_src, axis=None):
......@@ -243,7 +245,7 @@ def verify_take(src_shape, indices_src, axis=None):
foo(data_nd, indices_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npys)
for device in ["llvm", "opencl", "sdaccel"]:
for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(device)
def verify_strided_slice(in_shape, begin, end, stride=None):
......@@ -270,7 +272,7 @@ def verify_strided_slice(in_shape, begin, end, stride=None):
foo(data_nd, out_nd)
np.testing.assert_allclose(out_nd.asnumpy(), out_npy)
for device in ["llvm", "opencl", "sdaccel"]:
for device in ["llvm", "opencl", "sdaccel", "aocl_sw_emu"]:
check_device(device)
def test_strided_slice():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment