Commit a706ad16 by Leyuan Wang Committed by Haichen Shen

[Relay][TOPI] Gluncv SSD support on the GPU (#2784)

* ssd gluoncv gpu op updated

* ssd gluoncv gpu op updated

* tutorials and testes modified

* tutorials and testes modified

* fix lint

* fix lint

* address comment

* multibox bug fixed

* space line added

* use less threads per block

* use less threads per block

* less threads per block for get valid count

* less threads per block for get valid count

* merge with master

* Revert "less threads per block for get valid count"

This reverts commit 08896cfccc34b0b2a1646d01d01ea4cad73941c4.

* Revert "less threads per block for get valid count"

This reverts commit 08896cfccc34b0b2a1646d01d01ea4cad73941c4.

* typo fixed

* elem length made to a variable

* fix lint error

* fix lint error

* lint fixed

* bug fixed

* bug fixed

* lint fixed

* error fixed

* error fixed

* test ci

* test ci

* seperate argsort to be an independent op

* seperate argsort to be an independent op

* fix lint

* fix lint

* remove unsupported models

* typo fixed

* argsort added to realy

* solve conflicts with master

* fix lint

* fix lint

* test push

* Revert "test push"

This reverts commit 6db00883fab6cc06bddf564c926bb27c874397d8.

* fix lint error

* fix more lint

* cpu test_sort udpated

* debug ci

* nms fixed

* expose argsort to relay frontend

* test ci

* fix lint

* sort register error fixed

* fix nnvm

* nms type fixed

* adaptive pooling added to relay

* Revert "adaptive pooling added to relay"

This reverts commit 1119f1f2c055753e0cc5611627597749134c5c8c.

* fix lint

* expose argsort op

* fix lint

* fix lint

* fix lint

* sort test updated

* sort bug fixed

* nnvm error fixed

* fix argsort default data type returned to be float insteaf of int

* fix lint

* fix lint

* test fixed

* fix valid count

* fix titanx bug

* tutorial add both targets

* titanx error fixed

* try to fix CI old gpu error

* try to solve CI GPU error

* get_valid_count added

* reverse get_valid_count

* get valid count optimized

* address comments

* fix ci error

* remove unessesary block sync

* add back one sync

* address comments

* address more comments

* more comments

* move sort to be indepent algorithm

* typo fixed

* more typos

* comments addressed

* doc updated

* fix pylint

* address final comments

* apache license added
parent 9d002e8e
......@@ -165,6 +165,14 @@ This level enables additional math and transform operators.
tvm.relay.vision.yolo_reorg
**Level 6: Algorithm Operators**
.. autosummary::
:nosignatures:
tvm.relay.argsort
**Level 10: Temporary Operators**
This level support backpropagation of broadcast operators. It is temporary.
......@@ -294,6 +302,11 @@ Level 5 Definitions
.. autofunction:: tvm.relay.vision.yolo_reorg
Level 6 Definitions
-------------------
.. autofunction:: tvm.relay.argsort
Level 10 Definitions
--------------------
.. autofunction:: tvm.relay.broadcast_to_like
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relay/attrs/vision.h
* \brief Auxiliary attributes for vision operators.
*/
#ifndef TVM_RELAY_ATTRS_ALGORITHM_H_
#define TVM_RELAY_ATTRS_ALGORITHM_H_
#include <tvm/attrs.h>
#include <string>
namespace tvm {
namespace relay {
/*! \brief Attributes used in argsort operators */
struct ArgsortAttrs : public tvm::AttrsNode<ArgsortAttrs> {
int axis;
bool is_ascend;
DataType dtype;
TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs") {
TVM_ATTR_FIELD(axis).set_default(-1)
.describe("Axis along which to sort the input tensor."
"If not given, the flattened array is used.");
TVM_ATTR_FIELD(is_ascend).set_default(true)
.describe("Whether to sort in ascending or descending order."
"By default, sort in ascending order");
TVM_ATTR_FIELD(dtype).set_default(NullValue<DataType>())
.describe("DType of the output indices.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ALGORITHM_H_
......@@ -92,6 +92,8 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
double iou_threshold;
bool force_suppress;
int top_k;
int coord_start;
int score_index;
int id_index;
bool return_indices;
bool invalid_to_bottom;
......@@ -106,6 +108,10 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode<NonMaximumSuppressionA
.describe("Suppress all detections regardless of class_id.");
TVM_ATTR_FIELD(top_k).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
TVM_ATTR_FIELD(coord_start).set_default(2)
.describe("Start index of the consecutive 4 coordinates.");
TVM_ATTR_FIELD(score_index).set_default(1)
.describe("Index of the scores/confidence of boxes.");
TVM_ATTR_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
TVM_ATTR_FIELD(return_indices).set_default(true)
......
......@@ -488,6 +488,8 @@ struct NonMaximumSuppressionParam : public dmlc::Parameter<NonMaximumSuppression
bool force_suppress;
int top_k;
int id_index;
int coord_start;
int score_index;
int max_output_size;
bool invalid_to_bottom;
DMLC_DECLARE_PARAMETER(NonMaximumSuppressionParam) {
......@@ -500,6 +502,10 @@ struct NonMaximumSuppressionParam : public dmlc::Parameter<NonMaximumSuppression
.describe("Suppress all detections regardless of class_id.");
DMLC_DECLARE_FIELD(top_k).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
DMLC_DECLARE_FIELD(coord_start).set_default(2)
.describe("Start index of the consecutive 4 coordinates.");
DMLC_DECLARE_FIELD(score_index).set_default(1)
.describe("Index of the scores/confidence of boxes.");
DMLC_DECLARE_FIELD(id_index).set_default(0)
.describe("Axis index of id.");
DMLC_DECLARE_FIELD(return_indices).set_default(true)
......
......@@ -94,8 +94,12 @@ def compute_nms(attrs, inputs, _):
id_index = attrs.get_int('id_index')
invalid_to_bottom = attrs.get_bool('invalid_to_bottom')
return topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
return topi.vision.non_max_suppression(inputs[0], inputs[1],
max_output_size=max_output_size,
iou_threshold=iou_threshold,
force_suppress=force_suppress,
top_k=top_k, id_index=id_index,
return_indices=return_indices,
invalid_to_bottom=invalid_to_bottom)
reg.register_pattern("non_max_suppression", OpPattern.OPAQUE)
......@@ -543,14 +543,13 @@ def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1),
if clip:
np_out = np.clip(np_out, 0, 1)
target = "llvm"
ctx = tvm.cpu()
graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape})
m = graph_runtime.create(graph, lib, ctx)
m.set_input("data", np.random.uniform(size=dshape).astype(dtype))
m.run()
out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype))
tvm.testing.assert_allclose(out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape})
m = graph_runtime.create(graph, lib, ctx)
m.set_input("data", np.random.uniform(size=dshape).astype(dtype))
m.run()
tvm_out = m.get_output(0, tvm.nd.empty(np_out.shape, dtype))
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_out, atol=1e-5, rtol=1e-5)
def test_multibox_prior():
verify_multibox_prior((1, 3, 50, 50))
......@@ -577,17 +576,16 @@ def test_multibox_transform_loc():
[0, 0.44999999, 1, 1, 1, 1],
[0, 0.30000001, 0, 0, 0.22903419, 0.20435292]]])
target = "llvm"
dtype = "float32"
ctx = tvm.cpu()
graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes),
"loc_preds": (batch_size, num_anchors * 4),
"anchors": (1, num_anchors, 4)})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)})
m.run()
out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
tvm.testing.assert_allclose(out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(out, target, {"cls_prob": (batch_size, num_anchors, num_classes),
"loc_preds": (batch_size, num_anchors * 4),
"anchors": (1, num_anchors, 4)})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"cls_prob": np_cls_prob.astype(dtype), "loc_preds": np_loc_preds.astype(dtype), "anchors": np_anchors.astype(dtype)})
m.run()
tvm_out = m.get_output(0, tvm.nd.empty(expected_np_out.shape, dtype))
tvm.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, atol=1e-5, rtol=1e-5)
def test_non_max_suppression():
dshape = (1, 5, 6)
......@@ -607,15 +605,14 @@ def test_non_max_suppression():
[-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1],
[-1, -1, -1, -1, -1, -1]]])
target = "llvm"
ctx = tvm.cpu()
graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)},
dtype={"data": "float32", "valid_count": "int32"})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"data": np_data, "valid_count": np_valid_count})
m.run()
out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
tvm.testing.assert_allclose(out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(out, target, {"data": dshape, "valid_count": (dshape[0],)},
dtype={"data": "float32", "valid_count": "int32"})
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**{"data": np_data, "valid_count": np_valid_count})
m.run()
tvm_out = m.get_output(0, tvm.nd.empty(np_result.shape, "float32"))
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result, atol=1e-5, rtol=1e-5)
def np_slice_like(np_data, np_shape_like, axis=[]):
begin_idx = [0 for _ in np_data.shape]
......
......@@ -36,6 +36,7 @@ from .op import Op
from .op.reduce import *
from .op.tensor import *
from .op.transform import *
from .op.algorithm import *
from . import nn
from . import annotation
from . import vision
......
......@@ -186,6 +186,13 @@ def _mx_pooling(inputs, attrs):
'Operator {} Pooling is not supported for frontend MXNet.'.format(pool_type.capitalize()))
def _mx_adaptive_avg_pooling(inputs, attrs):
output_size = attrs.get_int_tuple("output_size", [])
if output_size != (1,):
raise RuntimeError("AdaptiveAvgPooling with output_size other than 1 is not supported yet.")
return _op.nn.global_avg_pool2d(inputs[0])
def _mx_dropout(inputs, attrs):
rate = attrs.get_float("p", 0.5)
return _op.nn.dropout(inputs[0], rate=rate)
......@@ -529,15 +536,6 @@ def _mx_box_nms(inputs, attrs):
id_index = attrs.get_int('id_index', -1)
in_format = attrs.get_str('in_format', 'corner')
out_format = attrs.get_str('out_format', 'corner')
if coord_start != 2:
raise tvm.error.OpAttributeInvalid(
'Value of attribute "coord_start" must equal 2 for operator box_nms.')
if score_index != 1:
raise tvm.error.OpAttributeInvalid(
'Value of attribute "score_index" must equal 1 for operator box_nms.')
if id_index != -1 and int(id_index) != 0:
raise tvm.error.OpAttributeInvalid(
'Value of attribute "id_index" must equal either -1 or 0 for operator box_nms.')
if in_format != 'corner':
raise tvm.error.OpAttributeInvalid(
'Value of attribute "in_format" must equal "corner" for operator box_nms.')
......@@ -551,6 +549,8 @@ def _mx_box_nms(inputs, attrs):
iou_threshold=iou_thresh,
force_suppress=force_suppress,
top_k=top_k,
coord_start=coord_start,
score_index=score_index,
id_index=id_index,
return_indices=False,
invalid_to_bottom=True)
......@@ -648,6 +648,15 @@ def _mx_deformable_convolution(inputs, attrs):
return res
def _mx_argsort(inputs, attrs):
assert len(inputs) == 1
new_attrs = {}
new_attrs["axis"] = attrs.get_int("axis", -1)
new_attrs["is_ascend"] = attrs.get_bool("is_ascend", True)
new_attrs["dtype"] = attrs.get_str("dtype", "float32")
return _op.argsort(inputs[0], **new_attrs)
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
......@@ -783,6 +792,7 @@ _convert_map = {
"BlockGrad" : _mx_BlockGrad,
"shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding,
"argsort" : _mx_argsort,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"smooth_l1" : _mx_smooth_l1,
......@@ -796,6 +806,7 @@ _convert_map = {
"_contrib_MultiProposal" : _mx_proposal,
"_contrib_box_nms" : _mx_box_nms,
"_contrib_DeformableConvolution" : _mx_deformable_convolution,
"_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
......
......@@ -24,6 +24,7 @@ from .op import get, register, register_schedule, register_compute, register_gra
from .reduce import *
from .tensor import *
from .transform import *
from .algorithm import *
from . import nn
from . import annotation
from . import image
......@@ -36,6 +37,7 @@ from . import _tensor
from . import _tensor_grad
from . import _transform
from . import _reduce
from . import _algorithm
from ..expr import Expr
from ..base import register_relay_node
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"Definition of classic algorithms"
# pylint: disable=invalid-name,unused-argument
from __future__ import absolute_import
import topi
from topi.util import get_const_int
from ..op import OpPattern, register_compute, register_schedule, register_pattern
@register_schedule("argsort")
def schedule_argsort(_, outs, target):
"""Schedule definition of argsort"""
with target:
return topi.generic.schedule_argsort(outs)
@register_compute("argsort")
def compute_argsort(attrs, inputs, _, target):
"""Compute definition of argsort"""
axis = get_const_int(attrs.axis)
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = str(attrs.dtype)
return [
topi.argsort(inputs[0], None, axis=axis, is_ascend=is_ascend, \
dtype=dtype, flag=False)
]
register_pattern("argsort", OpPattern.OPAQUE)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Classic algorithm operation"""
from __future__ import absolute_import as _abs
from . import _make
def argsort(data, axis=-1, is_ascend=1, dtype="float32"):
"""Performs sorting along the given axis and returns an array of indicies
having same shape as an input array that index data in sorted order.
Parameters
----------
data : relay.Expr
The input data tensor.
valid_count : tvm.Tensor
The number of valid elements to be sorted.
axis : int, optional
Axis long which to sort the input tensor.
is_ascend : boolean, optional
Whether to sort in ascending or descending order.
dtype : string, optional
DType of the output indices.
Returns
-------
out : relay.Expr
Tensor with same shape as data.
"""
return _make.argsort(data, axis, is_ascend, dtype)
......@@ -710,6 +710,30 @@ def concatenate(data, axis):
return _make.concatenate(Tuple(data), axis)
def stack(data, axis):
"""Join a sequence of arrays along a new axis.
Parameters
----------
data : Union(List[relay.Expr], Tuple(relay.Expr))
A list of tensors.
axis : int
The axis in the result array along which the input arrays are stacked.
Returns
-------
ret : relay.Expr
The stacked tensor.
"""
data = list(data)
if not data:
raise ValueError("relay.stack requires data to be non-empty.")
if not isinstance(axis, int):
raise ValueError("For now, we only support integer axis")
return _make.stack(Tuple(data), axis)
def copy(data):
"""Copy a tensor.
......
......@@ -315,28 +315,6 @@ def arange(start, stop=None, step=1, dtype="float32"):
return _make.arange(start, stop, step, dtype)
def stack(data, axis):
"""Join a sequence of arrays along a new axis.
Parameters
----------
data : relay.Expr
The input data to the operator.
axis : int
The axis in the result array along which the input arrays are stacked.
.. note::
Each array in the input array sequence must have the same shape.
Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.stack(data, axis)
def repeat(data, repeats, axis):
"""Repeats elements of an array.
By default, repeat flattens the input array into 1-D and then repeats the elements.
......@@ -698,5 +676,4 @@ def gather_nd(data, indices):
indices = [[0, 1], [1, 0]]
relay.gather_nd(data, indices) = [[3, 4], [5, 6]]
"""
return _make.gather_nd(data, indices)
......@@ -103,12 +103,15 @@ def compute_nms(attrs, inputs, _, target):
iou_threshold = get_const_float(attrs.iou_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress))
top_k = get_const_int(attrs.top_k)
coord_start = get_const_int(attrs.coord_start)
score_index = get_const_int(attrs.score_index)
id_index = get_const_int(attrs.id_index)
invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
return [
topi.vision.non_max_suppression(inputs[0], inputs[1], max_output_size,
iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
coord_start, score_index, id_index,
return_indices, invalid_to_bottom)
]
......
......@@ -49,6 +49,8 @@ def non_max_suppression(data,
iou_threshold=0.5,
force_suppress=False,
top_k=-1,
coord_start=2,
score_index=1,
id_index=0,
return_indices=True,
invalid_to_bottom=False):
......@@ -77,6 +79,12 @@ def non_max_suppression(data,
top_k : int, optional
Keep maximum top k detections before nms, -1 for no limit.
coord_start : int, optional
The starting index of the consecutive 4 coordinates.
score_index : int, optional
Index of the scores/confidence of boxes.
id_index : int, optional
index of the class categories, -1 to disable.
......@@ -93,4 +101,5 @@ def non_max_suppression(data,
"""
return _make.non_max_suppression(data, valid_count, max_output_size,
iou_threshold, force_suppress, top_k,
id_index, return_indices, invalid_to_bottom)
coord_start, score_index, id_index,
return_indices, invalid_to_bottom)
......@@ -46,20 +46,20 @@ bool CompareDescend(const std::pair<int32_t, DType>& lhs,
}
// Argsort implemented C library sort.
// Argsort implemented C library sort for nms.
// Return indices of sorted tensor.
// By default, the last axis will be used to sort.
// sort_num specify the number of elements to be sorted.
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
// and sort axis is dk. sort_num should have dimension of
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *input = args[0];
DLTensor *sort_num = args[1];
DLTensor *output = args[2];
int32_t axis = args[3];
bool is_descend = args[4];
bool is_ascend = args[4];
auto dtype = input->dtype;
auto data_ptr = static_cast<float *>(input->data);
......@@ -97,10 +97,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
int64_t full_idx = base_idx + k * axis_mul_after;
sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
}
if (is_descend) {
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
} else {
if (is_ascend) {
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
} else {
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
}
for (int32_t k = 0; k < input->shape[axis]; ++k) {
*(static_cast<int32_t *>(output->data) + base_idx + k * axis_mul_after)
......@@ -110,5 +110,68 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
}
});
// Argsort implemented C library sort.
// Return indices of sorted tensor.
// By default, the last axis will be used to sort.
// sort_num specify the number of elements to be sorted.
// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1))
// and sort axis is dk. sort_num should have dimension of
// (d1, d2, ..., d(k-1), d(k+1), ..., dn).
TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort")
.set_body([](TVMArgs args, TVMRetValue *ret) {
DLTensor *input = args[0];
DLTensor *output = args[1];
int32_t axis = args[2];
bool is_ascend = args[3];
auto dtype = input->dtype;
auto data_ptr = static_cast<float *>(input->data);
std::vector<std::pair<float, float>> sorter;
int64_t axis_mul_before = 1;
int64_t axis_mul_after = 1;
if (axis < 0) {
axis = input->ndim + axis;
}
// Currently only supports input dtype to be float32.
CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype "
"to be float32.";
CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype "
"to be float32.";
CHECK_LT(axis, input->ndim) << "Axis out of boundary for "
"input ndim " << input->ndim;
for (int i = 0; i < input->ndim; ++i) {
if (i < axis) {
axis_mul_before *= input->shape[i];
} else if (i > axis) {
axis_mul_after *= input->shape[i];
}
}
int32_t current_sort_num = input->shape[axis];
for (int64_t i = 0 ; i < axis_mul_before; ++i) {
for (int64_t j = 0 ; j < axis_mul_after; ++j) {
sorter.clear();
int64_t base_idx = i * input->shape[axis] * axis_mul_after + j;
for (int64_t k = 0; k < current_sort_num; ++k) {
int64_t full_idx = base_idx + k * axis_mul_after;
sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx)));
}
if (is_ascend) {
std::stable_sort(sorter.begin(), sorter.end(), CompareAscend<float>);
} else {
std::stable_sort(sorter.begin(), sorter.end(), CompareDescend<float>);
}
for (int32_t k = 0; k < input->shape[axis]; ++k) {
*(static_cast<float *>(output->data) + base_idx + k * axis_mul_after)
= k < static_cast<float>(sorter.size()) ? sorter[k].first : k;
}
}
}
});
} // namespace contrib
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file nms.cc
* \brief Non-maximum suppression operators
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/algorithm.h>
namespace tvm {
namespace relay {
TVM_REGISTER_NODE_TYPE(ArgsortAttrs);
bool ArgsortRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [data, result]
const ArgsortAttrs* param = attrs.as<ArgsortAttrs>();
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
CHECK(types[0].as<IncompleteTypeNode>())
<< "Argsort: expect input type to be TensorType but get "
<< types[0];
return false;
}
CHECK_EQ(param->dtype, Float(32));
reporter->Assign(types[1], TensorTypeNode::make(data->shape, param->dtype));
return true;
}
Expr MakeArgsort(Expr data,
int axis,
bool is_ascend,
DataType dtype) {
auto attrs = make_node<ArgsortAttrs>();
attrs->axis = axis;
attrs->is_ascend = is_ascend;
attrs->dtype = dtype;
static const Op& op = Op::Get("argsort");
return CallNode::make(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_API("relay.op._make.argsort")
.set_body_typed(MakeArgsort);
RELAY_REGISTER_OP("argsort")
.describe(R"doc(Returns the indices that would sort an
input array along the given axis.
)doc" TVM_ADD_FILELINE)
.set_num_inputs(1)
.set_attrs_type_key("relay.attrs.ArgsortAttrs")
.add_argument("data", "Tensor", "Input data.")
.set_support_level(6)
.add_type_rel("Argsort", ArgsortRel);
} // namespace relay
} // namespace tvm
......@@ -106,6 +106,8 @@ Expr MakeNMS(Expr data,
double iou_threshold,
bool force_suppress,
int top_k,
int coord_start,
int score_index,
int id_index,
bool return_indices,
bool invalid_to_bottom) {
......@@ -114,6 +116,8 @@ Expr MakeNMS(Expr data,
attrs->iou_threshold = iou_threshold;
attrs->force_suppress = force_suppress;
attrs->top_k = top_k;
attrs->coord_start = coord_start;
attrs->score_index = score_index;
attrs->id_index = id_index;
attrs->return_indices = return_indices;
attrs->invalid_to_bottom = invalid_to_bottom;
......
......@@ -24,11 +24,11 @@ def test_sort():
data = tvm.placeholder((n, l, m), name='data')
sort_num = tvm.placeholder((n, m), name="sort_num", dtype="int32")
axis = 1
is_descend = True
is_ascend = False
out = tvm.extern(data.shape, [data, sort_num],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.sort.argsort", ins[0],
ins[1], outs[0], axis, is_descend),
"tvm.contrib.sort.argsort_nms", ins[0],
ins[1], outs[0], axis, is_ascend),
dtype='int32', name="sort_tensor")
input = [[[1, 2, 3], [2, 4.5, 3.5], [1.1, 0.5, 1], [3.2, -5, 0.5], [1.5, 0, 0]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]]]
......@@ -50,13 +50,13 @@ def test_sort_np():
dshape = (1, 2, 3, 4, 5, 6)
axis = 4
reduced_shape = (1, 2, 3, 4, 6)
is_descend = False
is_ascend = True
data = tvm.placeholder(dshape, name='data')
sort_num = tvm.placeholder(reduced_shape, name="sort_num", dtype="int32")
out = tvm.extern(data.shape, [data, sort_num],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.sort.argsort", ins[0],
ins[1], outs[0], axis, is_descend),
"tvm.contrib.sort.argsort_nms", ins[0],
ins[1], outs[0], axis, is_ascend),
dtype='int32', name="sort_tensor")
ctx = tvm.cpu(0)
......
......@@ -177,12 +177,13 @@ def test_get_valid_counts():
assert "score_threshold" in z.astext()
func = relay.Function([x], z.astuple())
func = relay.ir_pass.infer_type(func)
ctx_list = [("llvm", tvm.cpu(0))]
for target, ctx in ctx_list:
for target, ctx in ctx_list():
if target == 'cuda':
return
intrp = relay.create_executor("debug", ctx=ctx, target=target)
out = intrp.evaluate(func)(np_data)
tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3)
tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04)
tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04)
verify_get_valid_counts((1, 2500, 6), 0)
verify_get_valid_counts((1, 2500, 6), -1)
......@@ -195,9 +196,13 @@ def test_non_max_suppression():
iou_threshold=0.5, force_suppress=False, top_k=-1,
check_type_only=False):
x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32"))
x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int"))
z = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k, return_indices=False)
z_indices = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k)
x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int32"))
z = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \
iou_threshold = iou_threshold, force_suppress = force_suppress, \
top_k = top_k, return_indices=False)
z_indices = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \
iou_threshold = iou_threshold, force_suppress = force_suppress, \
top_k = top_k)
assert "iou_threshold" in z.astext()
assert "iou_threshold" in z_indices.astext()
zz = relay.ir_pass.infer_type(z)
......@@ -212,8 +217,7 @@ def test_non_max_suppression():
func = relay.ir_pass.infer_type(func)
func_indices = relay.Function([x0, x1], z_indices)
func_indices = relay.ir_pass.infer_type(func_indices)
ctx_list = [("llvm", tvm.cpu(0))]
for target, ctx in ctx_list:
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(x0_data, x1_data)
op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data)
......@@ -296,8 +300,7 @@ def test_multibox_transform_loc():
nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False)
func = relay.Function([cls_prob, loc_pred, anchors], nms)
func = relay.ir_pass.infer_type(func)
ctx_list = [("llvm", tvm.cpu(0))]
for target, ctx in ctx_list:
for target, ctx in ctx_list():
intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds,
np_anchors)
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Support level6 operator test cases.
"""
import math
import numpy as np
import tvm
from tvm import relay
from tvm.relay.testing import ctx_list
import topi.testing
def test_argsort():
def verify_argsort(shape, axis, is_ascend):
x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.argsort(x, axis=axis, is_ascend=is_ascend)
zz = relay.ir_pass.infer_type(z)
func = relay.Function([x], z)
x_data = np.random.uniform(size=shape).astype("float32")
if is_ascend:
ref_res = np.argsort(x_data, axis=axis)
else:
ref_res = np.argsort(-x_data, axis=axis)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype("float"), rtol=1e-5)
verify_argsort((2, 3, 4), axis=0, is_ascend=False)
verify_argsort((1, 4, 6), axis=1, is_ascend=True)
verify_argsort((3, 5, 6), axis=-1, is_ascend=False)
if __name__ == "__main__":
test_argsort()
......@@ -21,6 +21,7 @@ from .generic_op_impl import *
from .reduction import *
from .transform import *
from .broadcast import *
from .sort import *
from . import nn
from . import x86
from . import cuda
......
......@@ -32,11 +32,15 @@ def _default_schedule(outs):
def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)"""
if "nms" in op.tag:
sort = op.input_tensors[1]
if op.tag in ["nms", "invalid_to_bottom"]:
if op.tag == "nms":
sort = op.input_tensors[1]
else:
out = op.input_tensors[0]
sort = s[out].op.input_tensors[1]
score = s[sort].op.input_tensors[0]
fused = s[score].fuse(*s[score].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
num_thread = int(tvm.target.current_target(allow_none=False).max_num_threads)
bx, tx = s[score].split(fused, factor=num_thread)
s[score].bind(bx, tvm.thread_axis("blockIdx.x"))
s[score].bind(tx, tvm.thread_axis("threadIdx.x"))
......@@ -199,3 +203,30 @@ def schedule_get_valid_counts(outs):
The computation schedule for the op.
"""
return _default_schedule(outs)
@generic.schedule_argsort.register(["cuda", "gpu"])
def schedule_argsort(outs):
"""Schedule for argsort operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
......@@ -19,3 +19,4 @@ from .nn import *
from .injective import *
from .extern import *
from .vision import *
from .sort import *
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, no-member
"""Generic vision operators"""
from __future__ import absolute_import as _abs
import tvm
from .vision import _default_schedule
@tvm.target.generic_func
def schedule_argsort(outs):
"""Schedule for argsort operator.
Parameters
----------
outs: Array of Tensor
The indices that would sort an input array along
the given axis.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
return _default_schedule(outs, False)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=too-many-arguments
"""Argsort operator"""
import tvm
from tvm import api
@tvm.target.generic_func
def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0):
"""Performs sorting along the given axis and returns an array
of indices having the same shape as an input array that index
data in sorted order.
Parameters
----------
data : tvm.Tensor
The input tensor.
valid_count : tvm.Tensor
1-D tensor for valid number of boxes only for ssd.
axis : optional, int
Axis along which to sort the input tensor.
By default the flattened array is used.
is_ascend : optional, boolean
Whether to sort in ascending or descending order.
dtype : optional, string
DType of the output indices.
flag : optional, boolean
Whether valid_count is valid.
Returns
-------
out : tvm.Tensor
Sorted index tensor.
Example
--------
.. code-block:: python
# An example to use argsort
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
axis = 0
is_ascend = False
flag = False
out = argsort(data, valid_count, axis, is_ascend, flag)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_argsort(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f(tvm_data, tvm_valid_count, tvm_out)
"""
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
if flag:
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8)
out = \
tvm.extern(data.shape,
[data, valid_count],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.sort.argsort_nms", ins[0], ins[1],
outs[0], axis, is_ascend),
dtype="int32",
in_buffers=[data_buf, valid_count_buf],
out_buffers=out_buf,
name="argsort_nms_cpu",
tag="argsort_nms_cpu")
else:
out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
out = \
tvm.extern(data.shape,
[data],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.sort.argsort", ins[0],
outs[0], axis, is_ascend),
dtype=dtype,
in_buffers=[data_buf],
out_buffers=out_buf,
name="argsort_cpu",
tag="argsort_cpu")
return out
......@@ -18,7 +18,8 @@
"""Non-maximum suppression operator"""
import tvm
from tvm import api, hybrid
from tvm import hybrid
from ..sort import argsort
@hybrid.script
def hybrid_rearrange_out(data):
......@@ -129,7 +130,7 @@ def get_valid_counts(data, score_threshold=0):
@hybrid.script
def hybrid_nms(data, sorted_index, valid_count,
max_output_size, iou_threshold, force_suppress,
top_k, id_index):
top_k, coord_start, id_index):
"""Hybrid routing for non-maximum suppression.
Parameters
......@@ -158,6 +159,9 @@ def hybrid_nms(data, sorted_index, valid_count,
top_k : tvm.const
Keep maximum top k detections before nms, -1 for no limit.
coord_start : tvm.const
Start index of the consecutive 4 coordinates.
id_index : tvm.const
index of the class categories, -1 to disable.
......@@ -208,7 +212,7 @@ def hybrid_nms(data, sorted_index, valid_count,
batch_idx = i
box_a_idx = j
box_b_idx = k
box_start_idx = 2
box_start_idx = coord_start
a_t = output[batch_idx, box_a_idx, box_start_idx + 1]
a_b = output[batch_idx, box_a_idx, box_start_idx + 3]
a_l = output[batch_idx, box_a_idx, box_start_idx]
......@@ -252,7 +256,8 @@ def hybrid_nms(data, sorted_index, valid_count,
@tvm.target.generic_func
def non_max_suppression(data, valid_count, max_output_size=-1,
iou_threshold=0.5, force_suppress=False, top_k=-1,
id_index=0, return_indices=True, invalid_to_bottom=False):
coord_start=2, score_index=1, id_index=0,
return_indices=True, invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection.
Parameters
......@@ -278,6 +283,12 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
top_k : optional, int
Keep maximum top k detections before nms, -1 for no limit.
coord_start : required, int
Start index of the consecutive 4 coordinates.
score_index: optional, int
Index of the scores/confidence of boxes.
id_index : optional, int
index of the class categories, -1 to disable.
......@@ -317,32 +328,16 @@ def non_max_suppression(data, valid_count, max_output_size=-1,
"""
batch_size = data.shape[0]
num_anchors = data.shape[1]
valid_count_dtype = "int32"
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype,
"valid_count_buf", data_alignment=4)
score_axis = 1
score_axis = score_index
score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis])
score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype,
"score_tensor_buf", data_alignment=8)
sort_tensor_dtype = "int32"
sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype,
"sort_tensor_buf", data_alignment=8)
sort_tensor = \
tvm.extern(score_shape,
[score_tensor, valid_count],
lambda ins, outs: tvm.call_packed(
"tvm.contrib.sort.argsort", ins[0], ins[1],
outs[0], score_axis, True),
dtype=sort_tensor_dtype,
in_buffers=[score_tensor_buf, valid_count_buf],
out_buffers=sort_tensor_buf,
name="nms_sort")
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
out, box_indices = hybrid_nms(data, sort_tensor, valid_count,
tvm.const(max_output_size, dtype="int32"),
tvm.const(iou_threshold, dtype="float32"),
tvm.const(force_suppress, dtype="bool"),
tvm.const(top_k, dtype="int32"),
tvm.const(coord_start, dtype="int32"),
tvm.const(id_index, dtype="int32"))
if not return_indices and invalid_to_bottom:
out = hybrid_rearrange_out(out)
......
......@@ -308,7 +308,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm
"""
inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor,
clip, threshold, variances)
out = non_max_suppression(inter_out[0], inter_out[1], -1,
nms_threshold, force_suppress, nms_topk,
return_indices=False)
out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1,
iou_threshold=nms_threshold, force_suppress=force_suppress,
top_k=nms_topk, return_indices=False)
return out
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for vision package"""
from __future__ import print_function
import math
import numpy as np
import tvm
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
from topi import argsort
def test_argsort():
dshape = (1, 8)
valid_count_shape = (2,)
data = tvm.placeholder(dshape, name="data", dtype="float32")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype)
np_valid_count = np.array([4]).astype(valid_count.dtype)
np_result = np.argsort(-np_data)
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
out = argsort(data, valid_count, axis = -1, is_ascend = False, flag=False)
s = topi.generic.schedule_argsort(out)
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx)
f = tvm.build(s, [data, valid_count, out], device)
f(tvm_data, tvm_valid_count, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0)
for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
if __name__ == "__main__":
test_argsort()
......@@ -66,7 +66,7 @@ def verify_get_valid_counts(dshape, score_threshold):
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)
for device in ['llvm']:
for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
......@@ -124,7 +124,7 @@ def test_non_max_suppression():
f(tvm_data, tvm_valid_count, tvm_indices_out)
tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4)
for device in ['llvm']:
for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
......@@ -231,7 +231,7 @@ def test_multibox_detection():
f(tvm_cls_prob, tvm_loc_preds, tvm_anchors, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, rtol=1e-4)
for device in ['llvm', 'opencl']:
for device in ['llvm', 'opencl', 'cuda']:
check_device(device)
......@@ -275,7 +275,7 @@ def verify_roi_align(batch, in_channel, in_size, num_roi, pooled_size, spatial_s
f(tvm_a, tvm_rois, tvm_b)
tvm.testing.assert_allclose(tvm_b.asnumpy(), b_np, rtol=1e-3)
for device in ['llvm', 'cuda']:
for device in ['llvm', 'cuda', 'opencl']:
check_device(device)
......
......@@ -18,6 +18,7 @@
Deploy Single Shot Multibox Detector(SSD) model
===============================================
**Author**: `Yao Wang <https://github.com/kevinthesun>`_
`Leyuan Wang <https://github.com/Laurawly>`_
This article is an introductory tutorial to deploy SSD models with TVM.
We will use GluonCV pre-trained SSD model and convert it to Relay IR
......@@ -37,30 +38,29 @@ from gluoncv import model_zoo, data, utils
# ------------------------------
# .. note::
#
# Currently we support compiling SSD on CPU only.
# GPU support is in progress.
# We support compiling SSD on bot CPUs and GPUs now.
#
# To get best inference performance on CPU, change
# target argument according to your device and
# follow the :ref:`tune_relay_x86` to tune x86 CPU and
# :ref:`tune_relay_arm` for arm cpu.
#
# To get best performance fo SSD on Intel graphics,
# change target argument to 'opencl -device=intel_graphics'
#
# SSD with VGG as body network is not supported yet since
# x86 conv2d schedule doesn't support dilation.
supported_model = [
'ssd_512_resnet18_v1_voc',
'ssd_512_resnet18_v1_coco',
'ssd_512_resnet50_v1_voc',
'ssd_512_resnet50_v1_coco',
'ssd_512_resnet101_v2_voc',
'ssd_512_mobilenet1_0_voc',
'ssd_512_mobilenet1_0_coco',
'ssd_512_mobilenet1.0_voc',
'ssd_512_mobilenet1.0_coco',
]
model_name = "ssd_512_resnet50_v1_voc"
model_name = supported_model[0]
dshape = (1, 3, 512, 512)
dtype = "float32"
target_list = ctx_list()
######################################################################
......@@ -76,7 +76,7 @@ x, img = data.transforms.presets.ssd.load_test(im_fname, short=512)
block = model_zoo.get_model(model_name, pretrained=True)
def compile(target):
def build(target):
net, params = relay.frontend.from_mxnet(block, {"data": dshape})
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(net, target, params=params)
......@@ -98,10 +98,7 @@ def run(graph, lib, params, ctx):
return class_IDs, scores, bounding_boxs
for target, ctx in target_list:
if target == "cuda":
print("GPU not supported yet, skip.")
continue
graph, lib, params = compile(target)
graph, lib, params = build(target)
class_IDs, scores, bounding_boxs = run(graph, lib, params, ctx)
######################################################################
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment