Unverified Commit 1c56c722 by Yizhi Liu Committed by GitHub

[topi] enable fp16 sort for arm (#4084)

parent ec375a85
...@@ -74,9 +74,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") ...@@ -74,9 +74,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms")
// Currently only supports input dtype to be float32. // Currently only supports input dtype to be float32.
CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype "
"to be float32."; "to be float.";
#if (__ARM_FP16_FORMAT_IEEE != 1)
CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype "
"to be float32."; "to be float32.";
#endif
CHECK_LT(axis, input->ndim) << "Axis out of boundary for " CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
"input ndim " << input->ndim; "input ndim " << input->ndim;
...@@ -98,9 +100,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") ...@@ -98,9 +100,25 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms")
sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
} }
if (is_ascend) { if (is_ascend) {
#if (__ARM_FP16_FORMAT_IEEE == 1)
if (dtype.bits == 16) {
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<__fp16>);
} else {
#endif
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>); std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
#if (__ARM_FP16_FORMAT_IEEE == 1)
}
#endif
} else { } else {
#if (__ARM_FP16_FORMAT_IEEE == 1)
if (dtype.bits == 16) {
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<__fp16>);
} else {
#endif
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>); std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
#if (__ARM_FP16_FORMAT_IEEE == 1)
}
#endif
} }
for (int32_t k = 0; k < input->shape[axis]; ++k) { for (int32_t k = 0; k < input->shape[axis]; ++k) {
*(static_cast<int32_t *>(output->data) + base_idx + k * axis_mul_after) *(static_cast<int32_t *>(output->data) + base_idx + k * axis_mul_after)
...@@ -192,6 +210,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") ...@@ -192,6 +210,14 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
} else { } else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype; LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
} }
#if (__ARM_FP16_FORMAT_IEEE == 1)
} else if (data_dtype == "float16") {
if (out_dtype == "float16") {
argsort<__fp16, __fp16>(input, output, axis, is_ascend);
} else {
LOG(FATAL) << "Unsupported output dtype: " << out_dtype;
}
#endif
} else if (data_dtype == "int32") { } else if (data_dtype == "int32") {
if (out_dtype == "int32") { if (out_dtype == "int32") {
argsort<int32_t, int32_t>(input, output, axis, is_ascend); argsort<int32_t, int32_t>(input, output, axis, is_ascend);
......
...@@ -22,7 +22,7 @@ from tvm import hybrid ...@@ -22,7 +22,7 @@ from tvm import hybrid
from ..sort import argsort from ..sort import argsort
@hybrid.script @hybrid.script
def hybrid_rearrange_out(data): def hybrid_rearrange_out(data, one):
"""Hybrid routine to rearrange nms output to """Hybrid routine to rearrange nms output to
move all valid entries to top. move all valid entries to top.
...@@ -32,6 +32,9 @@ def hybrid_rearrange_out(data): ...@@ -32,6 +32,9 @@ def hybrid_rearrange_out(data):
NMS output. 3-D tensor with shape NMS output. 3-D tensor with shape
[batch_size, num_anchors, 6]. [batch_size, num_anchors, 6].
one: tvm.const
Constant one with the same dtype as data.
Returns Returns
------- -------
output : tvm.Tensor or numpy NDArray output : tvm.Tensor or numpy NDArray
...@@ -55,12 +58,12 @@ def hybrid_rearrange_out(data): ...@@ -55,12 +58,12 @@ def hybrid_rearrange_out(data):
valid_idx += 1 valid_idx += 1
if j >= valid_idx: if j >= valid_idx:
for k in range(elem_length): for k in range(elem_length):
output[i, j, k] = -1.0 output[i, j, k] = -one
return output return output
@hybrid.script @hybrid.script
def hybrid_get_valid_counts(data, score_threshold, id_index, score_index): def hybrid_get_valid_counts(data, score_threshold, id_index, score_index, one):
"""Hybrid routine to get valid count of bounding boxes """Hybrid routine to get valid count of bounding boxes
given a score threshold. Also moves valid boxes to the given a score threshold. Also moves valid boxes to the
top of input data. top of input data.
...@@ -80,6 +83,9 @@ def hybrid_get_valid_counts(data, score_threshold, id_index, score_index): ...@@ -80,6 +83,9 @@ def hybrid_get_valid_counts(data, score_threshold, id_index, score_index):
score_index: tvm.const score_index: tvm.const
Index of the scores/confidence of boxes. Index of the scores/confidence of boxes.
one: tvm.const
Constant one with the same dtype as data.
Returns Returns
------- -------
out_tensor : tvm.Tensor or numpy NDArray out_tensor : tvm.Tensor or numpy NDArray
...@@ -107,7 +113,7 @@ def hybrid_get_valid_counts(data, score_threshold, id_index, score_index): ...@@ -107,7 +113,7 @@ def hybrid_get_valid_counts(data, score_threshold, id_index, score_index):
valid_count[i] += 1 valid_count[i] += 1
if j >= valid_count[i]: if j >= valid_count[i]:
for k in range(box_data_length): for k in range(box_data_length):
out_tensor[i, j, k] = -1.0 out_tensor[i, j, k] = -one
return valid_count, out_tensor return valid_count, out_tensor
@tvm.target.generic_func @tvm.target.generic_func
...@@ -138,17 +144,18 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1): ...@@ -138,17 +144,18 @@ def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
valid_count : tvm.Tensor valid_count : tvm.Tensor
1-D tensor for valid number of boxes. 1-D tensor for valid number of boxes.
""" """
score_threshold_const = tvm.const(score_threshold, "float32") score_threshold_const = tvm.const(score_threshold, data.dtype)
id_index_const = tvm.const(id_index, "int32") id_index_const = tvm.const(id_index, "int32")
score_index_const = tvm.const(score_index, "int32") score_index_const = tvm.const(score_index, "int32")
return hybrid_get_valid_counts(data, score_threshold_const, return hybrid_get_valid_counts(data, score_threshold_const,
id_index_const, score_index_const) id_index_const, score_index_const,
tvm.const(1, data.dtype))
@hybrid.script @hybrid.script
def hybrid_nms(data, sorted_index, valid_count, def hybrid_nms(data, sorted_index, valid_count,
max_output_size, iou_threshold, force_suppress, max_output_size, iou_threshold, force_suppress,
top_k, coord_start, id_index, score_index): top_k, coord_start, id_index, score_index, zero, one):
"""Hybrid routing for non-maximum suppression. """Hybrid routing for non-maximum suppression.
Parameters Parameters
...@@ -186,6 +193,12 @@ def hybrid_nms(data, sorted_index, valid_count, ...@@ -186,6 +193,12 @@ def hybrid_nms(data, sorted_index, valid_count,
score_index: tvm.const score_index: tvm.const
Index of the scores/confidence of boxes. Index of the scores/confidence of boxes.
zero: tvm.const
Constant zero with the same dtype as data.
one: tvm.const
Constant one with the same dtype as data.
Returns Returns
------- -------
output : tvm.Tensor output : tvm.Tensor
...@@ -200,8 +213,7 @@ def hybrid_nms(data, sorted_index, valid_count, ...@@ -200,8 +213,7 @@ def hybrid_nms(data, sorted_index, valid_count,
box_indices = output_tensor((batch_size, num_anchors), "int32") box_indices = output_tensor((batch_size, num_anchors), "int32")
output = output_tensor((batch_size, output = output_tensor((batch_size,
num_anchors, num_anchors,
box_data_length,), box_data_length,), data.dtype)
data.dtype)
for i in range(batch_size): for i in range(batch_size):
if iou_threshold > 0: if iou_threshold > 0:
...@@ -217,7 +229,7 @@ def hybrid_nms(data, sorted_index, valid_count, ...@@ -217,7 +229,7 @@ def hybrid_nms(data, sorted_index, valid_count,
if 0 < top_k < valid_count[i]: if 0 < top_k < valid_count[i]:
for j in parallel(valid_count[i] - nkeep): for j in parallel(valid_count[i] - nkeep):
for k in range(box_data_length): for k in range(box_data_length):
output[i, j + nkeep, k] = -1.0 output[i, j + nkeep, k] = -one
box_indices[i, j + nkeep] = -1 box_indices[i, j + nkeep] = -1
# Apply nms # Apply nms
box_start_idx = coord_start box_start_idx = coord_start
...@@ -243,15 +255,15 @@ def hybrid_nms(data, sorted_index, valid_count, ...@@ -243,15 +255,15 @@ def hybrid_nms(data, sorted_index, valid_count,
b_b = output[batch_idx, box_b_idx, box_start_idx + 3] b_b = output[batch_idx, box_b_idx, box_start_idx + 3]
b_l = output[batch_idx, box_b_idx, box_start_idx] b_l = output[batch_idx, box_b_idx, box_start_idx]
b_r = output[batch_idx, box_b_idx, box_start_idx + 2] b_r = output[batch_idx, box_b_idx, box_start_idx + 2]
w = max(0.0, min(a_r, b_r) - max(a_l, b_l)) w = max(zero, min(a_r, b_r) - max(a_l, b_l))
h = max(0.0, min(a_b, b_b) - max(a_t, b_t)) h = max(zero, min(a_b, b_b) - max(a_t, b_t))
area = h * w area = h * w
u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area
iou = 0.0 if u <= 0.0 else area / u iou = zero if u <= zero else area / u
if iou >= iou_threshold: if iou >= iou_threshold:
output[i, k, score_index] = -1.0 output[i, k, score_index] = -one
if id_index >= 0: if id_index >= 0:
output[i, k, id_index] = -1.0 output[i, k, id_index] = -one
box_indices[i, k] = -1 box_indices[i, k] = -1
else: else:
for j in parallel(valid_count[i]): for j in parallel(valid_count[i]):
...@@ -261,16 +273,16 @@ def hybrid_nms(data, sorted_index, valid_count, ...@@ -261,16 +273,16 @@ def hybrid_nms(data, sorted_index, valid_count,
# Set invalid entry to be -1 # Set invalid entry to be -1
for j in parallel(num_anchors - valid_count[i]): for j in parallel(num_anchors - valid_count[i]):
for k in range(box_data_length): for k in range(box_data_length):
output[i, j + valid_count[i], k] = -1.0 output[i, j + valid_count[i], k] = -one
box_indices[i, j + valid_count[i]] = -1 box_indices[i, j + valid_count[i]] = -1
# Only return max_output_size valid boxes # Only return max_output_size valid boxes
num_valid_boxes = 0 num_valid_boxes = 0
if max_output_size > 0: if max_output_size > 0:
for j in parallel(valid_count[i]): for j in parallel(valid_count[i]):
if output[i, j, 0] >= 0: if output[i, j, 0] >= zero:
if num_valid_boxes == max_output_size: if num_valid_boxes == max_output_size:
for k in range(box_data_length): for k in range(box_data_length):
output[i, j, k] = -1.0 output[i, j, k] = -one
box_indices[i, j] = -1 box_indices[i, j] = -1
else: else:
num_valid_boxes += 1 num_valid_boxes += 1
...@@ -356,13 +368,15 @@ def non_max_suppression(data, valid_count, max_output_size=-1, ...@@ -356,13 +368,15 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False)
out, box_indices = hybrid_nms(data, sort_tensor, valid_count, out, box_indices = hybrid_nms(data, sort_tensor, valid_count,
tvm.const(max_output_size, dtype="int32"), tvm.const(max_output_size, dtype="int32"),
tvm.const(iou_threshold, dtype="float32"), tvm.const(iou_threshold, dtype=data.dtype),
tvm.const(force_suppress, dtype="bool"), tvm.const(force_suppress, dtype="bool"),
tvm.const(top_k, dtype="int32"), tvm.const(top_k, dtype="int32"),
tvm.const(coord_start, dtype="int32"), tvm.const(coord_start, dtype="int32"),
tvm.const(id_index, dtype="int32"), tvm.const(id_index, dtype="int32"),
tvm.const(score_index, dtype="int32")) tvm.const(score_index, dtype="int32"),
zero=tvm.const(0, dtype=data.dtype),
one=tvm.const(1, dtype=data.dtype))
if not return_indices and invalid_to_bottom: if not return_indices and invalid_to_bottom:
out = hybrid_rearrange_out(out) out = hybrid_rearrange_out(out, one=tvm.const(1, dtype=data.dtype))
return box_indices if return_indices else out return box_indices if return_indices else out
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment