Commit fcc5b422 by Steven S. Lyubomirsky Committed by Tianqi Chen

Ensure interpreted functions can take values that are not TensorValues (#3015)

parent 561e422b
......@@ -24,7 +24,7 @@ from . import _backend
from .. import _make, ir_pass
from ... import register_func, nd
from ..base import NodeBase, register_relay_node
from ..expr import Call, Constant, GlobalVar, Function, const
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder
class Value(NodeBase):
......@@ -112,6 +112,12 @@ class RefValue(Value):
def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue):
return Tuple([_arg_to_ast(field) for field in arg.fields])
elif isinstance(arg, RefValue):
return RefCreate(_arg_to_ast(arg.value))
elif isinstance(arg, ConstructorValue):
return Call(arg.constructor, [_arg_to_ast(field) for field in arg.fields])
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
elif isinstance(arg, Constant):
......
......@@ -19,6 +19,7 @@ import tvm
import tvm.testing
from tvm import relay
from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor
......@@ -156,6 +157,7 @@ def test_tensor_value():
xx = np.ones((1, 10)).astype("float32")
check_eval(relay.Function([x], x), [TensorValue(xx)], xx)
def test_kwargs_params():
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10))
......@@ -170,6 +172,46 @@ def test_kwargs_params():
tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data)
def test_function_taking_adt_ref_tuple():
mod = relay.Module()
prelude = relay.prelude.Prelude(mod)
intrp = create_executor("debug", mod)
nil_value = ConstructorValue(prelude.nil, [], [])
cons_value = ConstructorValue(prelude.cons, [
TensorValue(np.random.rand(1, 10).astype('float32')),
nil_value
], [relay.TensorType((1, 10), 'float32')])
ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
tuple_value = TupleValue(*[
TensorValue(np.random.rand(1, 10).astype('float32')) for _ in range(10)
])
id_func = intrp.evaluate(prelude.id)
res_nil = id_func(nil_value)
assert res_nil.constructor == nil_value.constructor
assert len(res_nil.fields) == 0
res_cons = id_func(cons_value)
assert res_cons.constructor == cons_value.constructor
assert len(res_cons.fields) == len(cons_value.fields)
tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
cons_value.fields[0].asnumpy())
assert isinstance(res_cons.fields[1], ConstructorValue)
assert res_cons.fields[1].constructor == prelude.nil
assert len(res_cons.fields[1].fields) == 0
res_ref = id_func(ref_value)
tvm.testing.assert_allclose(res_ref.value.asnumpy(), ref_value.value.asnumpy())
res_tuple = id_func(tuple_value)
for i in range(10):
tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(),
tuple_value.fields[i].asnumpy())
if __name__ == "__main__":
test_id()
test_add_const()
......@@ -181,3 +223,4 @@ if __name__ == "__main__":
test_kwargs_params()
test_ref()
test_tensor_value()
test_function_taking_adt_ref_tuple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment