Commit 4468c576 by Xingjian Shi Committed by Tianqi Chen

[TOPI] add argmax, argmin (#515)

* add argmax argmin

* remove coder saver
parent 46657ed1
......@@ -4,14 +4,17 @@ from __future__ import absolute_import as _abs
import tvm
from .. import tag
def _schedule_reduce(op, sch):
def _schedule_reduce(op, sch, is_idx_reduce=False):
if is_idx_reduce:
data_out = op.input_tensors[0]
else:
data_in = op.input_tensors[0]
data_out = op.output(0)
assert len(sch[data_out].op.reduce_axis) > 0, "reduce_axis must be bigger than zero!"
if len(sch[data_out].op.axis) > 0:
all_reduce = False
num_thread = 16
num_thread = 32
block_x = tvm.thread_axis("blockIdx.x")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
......@@ -24,21 +27,38 @@ def _schedule_reduce(op, sch):
fused_reduce = sch[data_out].fuse(*[sch[data_out].op.reduce_axis[i]
for i in range(len(sch[data_out].op.reduce_axis))])
ko, ki = sch[data_out].split(fused_reduce, factor=num_thread)
if is_idx_reduce:
data_out_rf, _ = sch.rfactor(data_out, ki)
else:
data_out_rf = sch.rfactor(data_out, ki)
sch[data_out_rf].compute_at(sch[data_out], sch[data_out].op.reduce_axis[0])
tx = sch[data_out].op.reduce_axis[0]
sch[data_out].bind(tx, thread_x)
sch[data_out_rf].compute_at(sch[data_out], tx)
if is_idx_reduce:
real_output = op.output(0)
temp_idx_input = data_out.op.output(0)
temp_val_input = data_out.op.output(1)
else:
real_output = data_out
if not all_reduce:
# Fuse and split the axis
fused_outer = sch[data_out].fuse(*[sch[data_out].op.axis[i]
for i in range(len(sch[data_out].op.axis))])
bx, outer_in = sch[data_out].split(fused_outer, factor=num_thread)
fused_outer = sch[real_output].fuse(*[sch[real_output].op.axis[i]
for i in range(len(sch[real_output].op.axis))])
bx, outer_in = sch[real_output].split(fused_outer, factor=num_thread)
# Bind the axes to threads and blocks
sch[data_out].bind(sch[data_out].op.reduce_axis[0], thread_x)
sch[data_out].set_store_predicate(thread_x.equal(0))
sch[data_out].bind(outer_in, thread_y)
sch[data_out].bind(bx, block_x)
sch[real_output].bind(outer_in, thread_y)
sch[real_output].bind(bx, block_x)
if is_idx_reduce:
sch[temp_idx_input].compute_at(sch[real_output], outer_in)
sch[temp_val_input].compute_at(sch[real_output], outer_in)
else:
sch[data_out].bind(sch[data_out].op.reduce_axis[0], thread_x)
if is_idx_reduce:
sch[temp_idx_input].compute_at(sch[real_output],
sch[real_output].op.axis[0])
sch[temp_val_input].compute_at(sch[real_output],
sch[real_output].op.axis[0])
sch[real_output].set_store_predicate(thread_x.equal(0))
return sch
......@@ -73,9 +93,13 @@ def schedule_reduce(outs):
if tag.is_broadcast(operator.tag):
raise RuntimeError("Not yet support ewise after reduce")
elif operator.tag == 'comm_reduce':
_schedule_reduce(operator, sch)
_schedule_reduce(operator, sch, is_idx_reduce=False)
for tensor in operator.input_tensors:
traverse_before_reduce(tensor.op)
elif operator.tag == 'comm_reduce_idx':
_schedule_reduce(operator, sch, is_idx_reduce=True)
for tensor in operator.input_tensors[0].op.input_tensors:
traverse_before_reduce(tensor.op)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)
......
# pylint: disable=redefined-builtin,consider-using-enumerate
# pylint: disable=redefined-builtin,consider-using-enumerate,no-member
"""Reduce operators"""
from __future__ import absolute_import as _abs
import tvm
from . import tag
from .util import ravel_index
def _get_real_axis(ndim, axis):
if axis is None:
......@@ -26,6 +27,20 @@ def _get_real_axis(ndim, axis):
def get_reduce_out_shape(src_shape, axis=None, keepdims=False):
"""Get the output shape for the reduction OPs
Parameters
----------
src_shape : tuple of int or tvm.expr.IntImm
axis : None or int or tuple of int
keepdims : bool
Returns
-------
dst_shape : tuple of int or tvm.expr.IntImm
"""
real_axis = _get_real_axis(len(src_shape), axis)
if keepdims:
dst_shape = [src_shape[i] if i in real_axis else 1 for i in range(len(src_shape))]
......@@ -37,8 +52,36 @@ def get_reduce_out_shape(src_shape, axis=None, keepdims=False):
return dst_shape
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum):
def _argmax_comp(lhs, rhs):
"""Compare function of argmax"""
idx = tvm.make.Select((lhs[1] >= rhs[1]), lhs[0], rhs[0])
val = tvm.make.Select((lhs[1] >= rhs[1]), lhs[1], rhs[1])
return idx, val
def _argmax_init(idx_typ, val_typ):
"""Initial ind and val of argmax"""
return tvm.const(-1, idx_typ), tvm.min_value(val_typ)
def _argmin_comp(lhs, rhs):
"""Compare function of argmin"""
idx = tvm.make.Select((lhs[1] <= rhs[1]), lhs[0], rhs[0])
val = tvm.make.Select((lhs[1] <= rhs[1]), lhs[1], rhs[1])
return idx, val
def _argmin_init(idx_typ, val_typ):
"""Initial ind and val of argmax"""
return tvm.const(-1, idx_typ), tvm.max_value(val_typ)
def _choose_idx(idx, _, *indices):
"""Chose the idx from idx and val"""
return idx(*indices)
def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum, is_idx_reduce=False):
"""Reducing the data
Parameters
......@@ -63,9 +106,22 @@ def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum):
-------
ret : tvm.Tensor
"""
def _build_reduce_compute_func(data, real_axis, reduce_axes, keepdims,
func, *args):
ndim = len(data.shape)
real_axis = _get_real_axis(ndim, axis)
if real_axis == list(range(ndim)) and keepdims is False:
raise ValueError("Currently we do not support all reduce + keepdims = False!"
" axis={}, keepdims={}".format(axis, keepdims))
reduce_axes = [tvm.reduce_axis((0, data.shape[i]), "k%d" %i) for i in real_axis]
if keepdims:
target_shape = [1 if i in real_axis else data.shape[i] for i in range(ndim)]
else:
target_shape = []
for i in range(ndim):
if i not in real_axis:
target_shape.append(tvm.convert(data.shape[i]))
def _compute(*indices):
eval_range = []
eval_indices = []
if not keepdims:
arg_counter = 0
else:
......@@ -74,38 +130,29 @@ def comm_reduce(data, axis=None, keepdims=False, func=tvm.sum):
for i in range(len(data.shape)):
if i in real_axis:
eval_range.append(reduce_axes[red_counter])
eval_indices.append(reduce_axes[red_counter].var)
red_counter += 1
else:
if not keepdims:
eval_range.append(args[arg_counter])
eval_range.append(indices[arg_counter])
arg_counter += 1
else:
eval_range.append(args[i])
eval_range.append(indices[i])
if not is_idx_reduce:
return func(data[tuple(eval_range)], axis=reduce_axes)
ndim = len(data.shape)
real_axis = _get_real_axis(ndim, axis)
if real_axis == list(range(ndim)) and keepdims is False:
raise ValueError("Currently we do not support all reduce + keepdims = False!"
" axis={}, keepdims={}".format(axis, keepdims))
reduce_axes = [tvm.reduce_axis((0, data.shape[i]), "k%d" %i) for i in real_axis]
if keepdims:
target_shape = [tvm.convert(1) if i in real_axis else tvm.convert(data.shape[i])
for i in range(ndim)]
else:
target_shape = []
for i in range(ndim):
if i not in real_axis:
target_shape.append(tvm.convert(data.shape[i]))
idx = ravel_index(eval_indices, [data.shape[i] for i in real_axis])
return func((idx, data[tuple(eval_range)]), axis=reduce_axes)
if is_idx_reduce:
temp_idx, temp_val = tvm.compute(target_shape, _compute, name=data.name + "_red_temp")
out = tvm.compute(target_shape,
lambda *args: _build_reduce_compute_func(data,
real_axis,
reduce_axes,
keepdims, func, *args),
lambda *indices: _choose_idx(temp_idx, temp_val, *indices),
name=data.name + "_red")
else:
out = tvm.compute(target_shape, _compute, name=data.name + "_red")
return out
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def sum(data, axis=None, keepdims=False):
"""Sum of array elements over a given axis or a list of axes
......@@ -131,6 +178,7 @@ def sum(data, axis=None, keepdims=False):
return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.sum)
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def max(data, axis=None, keepdims=False):
"""Maximum of array elements over a given axis or a list of axes
......@@ -156,6 +204,7 @@ def max(data, axis=None, keepdims=False):
return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.max)
@tvm.tag_scope(tag=tag.COMM_REDUCE)
def min(data, axis=None, keepdims=False):
"""Minimum of array elements over a given axis or a list of axes
......@@ -179,3 +228,57 @@ def min(data, axis=None, keepdims=False):
ret : tvm.Tensor
"""
return comm_reduce(data, axis=axis, keepdims=keepdims, func=tvm.min)
@tvm.tag_scope(tag=tag.COMM_REDUCE_IDX)
def argmax(data, axis=None, keepdims=False):
"""Returns the indices of the maximum values along an axis.
Parameters
----------
data : tvm.Tensor
The input tvm tensor
axis : None or int or tuple of int
Axis or axes along which a sum is performed.
The default, axis=None, will sum all of the elements of the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
Returns
-------
ret : tvm.Tensor
"""
_argmax = tvm.comm_reducer(fcombine=_argmax_comp, fidentity=_argmax_init, name='argmax')
return comm_reduce(data, axis=axis, keepdims=keepdims, func=_argmax, is_idx_reduce=True)
@tvm.tag_scope(tag=tag.COMM_REDUCE_IDX)
def argmin(data, axis=None, keepdims=False):
"""Returns the indices of the minimum values along an axis.
Parameters
----------
data : tvm.Tensor
The input tvm tensor
axis : None or int or tuple of int
Axis or axes along which a sum is performed.
The default, axis=None, will sum all of the elements of the input array.
If axis is negative it counts from the last to the first axis.
keepdims : bool
If this is set to True, the axes which are reduced are left in the result as dimensions
with size one.
With this option, the result will broadcast correctly against the input array.
Returns
-------
ret : tvm.Tensor
"""
_argmin = tvm.comm_reducer(fcombine=_argmin_comp, fidentity=_argmin_init, name='argmin')
return comm_reduce(data, axis=axis, keepdims=keepdims, func=_argmin, is_idx_reduce=True)
......@@ -31,6 +31,7 @@ ELEMWISE = "elemwise"
BROADCAST = "broadcast"
INJECTIVE = "injective"
COMM_REDUCE = "comm_reduce"
COMM_REDUCE_IDX = "comm_reduce_idx"
def is_broadcast(tag):
......
......@@ -4,20 +4,48 @@ import numpy as np
import tvm
import topi
def _my_npy_argmax(arr, axis, keepdims):
if not keepdims:
return arr.argmax(axis=axis)
else:
if axis is not None:
out_shape = list(arr.shape)
out_shape[axis] = 1
else:
out_shape = [1 for _ in range(len(arr.shape))]
return arr.argmax(axis=axis).reshape(out_shape)
def _my_npy_argmin(arr, axis, keepdims):
if not keepdims:
return arr.argmin(axis=axis)
else:
out_shape = list(arr.shape)
out_shape[axis] = 1
return arr.argmin(axis=axis).reshape(out_shape)
def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
# Build the logic and compile the function
A = tvm.placeholder(shape=in_shape, name="A")
dat_dtype = "float32"
A = tvm.placeholder(shape=in_shape, name="A", dtype=dat_dtype)
A1 = topi.sqrt(topi.exp(A))
out_dtype = "float32"
if type == "sum":
B = topi.sum(A1, axis=axis, keepdims=keepdims)
elif type == "max":
B = topi.max(A1, axis=axis, keepdims=keepdims)
elif type == "min":
B = topi.min(A1, axis=axis, keepdims=keepdims)
elif type == "argmax":
B = topi.argmax(A1, axis=axis, keepdims=keepdims)
out_dtype = "int32"
elif type == "argmin":
B = topi.argmin(A1, axis=axis, keepdims=keepdims)
out_dtype = "int32"
else:
raise NotImplementedError
s = topi.cuda.schedule_reduce(B)
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
......@@ -26,18 +54,21 @@ def verify_reduce_map_ele(in_shape, axis, keepdims, type="sum"):
foo = tvm.build(s, [A, B], device, name="sum")
# Test
in_npy = np.random.uniform(size=in_shape).astype(np.float32)
in_npy_map = np.sqrt(np.exp(in_npy))
in_npy_map = np.sqrt(np.exp(in_npy)).astype(np.float32)
if type == "sum":
out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
elif type == "max":
out_npy = in_npy_map.max(axis=axis, keepdims=keepdims)
elif type == "min":
out_npy = in_npy_map.min(axis=axis, keepdims=keepdims)
elif type == "argmax":
out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims)
elif type == "argmin":
out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims)
else:
raise NotImplementedError
data_tvm = tvm.nd.array(in_npy, ctx=ctx)
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx)
out_tvm = tvm.nd.empty(shape=out_npy.shape, ctx=ctx, dtype=out_dtype)
for _ in range(1):
foo(data_tvm, out_tvm)
np.testing.assert_allclose(out_tvm.asnumpy(), out_npy, 1E-3, 1E-3)
......@@ -64,6 +95,19 @@ def test_reduce_map():
axis=(0, 2),
keepdims=False,
type="min")
verify_reduce_map_ele(in_shape=(32, 128),
axis=1,
keepdims=True,
type="argmax")
verify_reduce_map_ele(in_shape=(32, 24, 32, 24),
axis=2,
keepdims=False,
type="argmin")
verify_reduce_map_ele(in_shape=(31, 21, 15),
axis=None,
keepdims=True,
type="argmax")
if __name__ == "__main__":
test_reduce_map()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment