Commit 2e94a4b5 by yuruofeifei Committed by Tianqi Chen

[TOPI] Add compute for more operators (#849)

* [TOPI] Add compute for more operators

* Remove device except llvm

* Address comments

* Remove matmul compute

* Add outtype to boolean operator

* Address coments
parent 96cc008d
......@@ -19,6 +19,7 @@ from . import schedule as _schedule
from . import container as _container
from . import tag as _tag
int8 = "int8"
int32 = "int32"
float32 = "float32"
handle = "handle"
......
......@@ -12,6 +12,7 @@ from __future__ import absolute_import as _abs
from tvm._ffi.libinfo import __version__
from .math import *
from .tensor import *
from .reduction import *
from .transform import *
from .broadcast import *
......@@ -23,4 +24,3 @@ from . import mali
from . import testing
from . import util
from . import rocm
from . import cpp
# pylint: disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition
"""Elementwise operators"""
from __future__ import absolute_import as _abs
import tvm
from . import tag
@tvm.tag_scope(tag=tag.ELEMWISE)
def elemwise_sum(xs, num_args):
"""Perform element-wise sum on inputs
Parameters
----------
xs : list of tvm.Tensor
Input arguments.
num_args : int
Number of arguments
Returns
-------
y : tvm.Tensor
The result.
"""
assert len(xs) > 0, "elemwise sum must have at least one input tensor."
def _compute(*i):
return sum([x(*i) for x in xs])
return tvm.compute(xs[0].shape, _compute)
@tvm.tag_scope(tag=tag.ELEMWISE)
def full(shape, dtype, fill_value):
"""Fill tensor with fill_value
Parameters
----------
shape : tuple
Input tensor shape.
dtype : str
Data type
fill_value : float
Value to be filled
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(shape, lambda *i: tvm.const(fill_value, dtype))
@tvm.tag_scope(tag=tag.ELEMWISE)
def full_like(x, fill_value):
"""Construct a tensor with same shape as input tensor,
then fill tensor with fill_value.
Parameters
----------
x : tvm.Tensor
Input argument.
fill_value : float
Value to be filled
Returns
-------
y : tvm.Tensor
The result.
"""
dtype = x.dtype
return tvm.compute(x.shape, lambda *i: tvm.const(fill_value, dtype))
@tvm.tag_scope(tag=tag.ELEMWISE)
def greater(lhs, rhs, out_type=tvm.int8):
"""Compare two input tensors element-wise and return an mask tensor
which contains 1 if lhs > rhs holds else 0
Parameters
----------
lhs : tvm.Tensor
Left input argument.
rhs : tvm.Tensor
Right argument.
out_type: str
Output data type. Default is int8
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(lhs.shape,
lambda *i: tvm.select(lhs(*i) > rhs(*i),
tvm.const(1, out_type),
tvm.const(0, out_type)))
@tvm.tag_scope(tag=tag.ELEMWISE)
def less(lhs, rhs, out_type=tvm.int8):
"""Compare two input tensors element-wise and return an mask tensor
which contains 1 if lhs < rhs holds else 0
Parameters
----------
lhs : tvm.Tensor
Left input argument.
rhs : tvm.Tensor
Right argument.
out_type: str
Output data type. Default is int8
Returns
-------
y : tvm.Tensor
The result.
"""
return tvm.compute(lhs.shape,
lambda *i: tvm.select(lhs(*i) < rhs(*i),
tvm.const(1, out_type),
tvm.const(0, out_type)))
......@@ -2,6 +2,7 @@
"""Injective transformation operators"""
from __future__ import absolute_import as _abs
import tvm
import topi
from . import tag
from .util import ravel_index, unravel_index, get_const_int, get_const_tuple
......@@ -29,6 +30,60 @@ def expand_dims(a, axis, num_newaxis=1):
return tvm.compute(new_shape, _compute)
@tvm.tag_scope(tag=tag.BROADCAST)
def expand_like(a, shape_like, axis):
"""Expand an input array with the shape of second array.
This operation can always be composed of unsqueezing and
expanding dims on those unsqueezed axes.
Examples::
input = [ 12. 19. 27.]
input.shape = (3,)
new_shape_array = [[[1,2],[2,3],[1,3]],
[[1,4],[4,3],[5,2]],
[[7,1],[7,2],[7,3]]]
new_shape_array.shape = (3, 3, 2)
expand_like(input, [1,2], new_shape_array) =
[[[12,12],[12,12],[12,12]],
[[19,19],[19,19],[19,19]],
[[27,27],[27,27],[27,27]]]
Parameters
----------
a : tvm.Tensor
The tensor to be expanded.
shape_like : tvm.Tensor
The tensor to with target shape.
axis: list of int
axis to be expanded on
Returns
-------
ret : tvm.Tensor
"""
odim = len(axis) + len(a.shape)
if odim != len(shape_like.shape):
raise ValueError("shape inconsistent when expand_like ({}, {}, {})".format(
len(axis), len(a.shape), len(shape_like.shape)))
real_axis = topi.reduction._get_real_axis(len(shape_like.shape), axis)
real_axis = sorted(real_axis)
if not real_axis:
return a
def _compute(*idxs):
indices = []
axis_index = 0
for i in range(0, len(idxs)):
if i not in real_axis:
indices.append(idxs[i])
axis_index += 1
return a(*indices)
return tvm.compute(shape_like.shape, _compute)
@tvm.tag_scope(tag=tag.INJECTIVE)
def transpose(a, axes=None):
"""Permute the dimensions of an array.
......
import numpy as np
import tvm
import topi
from topi import util
def test_util():
x = tvm.const(100)
assert util.get_const_int(x) == 100
assert util.get_const_tuple((x, x)) == (100, 100)
def test_ewise():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')
def test_apply(func, name):
B = func(A)
assert tuple(B.shape) == tuple(A.shape)
assert B.op.body[0].name == name
test_apply(topi.exp, "exp")
test_apply(topi.tanh, "tanh")
test_apply(topi.sigmoid, "sigmoid")
test_apply(topi.log, "log")
test_apply(topi.sqrt, "sqrt")
if __name__ == "__main__":
test_util()
test_ewise()
"""Test code for tensor operator"""
import numpy as np
import tvm
import topi
from tvm.contrib.pickle_memoize import memoize
def verify_elemwise_sum(num_args, dtype):
shape = (3,5,4)
tvm_placeholders = []
for i in range(num_args):
tvm_placeholders.append(
tvm.placeholder(shape, name="data"+str(i), dtype=dtype))
esum = topi.elemwise_sum(tvm_placeholders, num_args=num_args)
s = tvm.create_schedule([esum.op])
@memoize("topi.tests.test_topi_elemwise_sum")
def get_ref_data():
np_nd = [np.random.uniform(0, 10, size=shape).astype(dtype)
for i in range(num_args)]
return np_nd
np_nd = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.context(device, 0)
out = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx)
f = tvm.build(s, tvm_placeholders + [esum], device, name="elemwise_sum")
tvm_nd = [tvm.nd.array(nd, ctx) for nd in np_nd] + [out]
f(*tvm_nd)
np_out = np.sum(np.array(np_nd), axis=0)
np.testing.assert_allclose(out.asnumpy(), np_out, rtol=1e-5)
for device in ["llvm"]:
check_device(device)
def verify_full(shape, dtype, fill_value):
A = tvm.placeholder(shape, dtype=dtype, name="A")
B = topi.full_like(A, fill_value=fill_value)
C = topi.full(shape=shape, dtype=dtype, fill_value=fill_value)
s1 = tvm.create_schedule([B.op])
s2 = tvm.create_schedule([C.op])
@memoize("topi.tests.test_topi_full")
def get_ref_data():
return np.full(shape, fill_value, dtype)
np_nd = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.context(device, 0)
out = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx)
f = tvm.build(s1, [A, B], device, name="full_like")
f(tvm.nd.array(np.zeros(shape, dtype), ctx), out)
np.testing.assert_allclose(out.asnumpy(), np_nd, rtol=1e-5)
f = tvm.build(s2, [C], device, name="full")
f(out)
np.testing.assert_allclose(out.asnumpy(), np_nd, rtol=1e-5)
for device in ["llvm"]:
check_device(device)
def verify_comparator(shape, dtype, out_type='int8'):
A = tvm.placeholder(shape, dtype, name="A")
B = tvm.placeholder(shape, dtype, name="B")
C = topi.less(A, B)
s_less = tvm.create_schedule([C.op])
D = tvm.placeholder(shape, dtype, name="D")
E = tvm.placeholder(shape, dtype, name="E")
F = topi.greater(D, E, out_type)
s_greater = tvm.create_schedule([F.op])
@memoize("topi.tests.test_topi_indicator")
def get_ref_data():
return [np.random.uniform(0, 10, size=shape).astype(dtype),
np.random.uniform(0, 10, size=shape).astype(dtype)]
[np_l, np_r] = get_ref_data()
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
ctx = tvm.context(device, 0)
out = tvm.nd.array(np.zeros(shape, dtype=out_type), ctx)
tvm_l = tvm.nd.array(np_l, ctx)
tvm_r = tvm.nd.array(np_r, ctx)
f = tvm.build(s_less, [A, B, C], device, name="less")
f(tvm_l, tvm_r, out)
np.testing.assert_allclose(out.asnumpy(), np.less(np_l, np_r).astype(out_type), rtol=1e-5)
f = tvm.build(s_greater, [D, E, F], device, name="greater")
f(tvm_l, tvm_r, out)
np.testing.assert_allclose(out.asnumpy(), np.greater(np_l, np_r).astype(out_type), rtol=1e-5)
for device in ["llvm"]:
check_device(device)
def test_elemwise_sum():
verify_elemwise_sum(1, "float32")
verify_elemwise_sum(5, "float32")
verify_elemwise_sum(4, "int32")
def test_full():
verify_full((3,4,5), "float32", 3.14)
verify_full((10,), "int32", 7)
def test_comparator():
verify_comparator((3,4,5), "float32")
verify_comparator((7,), "int32")
verify_comparator((3,4,5), "float32", "int8")
if __name__ == "__main__":
test_elemwise_sum()
test_full()
test_comparator()
......@@ -150,6 +150,41 @@ def verify_split(src_shape, indices_or_sections, axis):
check_device(device)
def verify_expand_like(in_shape, out_shape, axis):
A = tvm.placeholder(shape=in_shape, name="A")
B = tvm.placeholder(shape=out_shape, name="B")
C = topi.expand_like(A, B, axis)
s = tvm.create_schedule([C.op])
def check_device(device):
if not tvm.module.enabled(device):
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
ctx = tvm.context(device, 0)
f = tvm.build(s, [A, B, C], device, name="expand_like")
input = np.random.uniform(size=in_shape).astype(A.dtype)
tvm_input = tvm.nd.array(input, ctx)
odim = len(out_shape)
real_axis = [x if x >= 0 else x + odim for x in axis]
real_axis = sorted(real_axis)
for x in real_axis:
input = np.expand_dims(input, x).astype(A.dtype)
for x in real_axis:
input = np.concatenate([input]*out_shape[x], axis=x).astype(A.dtype)
assert input.shape == out_shape
tvm_shape_like = tvm.nd.array(np.zeros(out_shape).astype(B.dtype), ctx)
out = tvm.nd.array(np.zeros(out_shape).astype(A.dtype), ctx)
f(tvm_input, tvm_shape_like, out)
np.testing.assert_allclose(out.asnumpy(), input)
for device in ["llvm"]:
check_device(device)
def test_expand_dims():
verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), (1, 3, 10), -3, 1)
......@@ -191,6 +226,14 @@ def test_split():
verify_split((2, 12, 3), [2, 4], 1)
verify_split((10, 12, 24), [5, 7, 9], -1)
def test_expand_like():
verify_expand_like((3,), (2, 3), [0])
verify_expand_like((2,), (2, 3), [1])
verify_expand_like((3, 4), (3, 5, 4), [1])
verify_expand_like((5, 7), (5, 6, 7, 8), [1, 3])
if __name__ == "__main__":
test_concatenate()
test_tranpose()
......@@ -198,3 +241,4 @@ if __name__ == "__main__":
test_reshape()
test_squeeze()
test_split()
test_expand_like()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment