Commit c8373ece by Haichen Shen Committed by Tianqi Chen

[Relay][Frontend] Add a few mxnet ops in relay frontend (#2704)

parent af69f873
...@@ -64,6 +64,13 @@ def _mx_activations(inputs, attrs): ...@@ -64,6 +64,13 @@ def _mx_activations(inputs, attrs):
raise RuntimeError("Do not support act_type: {}".format(act_type)) raise RuntimeError("Do not support act_type: {}".format(act_type))
def _mx_compare(new_op, wrapper):
def impl(inputs, attrs):
dtype = ir_pass.infer_type(inputs[0]).checked_type.dtype
return wrapper(new_op)(inputs, attrs).astype(dtype)
return impl
def _mx_conv2d(inputs, attrs): def _mx_conv2d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel") kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 2: if len(kernel_size) != 2:
...@@ -333,32 +340,52 @@ _identity_list = [ ...@@ -333,32 +340,52 @@ _identity_list = [
] ]
_convert_map = { _convert_map = {
"_copy" : _rename(_op.copy), "_copy" : _rename(_op.copy),
"relu" : _rename(_op.nn.relu), "relu" : _rename(_op.nn.relu),
"broadcast_add" : _rename(_op.add), "broadcast_add" : _rename(_op.add),
"broadcast_sub" : _rename(_op.subtract), "broadcast_sub" : _rename(_op.subtract),
"broadcast_mul" : _rename(_op.multiply), "broadcast_mul" : _rename(_op.multiply),
"broadcast_div" : _rename(_op.divide), "broadcast_div" : _rename(_op.divide),
"elemwise_add" : _rename(_op.add), "broadcast_mod" : _rename(_op.mod),
"elemwise_sub" : _rename(_op.subtract), "broadcast_maximum" : _rename(_op.maximum),
"elemwise_mul" : _rename(_op.multiply), "broadcast_minimum" : _rename(_op.minimum),
"elemwise_div" : _rename(_op.divide), "broadcast_equal" : _mx_compare(_op.equal, _rename),
"flatten" : _rename(_op.nn.batch_flatten), "broadcast_not_equal" : _mx_compare(_op.not_equal, _rename),
"Flatten" : _rename(_op.nn.batch_flatten), "broadcast_greater" : _mx_compare(_op.greater, _rename),
"_plus_scalar" : _binop_scalar(_op.add), "broadcast_greater_equal": _mx_compare(_op.greater_equal, _rename),
"__add_scalar__": _binop_scalar(_op.add), "broadcast_lesser" : _mx_compare(_op.less, _rename),
"__sub_scalar__": _binop_scalar(_op.subtract), "broadcast_lesser_equal" : _mx_compare(_op.less_equal, _rename),
"_minus_scalar" : _binop_scalar(_op.subtract), "elemwise_add" : _rename(_op.add),
"__mul_scalar__": _binop_scalar(_op.multiply), "elemwise_sub" : _rename(_op.subtract),
"_mul_scalar" : _binop_scalar(_op.multiply), "elemwise_mul" : _rename(_op.multiply),
"__div_scalar__": _binop_scalar(_op.divide), "elemwise_div" : _rename(_op.divide),
"_div_scalar" : _binop_scalar(_op.divide), "_maximum" : _rename(_op.maximum),
"__pow_scalar__": _binop_scalar(_op.power), "_minimum" : _rename(_op.minimum),
"_rminus_scalar": _rbinop_scalar(_op.subtract), "flatten" : _rename(_op.nn.batch_flatten),
"__rsub_scalar__": _rbinop_scalar(_op.subtract), "Flatten" : _rename(_op.nn.batch_flatten),
"_rdiv_scalar" : _rbinop_scalar(_op.divide), "__add_scalar__" : _binop_scalar(_op.add),
"__rdiv_scalar__" : _rbinop_scalar(_op.divide), "_plus_scalar" : _binop_scalar(_op.add),
"__rpow_scalar__": _rbinop_scalar(_op.power), "__sub_scalar__" : _binop_scalar(_op.subtract),
"_minus_scalar" : _binop_scalar(_op.subtract),
"__mul_scalar__" : _binop_scalar(_op.multiply),
"_mul_scalar" : _binop_scalar(_op.multiply),
"__div_scalar__" : _binop_scalar(_op.divide),
"_div_scalar" : _binop_scalar(_op.divide),
"__pow_scalar__" : _binop_scalar(_op.power),
"_power_scalar" : _binop_scalar(_op.power),
"__rsub_scalar__" : _rbinop_scalar(_op.subtract),
"_rminus_scalar" : _rbinop_scalar(_op.subtract),
"__rdiv_scalar__" : _rbinop_scalar(_op.divide),
"_rdiv_scalar" : _rbinop_scalar(_op.divide),
"__rpow_scalar__" : _rbinop_scalar(_op.power),
"_equal_scalar" : _mx_compare(_op.equal, _binop_scalar),
"_not_equal_scalar" : _mx_compare(_op.not_equal, _binop_scalar),
"_greater_scalar" : _mx_compare(_op.greater, _binop_scalar),
"_greater_equal_scalar" : _mx_compare(_op.greater_equal, _binop_scalar),
"_lesser_scalar" : _mx_compare(_op.less, _binop_scalar),
"_lesser_equal_scalar" : _mx_compare(_op.less_equal, _binop_scalar),
"_maximum_scalar" : _binop_scalar(_op.maximum),
"_minimum_scalar" : _binop_scalar(_op.minimum),
# reduction ops # reduction ops
"max" : _reduce(_op.max), "max" : _reduce(_op.max),
"min" : _reduce(_op.min), "min" : _reduce(_op.min),
......
import numpy as np import numpy as np
import operator
import tvm import tvm
from tvm.contrib import graph_runtime from tvm.contrib import graph_runtime
...@@ -256,6 +257,85 @@ def test_forward_arange(): ...@@ -256,6 +257,85 @@ def test_forward_arange():
verify(20, 1, -1) verify(20, 1, -1)
verify(20, 1, -1.5) verify(20, 1, -1.5)
def _mx_symbol(F, op_name, inputs):
op = getattr(F, op_name)
return op(*inputs)
def test_forward_broadcast_ops():
for op in ["broadcast_add", "broadcast_sub", "broadcast_mul",
"broadcast_div", "broadcast_mod", "broadcast_maximum",
"broadcast_minimum", "broadcast_equal", "broadcast_not_equal",
"broadcast_greater", "broadcast_greater_equal",
"broadcast_lesser", "broadcast_lesser_equal"]:
a_shape = (3, 4, 5)
b_shape = (4, 5)
if op == "broadcast_mod":
dtype = 'int32'
a_np = np.random.randint(1, 100, size=a_shape).astype(dtype)
b_np = np.random.randint(1, 100, size=b_shape).astype(dtype)
else:
dtype = 'float32'
a_np = np.random.uniform(size=a_shape).astype(dtype)
b_np = np.random.uniform(size=b_shape).astype(dtype)
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
shapes = {'a': a_shape, 'b': b_shape}
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(a_np, b_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
def test_forward_elemwise_ops():
for op in ["elemwise_add", "elemwise_sub", "elemwise_mul",
"elemwise_div", "maximum", "minimum"]:
shape = (3, 4, 5)
dtype = 'float32'
a_np = np.random.uniform(size=shape).astype(dtype)
b_np = np.random.uniform(size=shape).astype(dtype)
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), mx.sym.var('b')])
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), mx.nd.array(b_np)])
shapes = {'a': shape, 'b': shape}
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(a_np, b_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
def test_forward_scalar_ops():
for op in [operator.add, operator.sub, operator.mul, operator.truediv,
operator.pow, operator.lt, operator.le, operator.eq,
operator.ne, operator.gt, operator.ge]:
dtype='float32'
a_shape = (3, 4, 5)
a_np = np.random.uniform(size=a_shape).astype(dtype)
b_scalar = 2.3
mx_sym = op(mx.sym.var('a'), b_scalar)
ref_res = op(mx.nd.array(a_np), b_scalar)
shapes = {'a': a_shape}
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
for op in ["maximum", "minimum"]:
dtype='float32'
a_shape = (3, 4, 5)
a_np = np.random.uniform(size=a_shape).astype(dtype)
b_scalar = 2.3
mx_sym = _mx_symbol(mx.sym, op, [mx.sym.var('a'), b_scalar])
ref_res = _mx_symbol(mx.nd, op, [mx.nd.array(a_np), b_scalar])
shapes = {'a': a_shape}
new_sym, _ = relay.frontend.from_mxnet(mx_sym, shapes, dtype)
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
if __name__ == '__main__': if __name__ == '__main__':
test_forward_mlp() test_forward_mlp()
...@@ -280,3 +360,6 @@ if __name__ == '__main__': ...@@ -280,3 +360,6 @@ if __name__ == '__main__':
test_forward_argmin() test_forward_argmin()
test_forward_where() test_forward_where()
test_forward_arange() test_forward_arange()
test_forward_broadcast_ops()
test_forward_elemwise_ops()
test_forward_scalar_ops()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment