Commit 9473dca2 by Steven S. Lyubomirsky Committed by Tianqi Chen

[Relay][Op] Add compute, schedule, and tests for expand_dims and squeeze (#2133)

parent b1cf70a8
#pylint: disable=invalid-name, unused-argument #pylint: disable=invalid-name, unused-argument
"""Backend compiler related feature registration""" """Backend compiler related feature registration"""
from __future__ import absolute_import from __future__ import absolute_import
import topi
import topi.cuda
from tvm import container
from . import op as _reg from . import op as _reg
from .op import schedule_injective, OpPattern from .op import (schedule_injective, register_compute, register_schedule,
register_pattern, OpPattern)
schedule_broadcast = schedule_injective
# squeeze
@register_compute("squeeze")
def squeeze_compiler(attrs, inputs, output_type, target):
"""Compiler for squeeze dims."""
assert len(inputs) == 1
if attrs.axis is None:
axis = None
elif isinstance(attrs.axis, container.Array):
axis = tuple(attrs.axis)
else:
axis = int(attrs.axis)
return [topi.squeeze(inputs[0], axis)]
register_pattern("squeeze", OpPattern.INJECTIVE)
register_schedule("squeeze", schedule_injective)
# expand_dims
@register_compute("expand_dims")
def expand_dims_compiler(attrs, inputs, output_type, target):
"""Compiler for expand_dims."""
assert len(inputs) == 1
new_axis = int(attrs.num_newaxis)
assert new_axis >= 0
# axis should be in range [-data.ndim - 1, data.ndim]
axis = int(attrs.axis)
assert axis >= -len(inputs[0].shape) - 1
assert axis <= len(inputs[0].shape)
return [topi.expand_dims(inputs[0], axis, new_axis)]
_reg.register_schedule("expand_dims", schedule_broadcast)
_reg.register_pattern("expand_dims", OpPattern.BROADCAST)
# strided_slice # strided_slice
_reg.register_schedule("strided_slice", schedule_injective) _reg.register_schedule("strided_slice", schedule_injective)
......
...@@ -90,6 +90,22 @@ def test_binary_op(): ...@@ -90,6 +90,22 @@ def test_binary_op():
check_binary_op(opfunc, ref) check_binary_op(opfunc, ref)
def test_expand_dims():
# based on topi test
def verify_expand_dims(dshape, dtype, oshape, axis, num_newaxis):
x = relay.Var("x", relay.TensorType(dshape, dtype))
func = relay.Function([x], relay.expand_dims(x, axis, num_newaxis))
for target, ctx in ctx_list():
data = np.random.uniform(size=dshape).astype(dtype)
ref_res = data.reshape(oshape)
intrp = relay.create_executor("graph", ctx=ctx, target=target)
op_res = intrp.evaluate(func)(data)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
verify_expand_dims((3, 10), 'float32', (3, 10, 1, 1), 2, 2)
verify_expand_dims((3, 10), 'float32', (1, 3, 10), -3, 1)
def test_bias_add(): def test_bias_add():
xshape=(10, 2, 3, 4) xshape=(10, 2, 3, 4)
bshape=(2,) bshape=(2,)
...@@ -295,6 +311,7 @@ if __name__ == "__main__": ...@@ -295,6 +311,7 @@ if __name__ == "__main__":
test_binary_op() test_binary_op()
test_expand_dims_infer_type() test_expand_dims_infer_type()
test_concatenate() test_concatenate()
test_expand_dims()
test_softmax() test_softmax()
test_log_softmax() test_log_softmax()
test_dropout() test_dropout()
......
...@@ -60,6 +60,22 @@ def test_clip(): ...@@ -60,6 +60,22 @@ def test_clip():
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01) np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
def test_squeeze():
def verify_squeeze(shape, dtype, axis):
x = relay.var("x", relay.TensorType(shape, dtype))
squeeze = relay.squeeze(x, axis=axis)
np_axis = tuple(axis) if axis is not None else None
data = np.random.random_sample(shape).astype(dtype)
intrp = create_executor()
op_res = intrp.evaluate(squeeze, { x : relay.const(data) })
ref_res = np.squeeze(data, axis=np_axis)
np.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=0.01)
verify_squeeze((1, 3, 2, 5), "float32", None)
verify_squeeze((1, 3, 1), "float32", [0])
verify_squeeze((1, 2, 1, 2, 1), "float32", [0, 2])
def test_transpose_infer_type(): def test_transpose_infer_type():
...@@ -308,6 +324,7 @@ if __name__ == "__main__": ...@@ -308,6 +324,7 @@ if __name__ == "__main__":
test_full_like() test_full_like()
test_infer_type_leaky_relu() test_infer_type_leaky_relu()
test_infer_type_prelu() test_infer_type_prelu()
test_squeeze()
test_squeeze_infer_type() test_squeeze_infer_type()
test_squeeze_bad_axes_infer_type() test_squeeze_bad_axes_infer_type()
test_split_infer_type() test_split_infer_type()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment