Commit 98a91af9 by Yao Wang Committed by Wuwei Lin

Improve non_max_suppression and get_valid_counts for CPU (#3305)

* Improve non_max_suppression for CPU

* Improve get_valid_counts

* Minor change

* Skip some unnecessary computes
parent a4bc50eb
...@@ -79,10 +79,16 @@ struct MultiBoxTransformLocAttrs ...@@ -79,10 +79,16 @@ struct MultiBoxTransformLocAttrs
/*! \brief Attributes used in get_valid_counts operator */ /*! \brief Attributes used in get_valid_counts operator */
struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> { struct GetValidCountsAttrs : public tvm::AttrsNode<GetValidCountsAttrs> {
double score_threshold; double score_threshold;
int id_index;
int score_index;
TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") { TVM_DECLARE_ATTRS(GetValidCountsAttrs, "relay.attrs.GetValidCountsAttrs") {
TVM_ATTR_FIELD(score_threshold).set_default(0.0) TVM_ATTR_FIELD(score_threshold).set_default(0.0)
.describe("Lower limit of score for valid bounding boxes."); .describe("Lower limit of score for valid bounding boxes.");
TVM_ATTR_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
TVM_ATTR_FIELD(score_index).set_default(1)
.describe("Index of the scores/confidence of boxes.");
} }
}; };
......
...@@ -569,7 +569,8 @@ def _mx_box_nms(inputs, attrs): ...@@ -569,7 +569,8 @@ def _mx_box_nms(inputs, attrs):
raise tvm.error.OpAttributeInvalid( raise tvm.error.OpAttributeInvalid(
'Value of attribute "out_format" must equal "corner" for operator box_nms.') 'Value of attribute "out_format" must equal "corner" for operator box_nms.')
ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh) ret = _op.vision.get_valid_counts(inputs[0], score_threshold=valid_thresh,
id_index=id_index, score_index=score_index)
nms_out = _op.vision.non_max_suppression(ret[1], nms_out = _op.vision.non_max_suppression(ret[1],
ret[0], ret[0],
iou_threshold=iou_thresh, iou_threshold=iou_thresh,
......
...@@ -82,7 +82,10 @@ def schedule_get_valid_counts(_, outs, target): ...@@ -82,7 +82,10 @@ def schedule_get_valid_counts(_, outs, target):
def compute_get_valid_counts(attrs, inputs, _, target): def compute_get_valid_counts(attrs, inputs, _, target):
"""Compute definition of get_valid_counts""" """Compute definition of get_valid_counts"""
score_threshold = get_const_float(attrs.score_threshold) score_threshold = get_const_float(attrs.score_threshold)
return topi.vision.get_valid_counts(inputs[0], score_threshold) id_index = get_const_int(attrs.id_index)
score_index = get_const_int(attrs.score_index)
return topi.vision.get_valid_counts(inputs[0], score_threshold,
id_index, score_index)
reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE) reg.register_pattern("vision.get_valid_counts", OpPattern.OPAQUE)
......
...@@ -20,7 +20,9 @@ from . import _make ...@@ -20,7 +20,9 @@ from . import _make
from ...expr import TupleWrapper from ...expr import TupleWrapper
def get_valid_counts(data, def get_valid_counts(data,
score_threshold): score_threshold,
id_index=0,
score_index=1):
"""Get valid count of bounding boxes given a score threshold. """Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data. Also moves valid boxes to the top of input data.
...@@ -32,6 +34,12 @@ def get_valid_counts(data, ...@@ -32,6 +34,12 @@ def get_valid_counts(data,
score_threshold : optional, float score_threshold : optional, float
Lower limit of score for valid bounding boxes. Lower limit of score for valid bounding boxes.
id_index : optional, int
index of the class categories, -1 to disable.
score_index: optional, int
Index of the scores/confidence of boxes.
Returns Returns
------- -------
valid_count : relay.Expr valid_count : relay.Expr
...@@ -40,7 +48,8 @@ def get_valid_counts(data, ...@@ -40,7 +48,8 @@ def get_valid_counts(data,
out_tensor : relay.Expr out_tensor : relay.Expr
Rearranged data tensor. Rearranged data tensor.
""" """
return TupleWrapper(_make.get_valid_counts(data, score_threshold), 2) return TupleWrapper(_make.get_valid_counts(data, score_threshold,
id_index, score_index), 2)
def non_max_suppression(data, def non_max_suppression(data,
......
...@@ -50,9 +50,13 @@ bool GetValidCountRel(const Array<Type>& types, ...@@ -50,9 +50,13 @@ bool GetValidCountRel(const Array<Type>& types,
} }
Expr MakeGetValidCounts(Expr data, Expr MakeGetValidCounts(Expr data,
double score_threshold) { double score_threshold,
int id_index,
int score_index) {
auto attrs = make_node<GetValidCountsAttrs>(); auto attrs = make_node<GetValidCountsAttrs>();
attrs->score_threshold = score_threshold; attrs->score_threshold = score_threshold;
attrs->id_index = id_index;
attrs->score_index = score_index;
static const Op& op = Op::Get("vision.get_valid_counts"); static const Op& op = Op::Get("vision.get_valid_counts");
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
} }
......
...@@ -152,28 +152,28 @@ def test_multibox_prior(): ...@@ -152,28 +152,28 @@ def test_multibox_prior():
def test_get_valid_counts(): def test_get_valid_counts():
def verify_get_valid_counts(dshape, score_threshold): def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
dtype = "float32" dtype = "float32"
batch_size, num_anchor, elem_length = dshape batch_size, num_anchor, elem_length = dshape
np_data = np.random.uniform(size=dshape).astype(dtype) np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype)
np_out1 = np.zeros(shape=(batch_size,)) np_out1 = np.zeros(shape=(batch_size,))
np_out2 = np.zeros(shape=dshape).astype(dtype) np_out2 = np.zeros(shape=dshape).astype(dtype)
for i in range(batch_size): for i in range(batch_size):
np_out1[i] = 0 np_out1[i] = 0
inter_idx = 0 inter_idx = 0
for j in range(num_anchor): for j in range(num_anchor):
score = np_data[i, j, 1] score = np_data[i, j, score_index]
if score >= score_threshold: if score > score_threshold and (id_index < 0 or np_data[i, j, id_index] >= 0):
for k in range(elem_length): for k in range(elem_length):
np_out2[i, inter_idx, k] = np_data[i, j, k] np_out2[i, inter_idx, k] = np_data[i, j, k]
np_out1[i] += 1 np_out1[i] += 1
inter_idx += 1 inter_idx += 1
if j >= np_out1[i]: if j >= np_out1[i]:
for k in range(elem_length): for k in range(elem_length):
np_out2[i, j, k] = -1 np_out2[i, j, k] = -1.0
x = relay.var("x", relay.ty.TensorType(dshape, dtype)) x = relay.var("x", relay.ty.TensorType(dshape, dtype))
z = relay.vision.get_valid_counts(x, score_threshold) z = relay.vision.get_valid_counts(x, score_threshold, id_index, score_index)
assert "score_threshold" in z.astext() assert "score_threshold" in z.astext()
func = relay.Function([x], z.astuple()) func = relay.Function([x], z.astuple())
func = relay.ir_pass.infer_type(func) func = relay.ir_pass.infer_type(func)
...@@ -185,10 +185,10 @@ def test_get_valid_counts(): ...@@ -185,10 +185,10 @@ def test_get_valid_counts():
tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04) tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04)
tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04) tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04)
verify_get_valid_counts((1, 2500, 6), 0) verify_get_valid_counts((1, 2500, 6), 0, 0, 1)
verify_get_valid_counts((1, 2500, 6), -1) verify_get_valid_counts((1, 2500, 5), -1, -1, 0)
verify_get_valid_counts((3, 1000, 6), 0.55) verify_get_valid_counts((3, 1000, 6), 0.55, 1, 0)
verify_get_valid_counts((16, 500, 6), 0.95) verify_get_valid_counts((16, 500, 5), 0.95, -1, 0)
def test_non_max_suppression(): def test_non_max_suppression():
......
...@@ -313,7 +313,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): ...@@ -313,7 +313,7 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
@get_valid_counts.register(["cuda", "gpu"]) @get_valid_counts.register(["cuda", "gpu"])
def get_valid_counts_gpu(data, score_threshold=0): def get_valid_counts_gpu(data, score_threshold=0, id_index=0, score_index=1):
"""Get valid count of bounding boxes given a score threshold. """Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data. Also moves valid boxes to the top of input data.
...@@ -325,6 +325,12 @@ def get_valid_counts_gpu(data, score_threshold=0): ...@@ -325,6 +325,12 @@ def get_valid_counts_gpu(data, score_threshold=0):
score_threshold : optional, float score_threshold : optional, float
Lower limit of score for valid bounding boxes. Lower limit of score for valid bounding boxes.
id_index : optional, int
index of the class categories, -1 to disable.
score_index: optional, int
Index of the scores/confidence of boxes.
Returns Returns
------- -------
valid_count : tvm.Tensor valid_count : tvm.Tensor
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements # pylint: disable=import-error, invalid-name, no-member, too-many-locals, too-many-arguments, undefined-variable, too-many-nested-blocks, too-many-branches, too-many-statements, too-many-function-args
"""Non-maximum suppression operator""" """Non-maximum suppression operator"""
import tvm import tvm
...@@ -60,7 +60,7 @@ def hybrid_rearrange_out(data): ...@@ -60,7 +60,7 @@ def hybrid_rearrange_out(data):
@hybrid.script @hybrid.script
def hybrid_get_valid_counts(data, score_threshold): def hybrid_get_valid_counts(data, score_threshold, id_index, score_index):
"""Hybrid routine to get valid count of bounding boxes """Hybrid routine to get valid count of bounding boxes
given a score threshold. Also moves valid boxes to the given a score threshold. Also moves valid boxes to the
top of input data. top of input data.
...@@ -68,11 +68,18 @@ def hybrid_get_valid_counts(data, score_threshold): ...@@ -68,11 +68,18 @@ def hybrid_get_valid_counts(data, score_threshold):
Parameters Parameters
---------- ----------
data : tvm.Tensor or numpy NDArray data : tvm.Tensor or numpy NDArray
Input data. 3-D tensor with shape [batch_size, num_anchors, 6]. Input data. 3-D tensor with shape [batch_size, num_anchors, 6]
or [batch_size, num_anchors, 5].
score_threshold : tvm.const score_threshold : tvm.const
Lower limit of score for valid bounding boxes. Lower limit of score for valid bounding boxes.
id_index : tvm.const
index of the class categories, -1 to disable.
score_index: tvm.const
Index of the scores/confidence of boxes.
Returns Returns
------- -------
out_tensor : tvm.Tensor or numpy NDArray out_tensor : tvm.Tensor or numpy NDArray
...@@ -92,8 +99,9 @@ def hybrid_get_valid_counts(data, score_threshold): ...@@ -92,8 +99,9 @@ def hybrid_get_valid_counts(data, score_threshold):
for i in parallel(batch_size): for i in parallel(batch_size):
valid_count[i] = 0 valid_count[i] = 0
for j in range(num_anchors): for j in range(num_anchors):
score = data[i, j, 1] score = data[i, j, score_index]
if score > score_threshold: if score > score_threshold and \
(id_index < 0 or data[i, j, id_index] >= 0):
for k in range(box_data_length): for k in range(box_data_length):
out_tensor[i, valid_count[i], k] = data[i, j, k] out_tensor[i, valid_count[i], k] = data[i, j, k]
valid_count[i] += 1 valid_count[i] += 1
...@@ -103,18 +111,25 @@ def hybrid_get_valid_counts(data, score_threshold): ...@@ -103,18 +111,25 @@ def hybrid_get_valid_counts(data, score_threshold):
return valid_count, out_tensor return valid_count, out_tensor
@tvm.target.generic_func @tvm.target.generic_func
def get_valid_counts(data, score_threshold=0): def get_valid_counts(data, score_threshold=0, id_index=0, score_index=1):
"""Get valid count of bounding boxes given a score threshold. """Get valid count of bounding boxes given a score threshold.
Also moves valid boxes to the top of input data. Also moves valid boxes to the top of input data.
Parameters Parameters
---------- ----------
data : tvm.Tensor data : tvm.Tensor
Input data. 3-D tensor with shape [batch_size, num_anchors, 6]. Input data. 3-D tensor with shape [batch_size, num_anchors, 6]
or [batch_size, num_anchors, 5].
score_threshold : optional, float score_threshold : optional, float
Lower limit of score for valid bounding boxes. Lower limit of score for valid bounding boxes.
id_index : optional, int
index of the class categories, -1 to disable.
score_index: optional, int
Index of the scores/confidence of boxes.
Returns Returns
------- -------
out_tensor : tvm.Tensor out_tensor : tvm.Tensor
...@@ -123,14 +138,17 @@ def get_valid_counts(data, score_threshold=0): ...@@ -123,14 +138,17 @@ def get_valid_counts(data, score_threshold=0):
valid_count : tvm.Tensor valid_count : tvm.Tensor
1-D tensor for valid number of boxes. 1-D tensor for valid number of boxes.
""" """
score_threshold_const = tvm.const(score_threshold, "float") score_threshold_const = tvm.const(score_threshold, "float32")
return hybrid_get_valid_counts(data, score_threshold_const) id_index_const = tvm.const(id_index, "int32")
score_index_const = tvm.const(score_index, "int32")
return hybrid_get_valid_counts(data, score_threshold_const,
id_index_const, score_index_const)
@hybrid.script @hybrid.script
def hybrid_nms(data, sorted_index, valid_count, def hybrid_nms(data, sorted_index, valid_count,
max_output_size, iou_threshold, force_suppress, max_output_size, iou_threshold, force_suppress,
top_k, coord_start, id_index): top_k, coord_start, id_index, score_index):
"""Hybrid routing for non-maximum suppression. """Hybrid routing for non-maximum suppression.
Parameters Parameters
...@@ -165,6 +183,9 @@ def hybrid_nms(data, sorted_index, valid_count, ...@@ -165,6 +183,9 @@ def hybrid_nms(data, sorted_index, valid_count,
id_index : tvm.const id_index : tvm.const
index of the class categories, -1 to disable. index of the class categories, -1 to disable.
score_index: tvm.const
Index of the scores/confidence of boxes.
Returns Returns
------- -------
output : tvm.Tensor output : tvm.Tensor
...@@ -182,41 +203,42 @@ def hybrid_nms(data, sorted_index, valid_count, ...@@ -182,41 +203,42 @@ def hybrid_nms(data, sorted_index, valid_count,
box_data_length,), box_data_length,),
data.dtype) data.dtype)
for i in parallel(batch_size): for i in range(batch_size):
if iou_threshold > 0: if iou_threshold > 0:
if valid_count[i] > 0: if valid_count[i] > 0:
# Reorder output # Reorder output
nkeep = valid_count[i] nkeep = valid_count[i]
if 0 < top_k < nkeep: if 0 < top_k < nkeep:
nkeep = top_k nkeep = top_k
for j in range(nkeep): for j in parallel(nkeep):
for k in range(box_data_length): for k in range(box_data_length):
output[i, j, k] = data[i, sorted_index[i, j], k] output[i, j, k] = data[i, sorted_index[i, j], k]
box_indices[i, j] = sorted_index[i, j] box_indices[i, j] = sorted_index[i, j]
if 0 < top_k < valid_count[i]: if 0 < top_k < valid_count[i]:
for j in range(valid_count[i] - nkeep): for j in parallel(valid_count[i] - nkeep):
for k in range(box_data_length): for k in range(box_data_length):
output[i, j + nkeep, k] = -1.0 output[i, j + nkeep, k] = -1.0
box_indices[i, j + nkeep] = -1 box_indices[i, j + nkeep] = -1
# Apply nms # Apply nms
box_start_idx = coord_start
batch_idx = i
for j in range(valid_count[i]): for j in range(valid_count[i]):
if output[i, j, 0] >= 0: if output[i, j, score_index] > 0 and (id_index < 0 or output[i, j, id_index] >= 0):
for k in range(valid_count[i]): box_a_idx = j
for k in parallel(valid_count[i]):
check_iou = 0 check_iou = 0
if k > j and output[i, k, 0] >= 0: if k > j and output[i, k, score_index] > 0 \
and (id_index < 0 or output[i, k, id_index] >= 0):
if force_suppress: if force_suppress:
check_iou = 1 check_iou = 1
elif id_index < 0 or output[i, j, 0] == output[i, k, 0]: elif id_index < 0 or output[i, j, id_index] == output[i, k, id_index]:
check_iou = 1 check_iou = 1
if check_iou > 0: if check_iou > 0:
batch_idx = i
box_a_idx = j
box_b_idx = k
box_start_idx = coord_start
a_t = output[batch_idx, box_a_idx, box_start_idx + 1]
a_b = output[batch_idx, box_a_idx, box_start_idx + 3]
a_l = output[batch_idx, box_a_idx, box_start_idx] a_l = output[batch_idx, box_a_idx, box_start_idx]
a_t = output[batch_idx, box_a_idx, box_start_idx + 1]
a_r = output[batch_idx, box_a_idx, box_start_idx + 2] a_r = output[batch_idx, box_a_idx, box_start_idx + 2]
a_b = output[batch_idx, box_a_idx, box_start_idx + 3]
box_b_idx = k
b_t = output[batch_idx, box_b_idx, box_start_idx + 1] b_t = output[batch_idx, box_b_idx, box_start_idx + 1]
b_b = output[batch_idx, box_b_idx, box_start_idx + 3] b_b = output[batch_idx, box_b_idx, box_start_idx + 3]
b_l = output[batch_idx, box_b_idx, box_start_idx] b_l = output[batch_idx, box_b_idx, box_start_idx]
...@@ -227,22 +249,24 @@ def hybrid_nms(data, sorted_index, valid_count, ...@@ -227,22 +249,24 @@ def hybrid_nms(data, sorted_index, valid_count,
u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area u = (a_r - a_l) * (a_b - a_t) + (b_r - b_l) * (b_b - b_t) - area
iou = 0.0 if u <= 0.0 else area / u iou = 0.0 if u <= 0.0 else area / u
if iou >= iou_threshold: if iou >= iou_threshold:
output[i, k, 0] = -1.0 output[i, k, score_index] = -1.0
if id_index >= 0:
output[i, k, id_index] = -1.0
box_indices[i, k] = -1 box_indices[i, k] = -1
else: else:
for j in range(valid_count[i]): for j in parallel(valid_count[i]):
for k in range(box_data_length): for k in range(box_data_length):
output[i, j, k] = data[i, j, k] output[i, j, k] = data[i, j, k]
box_indices[i, j] = j box_indices[i, j] = j
# Set invalid entry to be -1 # Set invalid entry to be -1
for j in range(num_anchors - valid_count[i]): for j in parallel(num_anchors - valid_count[i]):
for k in range(box_data_length): for k in range(box_data_length):
output[i, j + valid_count[i], k] = -1.0 output[i, j + valid_count[i], k] = -1.0
box_indices[i, j + valid_count[i]] = -1 box_indices[i, j + valid_count[i]] = -1
# Only return max_output_size valid boxes # Only return max_output_size valid boxes
num_valid_boxes = 0 num_valid_boxes = 0
if max_output_size > 0: if max_output_size > 0:
for j in range(valid_count[i]): for j in parallel(valid_count[i]):
if output[i, j, 0] >= 0: if output[i, j, 0] >= 0:
if num_valid_boxes == max_output_size: if num_valid_boxes == max_output_size:
for k in range(box_data_length): for k in range(box_data_length):
...@@ -263,9 +287,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1, ...@@ -263,9 +287,7 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
Parameters Parameters
---------- ----------
data : tvm.Tensor data : tvm.Tensor
3-D tensor with shape [batch_size, num_anchors, 6]. 3-D tensor with shape [batch_size, num_anchors, 6] or [batch_size, num_anchors, 5].
The last dimension should be in format of
[class_id, score, box_left, box_top, box_right, box_bottom].
valid_count : tvm.Tensor valid_count : tvm.Tensor
1-D tensor for valid number of boxes. 1-D tensor for valid number of boxes.
...@@ -338,7 +360,8 @@ def non_max_suppression(data, valid_count, max_output_size=-1, ...@@ -338,7 +360,8 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
tvm.const(force_suppress, dtype="bool"), tvm.const(force_suppress, dtype="bool"),
tvm.const(top_k, dtype="int32"), tvm.const(top_k, dtype="int32"),
tvm.const(coord_start, dtype="int32"), tvm.const(coord_start, dtype="int32"),
tvm.const(id_index, dtype="int32")) tvm.const(id_index, dtype="int32"),
tvm.const(score_index, dtype="int32"))
if not return_indices and invalid_to_bottom: if not return_indices and invalid_to_bottom:
out = hybrid_rearrange_out(out) out = hybrid_rearrange_out(out)
......
...@@ -27,18 +27,18 @@ from topi.util import get_const_tuple ...@@ -27,18 +27,18 @@ from topi.util import get_const_tuple
from topi.vision import ssd, non_max_suppression, get_valid_counts from topi.vision import ssd, non_max_suppression, get_valid_counts
def verify_get_valid_counts(dshape, score_threshold): def verify_get_valid_counts(dshape, score_threshold, id_index, score_index):
dtype = "float32" dtype = "float32"
batch_size, num_anchor, elem_length = dshape batch_size, num_anchor, elem_length = dshape
np_data = np.random.uniform(size=dshape).astype(dtype) np_data = np.random.uniform(low=-2, high=2, size=dshape).astype(dtype)
np_out1 = np.zeros(shape=(batch_size,)) np_out1 = np.zeros(shape=(batch_size,))
np_out2 = np.zeros(shape=dshape).astype(dtype) np_out2 = np.zeros(shape=dshape).astype(dtype)
for i in range(batch_size): for i in range(batch_size):
np_out1[i] = 0 np_out1[i] = 0
inter_idx = 0 inter_idx = 0
for j in range(num_anchor): for j in range(num_anchor):
score = np_data[i, j, 1] score = np_data[i, j, score_index]
if score > score_threshold: if score > score_threshold and (id_index < 0 or np_data[i, j, id_index] >= 0):
for k in range(elem_length): for k in range(elem_length):
np_out2[i, inter_idx, k] = np_data[i, j, k] np_out2[i, inter_idx, k] = np_data[i, j, k]
np_out1[i] += 1 np_out1[i] += 1
...@@ -55,8 +55,8 @@ def verify_get_valid_counts(dshape, score_threshold): ...@@ -55,8 +55,8 @@ def verify_get_valid_counts(dshape, score_threshold):
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
data = tvm.placeholder(dshape, name="data", dtype=dtype) data = tvm.placeholder(dshape, name="data", dtype=dtype)
outs = get_valid_counts(data, score_threshold) outs = get_valid_counts(data, score_threshold, id_index, score_index)
s = topi.generic.schedule_multibox_prior(outs) s = topi.generic.schedule_get_valid_counts(outs)
tvm_input_data = tvm.nd.array(np_data, ctx) tvm_input_data = tvm.nd.array(np_data, ctx)
tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx) tvm_out1 = tvm.nd.array(np.zeros(np_out1.shape, dtype="int32"), ctx)
...@@ -67,33 +67,26 @@ def verify_get_valid_counts(dshape, score_threshold): ...@@ -67,33 +67,26 @@ def verify_get_valid_counts(dshape, score_threshold):
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
for device in ['llvm', 'cuda', 'opencl']: for device in ['llvm', 'cuda', 'opencl']:
# Disable gpu test for now
if device != "llvm":
continue
check_device(device) check_device(device)
def test_get_valid_counts(): def test_get_valid_counts():
verify_get_valid_counts((1, 2500, 6), 0) verify_get_valid_counts((1, 2500, 6), 0, 0, 1)
verify_get_valid_counts((1, 2500, 6), -1) verify_get_valid_counts((1, 2500, 5), -1, -1, 0)
verify_get_valid_counts((3, 1000, 6), 0.55) verify_get_valid_counts((3, 1000, 6), 0.55, 1, 0)
verify_get_valid_counts((16, 500, 6), 0.95) verify_get_valid_counts((16, 500, 5), 0.95, -1, 1)
def test_non_max_suppression(): def verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, iou_threshold,
dshape = (1, 5, 6) force_suppress, top_k, coord_start, score_index, id_index):
indices_dshape = (1, 5) dshape = np_data.shape
batch, num_anchors, _ = dshape
indices_dshape = (batch, num_anchors)
data = tvm.placeholder(dshape, name="data") data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") valid_count = tvm.placeholder((batch,), dtype="int32", name="valid_count")
nms_threshold = 0.7
force_suppress = True
nms_topk = 2
np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]]]).astype(data.dtype)
np_valid_count = np.array([4]).astype(valid_count.dtype)
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, -1, -1, -1]])
def check_device(device): def check_device(device):
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
...@@ -103,11 +96,17 @@ def test_non_max_suppression(): ...@@ -103,11 +96,17 @@ def test_non_max_suppression():
print("Running on target: %s" % device) print("Running on target: %s" % device)
with tvm.target.create(device): with tvm.target.create(device):
if device == 'llvm': if device == 'llvm':
out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False) out = non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k,
indices_out = non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk) coord_start=coord_start, score_index=score_index, id_index=id_index,
return_indices=False)
indices_out = non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k,
coord_start=coord_start, score_index=score_index, id_index=id_index)
else: else:
out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk, return_indices=False) out = topi.cuda.non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k,
indices_out = topi.cuda.non_max_suppression(data, valid_count, -1, nms_threshold, force_suppress, nms_topk) coord_start=coord_start, score_index=score_index, id_index=id_index,
return_indices=False)
indices_out = topi.cuda.non_max_suppression(data, valid_count, -1, iou_threshold, force_suppress, top_k,
coord_start=coord_start, score_index=score_index, id_index=id_index)
s = topi.generic.schedule_nms(out) s = topi.generic.schedule_nms(out)
indices_s = topi.generic.schedule_nms(indices_out) indices_s = topi.generic.schedule_nms(indices_out)
...@@ -128,6 +127,30 @@ def test_non_max_suppression(): ...@@ -128,6 +127,30 @@ def test_non_max_suppression():
check_device(device) check_device(device)
def test_non_max_suppression():
np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80],
[0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79],
[1, 0.5, 100, 60, 70, 110]]]).astype("float32")
np_valid_count = np.array([4]).astype("int32")
np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45],
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, -1, -1, -1]])
verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, 0.7, True, 2, 2, 1, 0)
np_data = np.array([[[0.8, 1, 20, 25, 45], [0.7, 30, 60, 50, 80],
[0.4, 4, 21, 19, 40], [0.9, 35, 61, 52, 79],
[0.5, 100, 60, 70, 110]]]).astype("float32")
np_valid_count = np.array([4]).astype("int32")
np_result = np.array([[[0.9, 35, 61, 52, 79], [0.8, 1, 20, 25, 45],
[-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1]]])
np_indices_result = np.array([[3, 0, -1, -1, -1]])
verify_non_max_suppression(np_data, np_valid_count, np_result, np_indices_result, 0.7, False, 2, 1, 0, -1)
def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False): def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False):
data = tvm.placeholder(dshape, name="data") data = tvm.placeholder(dshape, name="data")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment