Commit d0c45648 by Jared Roesch Committed by Tianqi Chen

[Relay][Backend] Fix interpreter argument conversion for tuples. (#3349)

* Support taking a tuple as an argument

* Add test
parent 499adfdb
......@@ -118,6 +118,8 @@ def _arg_to_ast(arg):
return Constant(arg.data.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue):
return Tuple([_arg_to_ast(field) for field in arg.fields])
elif isinstance(arg, tuple):
return Tuple([_arg_to_ast(field) for field in arg])
elif isinstance(arg, RefValue):
return RefCreate(_arg_to_ast(arg.value))
elif isinstance(arg, ConstructorValue):
......
......@@ -217,6 +217,31 @@ def test_function_taking_adt_ref_tuple():
tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(),
tuple_value.fields[i].asnumpy())
def test_tuple_passing():
x = relay.var('x', type_annotation=relay.ty.TupleType([
relay.ty.TensorType((), 'int64'),
relay.ty.TensorType((), 'int64')]))
fn = relay.Function([x], relay.expr.TupleGetItem(x, 0))
mod = relay.Module({})
gv = relay.GlobalVar('fn')
mod[gv] = fn
mod.entry_func = gv
mod[gv] = relay.ir_pass.infer_type(mod[gv], mod=mod)
ctx = tvm.cpu()
target = tvm.target.create('llvm')
exec = relay.create_executor(mod=mod, ctx=ctx, target=target)
f = exec.evaluate(gv)
# First use a Python tuple.
out = f((10, 8))
tvm.testing.assert_allclose(out.asnumpy(), np.array(10))
# Second use a tuple value.
value_tuple = TupleValue(
TensorValue(np.array(11)),
TensorValue(np.array(12)))
out = f(value_tuple)
tvm.testing.assert_allclose(out.asnumpy(), np.array(11))
if __name__ == "__main__":
test_id()
......@@ -232,3 +257,4 @@ if __name__ == "__main__":
test_tuple_value()
test_tuple_getitem()
test_function_taking_adt_ref_tuple()
test_tuple_passing()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment