Commit 89acfeb2 by Haichen Shen Committed by Leyuan Wang

[Relay][Frontend] Add ops in mxnet converter (#2844)

* Add ops in mxnet converter

* trigger ci
parent f81e2873
......@@ -213,7 +213,7 @@ def _mx_slice_axis(inputs, attrs):
ax_end = attrs.get_str("end")
if axis < 0:
axis += len(shape)
assert axis >= 0 and axis < len(shape)
assert 0 <= axis < len(shape)
if ax_end == "None":
ax_end = int(shape[axis])
else:
......@@ -222,8 +222,8 @@ def _mx_slice_axis(inputs, attrs):
ax_beg += int(shape[axis])
if ax_end < 0:
ax_end += int(shape[axis])
assert ax_beg >= 0 and ax_beg < int(shape[axis])
assert ax_end > ax_beg and ax_end <= int(shape[axis])
assert 0 <= ax_beg < int(shape[axis])
assert ax_beg < ax_end <= int(shape[axis])
begin = []
end = []
for i, dim in enumerate(shape):
......@@ -527,11 +527,53 @@ def _mx_shape_array(inputs, attrs):
return _op.shape_of(inputs[0], dtype='int64')
def _mx_full(inputs, attrs):
assert len(inputs) == 0
val = attrs.get_float("value")
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
return _op.full(_expr.const(val, dtype), shape, dtype)
def _mx_squeeze(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int_tuple("axis", None)
return _op.squeeze(inputs[0], axis)
def _mx_broadcast_axis(inputs, attrs):
assert len(inputs) == 1
axis = attrs.get_int_tuple("axis", [])
size = attrs.get_int_tuple("size", [])
assert len(axis) == len(size)
if len(axis) == 0:
return inputs[0]
src_shape = ir_pass.infer_type(inputs[0])._checked_type_.shape
tgt_shape = []
for i, dim in enumerate(src_shape):
if i not in axis:
tgt_shape.append(dim)
else:
assert int(dim) == 1
idx = axis.index(i)
tgt_shape.append(size[idx])
return _op.broadcast_to(inputs[0], tgt_shape)
def _mx_embedding(inputs, _):
assert len(inputs) == 2
indices, weight = inputs
return _op.take(weight, indices.astype('int32'), axis=0)
# Note: due to attribute conversion constraint
# ops in the identity set must be attribute free
_identity_list = [
"log",
"exp",
"sqrt",
"floor",
"ceil",
"sigmoid",
"tanh",
"negative",
......@@ -567,7 +609,6 @@ _convert_map = {
"Flatten" : _rename(_op.nn.batch_flatten),
# scalar power
"square" : _mx_make_power(2),
"sqrt" : _mx_make_power(1/2),
"rsqrt" : _mx_make_power(-1/2),
"cbrt" : _mx_make_power(1/3),
"rcbrt" : _mx_make_power(-1/3),
......@@ -649,11 +690,15 @@ _convert_map = {
"batch_dot" : _mx_batch_dot,
"LeakyReLU" : _mx_leaky_relu,
"_arange" : _mx_arange,
"_full" : _mx_full,
"repeat" : _mx_repeat,
"tile" : _mx_tile,
"reverse" : _mx_reverse,
"squeeze" : _mx_squeeze,
"broadcast_axis": _mx_broadcast_axis,
"BlockGrad" : _mx_BlockGrad,
"shape_array" : _mx_shape_array,
"Embedding" : _mx_embedding,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
# vision
......
......@@ -379,7 +379,6 @@ def test_forward_l2_normalize():
mx_sym = mx.sym.L2Normalization(data, mode="channel")
verify_mxnet_frontend_impl(mx_sym, (2, 3, 4, 5), (2, 3, 4, 5))
def test_forward_shape_array():
def verify(shape):
x_np = np.random.uniform(size=shape).astype("float32")
......@@ -395,6 +394,75 @@ def test_forward_shape_array():
verify((3, 4, 5))
verify((3, 4, 5, 6))
def test_forward_squeeze():
def verify(shape, axis):
x_np = np.random.uniform(size=shape).astype("float32")
if axis is None:
ref_res = mx.nd.squeeze(mx.nd.array(x_np))
mx_sym = mx.sym.squeeze(mx.sym.var("x"))
else:
ref_res = mx.nd.squeeze(mx.nd.array(x_np), axis=axis)
mx_sym = mx.sym.squeeze(mx.sym.var("x"), axis=axis)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((1, 3, 1), None)
verify((1, 3, 1), 0)
verify((1, 3, 1), 2)
verify((1, 3, 1), (0, 2))
def test_forward_broadcast_axis():
def verify(shape, axis, size):
x_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.broadcast_axis(mx.nd.array(x_np), axis=axis, size=size)
mx_sym = mx.sym.broadcast_axis(mx.sym.var("x"), axis=axis, size=size)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((1, 2, 1), 2, 3)
verify((1, 2, 1), (0, 2), (2, 3))
def test_forward_full():
def verify(val, shape, dtype):
ctx = mx.cpu()
ref_res = mx.nd.full(shape, val, dtype=dtype)
mx_sym = mx.sym.full(shape, val, dtype=dtype)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {})
for target, ctx in ctx_list():
# Skip testing graph runtime because this op will be optimized out
# by constant folding.
for kind in ["debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)()
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify(2, (3, 4), "float32")
verify(2, (3, 4), "int32")
verify(3.5, (1, 3, 4), "float32")
def test_forward_embedding():
def verify(data_shape, weight_shape):
in_dim, out_dim = weight_shape
x_np = np.random.randint(0, weight_shape[0], size=data_shape).astype("float32")
w_np = np.random.uniform(size=weight_shape).astype("float32")
ref_res = mx.nd.Embedding(mx.nd.array(x_np), mx.nd.array(w_np),
input_dim=in_dim, output_dim=out_dim)
mx_sym = mx.sym.Embedding(mx.sym.var("x"), mx.sym.var("w"),
input_dim=in_dim, output_dim=out_dim)
new_sym, _ = relay.frontend.from_mxnet(
mx_sym, {"x": data_shape, "w": weight_shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x=x_np, w=w_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((2, 2), (4, 5))
verify((2, 3, 4), (4, 5))
if __name__ == '__main__':
test_forward_mlp()
......@@ -426,3 +494,7 @@ if __name__ == '__main__':
test_forward_slice_axis()
test_forward_l2_normalize()
test_forward_shape_array()
test_forward_squeeze()
test_forward_broadcast_axis()
test_forward_full()
test_forward_embedding()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment