Commit 09960e30 by Leyuan Wang Committed by Haichen Shen

[Bugfix] Fix sort changing original input data issue (#3212)

* sort bugfix for not rearranging input data

* separate sort schedule

* fix lint

* use identity op instead

* fix lint

* remove redundent code
parent f7d7fdcd
...@@ -72,7 +72,10 @@ Operation ExternOpNode::make(std::string name, ...@@ -72,7 +72,10 @@ Operation ExternOpNode::make(std::string name,
CHECK_EQ(inputs.size(), input_placeholders.size()); CHECK_EQ(inputs.size(), input_placeholders.size());
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype); CHECK_EQ(inputs[i]->dtype, input_placeholders[i]->dtype);
CHECK(inputs[i]->shape.same_as(input_placeholders[i]->shape)); CHECK_EQ(inputs[i]->shape.size(), input_placeholders[i]->shape.size());
for (size_t dim = 0; dim < inputs[i]->shape.size(); ++dim) {
CHECK(inputs[i]->shape[dim].same_as(input_placeholders[i]->shape[dim]));
}
CHECK_EQ(input_placeholders[i]->strides.size(), 0U); CHECK_EQ(input_placeholders[i]->strides.size(), 0U);
} }
n->inputs = std::move(inputs); n->inputs = std::move(inputs);
......
...@@ -24,6 +24,7 @@ from tvm.generic import cast ...@@ -24,6 +24,7 @@ from tvm.generic import cast
from tvm.intrin import if_then_else, log, power from tvm.intrin import if_then_else, log, power
from topi.vision import non_max_suppression, get_valid_counts from topi.vision import non_max_suppression, get_valid_counts
from .sort import argsort from .sort import argsort
from .. import tag
def get_valid_counts_pre(data, flag, idx, score_threshold): def get_valid_counts_pre(data, flag, idx, score_threshold):
...@@ -730,7 +731,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1, ...@@ -730,7 +731,7 @@ def non_max_suppression_gpu(data, valid_count, max_output_size=-1,
"valid_count_buf", data_alignment=4) "valid_count_buf", data_alignment=4)
score_axis = score_index score_axis = score_index
score_shape = (batch_size, num_anchors) score_shape = (batch_size, num_anchors)
score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis], tag=tag.ELEMWISE)
sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True)
sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype, sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype,
......
...@@ -20,6 +20,10 @@ import tvm ...@@ -20,6 +20,10 @@ import tvm
from tvm import api from tvm import api
from topi.sort import argsort from topi.sort import argsort
from topi.math import identity
from .. import generic
from .. import tag
def sort_ir(data, output, axis, is_ascend): def sort_ir(data, output, axis, is_ascend):
"""Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
...@@ -104,8 +108,6 @@ def sort_ir(data, output, axis, is_ascend): ...@@ -104,8 +108,6 @@ def sort_ir(data, output, axis, is_ascend):
return ib.get() return ib.get()
def sort_nms_ir(data, valid_count, output, axis, is_ascend): def sort_nms_ir(data, valid_count, output, axis, is_ascend):
"""Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.
...@@ -221,29 +223,60 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0 ...@@ -221,29 +223,60 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0
out : tvm.Tensor out : tvm.Tensor
The output of this function. The output of this function.
""" """
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) sorted_data_buf = api.decl_buffer(data.shape, data.dtype, "sorted_data_buf", data_alignment=8)
sorted_data = identity(data)
if flag: if flag:
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4) "valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4)
out = tvm.extern([data.shape], out = tvm.extern([data.shape],
[data, valid_count], [sorted_data, valid_count],
lambda ins, outs: sort_nms_ir( lambda ins, outs: sort_nms_ir(
ins[0], ins[1], outs[0], axis, is_ascend), ins[0], ins[1], outs[0], axis, is_ascend),
dtype="int32", dtype="int32",
in_buffers=[data_buf, valid_count_buf], in_buffers=[sorted_data_buf, valid_count_buf],
out_buffers=[out_buf], out_buffers=[out_buf],
name="argsort_nms_gpu", name="argsort_nms_gpu",
tag="argsort_nms_gpu") tag="argsort_nms_gpu")
else: else:
out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8)
out = tvm.extern([data.shape], out = tvm.extern([data.shape],
[data], [sorted_data],
lambda ins, outs: sort_ir( lambda ins, outs: sort_ir(
ins[0], outs[0], axis, is_ascend), ins[0], outs[0], axis, is_ascend),
dtype=dtype, dtype=dtype,
in_buffers=[data_buf], in_buffers=[sorted_data_buf],
out_buffers=[out_buf], out_buffers=[out_buf],
name="argsort_gpu", name="argsort_gpu",
tag="argsort_gpu") tag="argsort_gpu")
return out return out
@generic.schedule_argsort.register(["cuda", "gpu"])
def schedule_argsort(outs):
"""Schedule for argsort operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
if tag.is_broadcast(op.tag):
_schedule_injective(op, s)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
...@@ -25,41 +25,17 @@ from .pooling import schedule_pool ...@@ -25,41 +25,17 @@ from .pooling import schedule_pool
def _default_schedule(outs): def _default_schedule(outs):
"""Default schedule for gpu.""" """Default schedule for gpu."""
target = tvm.target.current_target()
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs]) s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = [] scheduled_ops = []
from .injective import _schedule_injective
def traverse(op): def traverse(op):
"""inline all one-to-one-mapping operators except the last stage (output)""" if tag.is_broadcast(op.tag) or op.tag in ['bbox_score', 'sorted_bbox']:
if op.tag in ["nms", "invalid_to_bottom"]: _schedule_injective(op, s)
if op.tag == "nms":
sort = op.input_tensors[1]
else:
out = op.input_tensors[0]
sort = s[out].op.input_tensors[1]
score = s[sort].op.input_tensors[0]
fused = s[score].fuse(*s[score].op.axis)
num_thread = int(tvm.target.current_target(allow_none=False).max_num_threads)
bx, tx = s[score].split(fused, factor=num_thread)
s[score].bind(bx, tvm.thread_axis("blockIdx.x"))
s[score].bind(tx, tvm.thread_axis("threadIdx.x"))
if tag.is_broadcast(op.tag):
if op not in s.outputs:
s[op].compute_inline()
else:
x = op.output(0)
fused = s[x].fuse(*s[x].op.axis)
num_thread = tvm.target.current_target(allow_none=False).max_num_threads
bx, tx = s[x].split(fused, factor=num_thread)
s[x].bind(bx, tvm.thread_axis("blockIdx.x"))
s[x].bind(tx, tvm.thread_axis("threadIdx.x"))
for tensor in op.input_tensors: for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops: if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op) traverse(tensor.op)
scheduled_ops.append(op) scheduled_ops.append(op)
traverse(outs[0].op) traverse(outs[0].op)
return s return s
...@@ -173,19 +149,7 @@ def schedule_proposal(outs): ...@@ -173,19 +149,7 @@ def schedule_proposal(outs):
s: Schedule s: Schedule
The computation schedule for the op. The computation schedule for the op.
""" """
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs return _default_schedule(outs)
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
if op.tag in ['bbox_score', 'sorted_bbox']:
_schedule_injective(op, s)
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
@generic.schedule_get_valid_counts.register(["cuda", "gpu"]) @generic.schedule_get_valid_counts.register(["cuda", "gpu"])
def schedule_get_valid_counts(outs): def schedule_get_valid_counts(outs):
...@@ -203,30 +167,3 @@ def schedule_get_valid_counts(outs): ...@@ -203,30 +167,3 @@ def schedule_get_valid_counts(outs):
The computation schedule for the op. The computation schedule for the op.
""" """
return _default_schedule(outs) return _default_schedule(outs)
@generic.schedule_argsort.register(["cuda", "gpu"])
def schedule_argsort(outs):
"""Schedule for argsort operator.
Parameters
----------
outs: Array of Tensor
The computation graph description of argsort
in the format of an array of tensors.
Returns
-------
s: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
s = tvm.create_schedule([x.op for x in outs])
scheduled_ops = []
from .injective import _schedule_injective
def traverse(op):
for tensor in op.input_tensors:
if tensor.op.input_tensors and tensor.op not in scheduled_ops:
traverse(tensor.op)
scheduled_ops.append(op)
traverse(outs[0].op)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment