Commit eb6d64f1 by Siva Committed by Tianqi Chen

[RELAY] bugfix. (#2215)

parent 20afa0a5
......@@ -45,8 +45,10 @@ def transpose(data, axes=None):
result : relay.Expr
The transposed result.
"""
axes = axes or []
return _make.transpose(data, list(axes))
if axes is not None:
axes = list(axes)
return _make.transpose(data, axes)
def squeeze(data, axis=None):
......
......@@ -82,11 +82,18 @@ def test_transpose_infer_type():
n, t, d = tvm.var("n"), tvm.var("t"), 100
x = relay.var("x", relay.TensorType((n, t, d), "float32"))
y = relay.transpose(x, axes=(1, 0, 2))
"axes=" in y.astext()
assert "axes=" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(
(t, n, 100), "float32")
y = relay.transpose(x)
assert "axes=" in y.astext()
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.TensorType(
(100, t, n), "float32")
def test_transpose():
def verify_transpose(dshape, axes):
x = relay.var("x", relay.TensorType(dshape, "float32"))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment