Commit 52d5cf89 by Haichen Shen Committed by Leyuan Wang

[Bugfix][Relay][Frontend] Fix bug in mxnet converter for slick_like (#2744)

* Fix bug in mxnet converter for slick_like

* More tolerance for topi_conv2d_NCHWc
parent 66cad621
......@@ -194,6 +194,13 @@ def _mx_slice(inputs, attrs):
return _op.strided_slice(inputs[0], **new_attrs)
def _mx_slice_like(inputs, attrs):
assert len(inputs) == 2
new_attrs = {}
new_attrs["axes"] = attrs.get_int_tuple("axes", None)
return _op.slice_like(*inputs, **new_attrs)
def _mx_slice_axis(inputs, attrs):
assert len(inputs) == 1
shape = ir_pass.infer_type(inputs[0]).checked_type.shape
......@@ -383,7 +390,6 @@ _identity_list = [
"exp",
"negative",
"reshape_like",
"slice_like",
"zeros_like",
"ones_like",
"where",
......@@ -473,6 +479,7 @@ _convert_map = {
"BatchNorm_v1" : _mx_batch_norm,
"LRN" : _mx_lrn,
"slice" : _mx_slice,
"slice_like" : _mx_slice_like,
"slice_axis" : _mx_slice_axis,
"SliceChannel" : _mx_split,
"split" : _mx_split,
......
......@@ -336,7 +336,6 @@ def test_forward_scalar_ops():
op_res = intrp.evaluate(new_sym)(a_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
def test_forward_slice_axis():
def verify(shape, axis, begin, end):
data_np = np.random.uniform(size=shape).astype("float32")
......@@ -354,6 +353,27 @@ def test_forward_slice_axis():
verify((3, 4), 1, -3, -1)
verify((3, 4), -1, -3, -1)
def test_forward_slice_like():
def verify(x_shape, y_shape, axes):
x_np = np.random.uniform(size=x_shape).astype("float32")
y_np = np.random.uniform(size=y_shape).astype("float32")
if axes is None:
ref_res = mx.nd.slice_like(mx.nd.array(x_np), mx.nd.array(y_np))
mx_sym = mx.sym.slice_like(mx.sym.var("x"), mx.sym.var("y"))
else:
ref_res = mx.nd.slice_like(mx.nd.array(x_np), mx.nd.array(y_np), axes=axes)
mx_sym = mx.sym.slice_like(mx.sym.var("x"), mx.sym.var("y"), axes=axes)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": x_shape, "y": y_shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(x_np, y_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((3, 4), (2, 3), None)
verify((3, 4), (2, 3), (0, 1))
verify((3, 4), (2, 3), (0))
verify((3, 4), (2, 3), (-1))
if __name__ == '__main__':
test_forward_mlp()
......@@ -382,3 +402,4 @@ if __name__ == '__main__':
test_forward_elemwise_ops()
test_forward_scalar_ops()
test_forward_slice_axis()
test_forward_slice_like()
......@@ -105,7 +105,7 @@ def verify_conv2d_NCHWc(batch, in_channel, in_size, num_filter, kernel, stride,
name="relu_%d_%d_%d_%d_%d_%d_%d_%d" %
(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-3)
# test llvm only for now since conv2d_NCHWc implement is missing in other backend.
for device in ["llvm"]:
......@@ -202,4 +202,4 @@ def test_conv2d_NCHWc():
verify_conv2d_NCHWc(1, 256, 3, 126, 3, 1, 1)
if __name__ == "__main__":
test_conv2d_NCHWc()
\ No newline at end of file
test_conv2d_NCHWc()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment