Commit fcc5b422 by Steven S. Lyubomirsky Committed by Tianqi Chen

Ensure interpreted functions can take values that are not TensorValues (#3015)

parent 561e422b
...@@ -24,7 +24,7 @@ from . import _backend ...@@ -24,7 +24,7 @@ from . import _backend
from .. import _make, ir_pass from .. import _make, ir_pass
from ... import register_func, nd from ... import register_func, nd
from ..base import NodeBase, register_relay_node from ..base import NodeBase, register_relay_node
from ..expr import Call, Constant, GlobalVar, Function, const from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const
from ..scope_builder import ScopeBuilder from ..scope_builder import ScopeBuilder
class Value(NodeBase): class Value(NodeBase):
...@@ -112,6 +112,12 @@ class RefValue(Value): ...@@ -112,6 +112,12 @@ class RefValue(Value):
def _arg_to_ast(arg): def _arg_to_ast(arg):
if isinstance(arg, TensorValue): if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(nd.cpu(0))) return Constant(arg.data.copyto(nd.cpu(0)))
elif isinstance(arg, TupleValue):
return Tuple([_arg_to_ast(field) for field in arg.fields])
elif isinstance(arg, RefValue):
return RefCreate(_arg_to_ast(arg.value))
elif isinstance(arg, ConstructorValue):
return Call(arg.constructor, [_arg_to_ast(field) for field in arg.fields])
elif isinstance(arg, np.ndarray): elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg)) return Constant(nd.array(arg))
elif isinstance(arg, Constant): elif isinstance(arg, Constant):
......
...@@ -19,6 +19,7 @@ import tvm ...@@ -19,6 +19,7 @@ import tvm
import tvm.testing import tvm.testing
from tvm import relay from tvm import relay
from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue
from tvm.relay.backend.interpreter import RefValue, ConstructorValue
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor from tvm.relay import testing, create_executor
...@@ -156,6 +157,7 @@ def test_tensor_value(): ...@@ -156,6 +157,7 @@ def test_tensor_value():
xx = np.ones((1, 10)).astype("float32") xx = np.ones((1, 10)).astype("float32")
check_eval(relay.Function([x], x), [TensorValue(xx)], xx) check_eval(relay.Function([x], x), [TensorValue(xx)], xx)
def test_kwargs_params(): def test_kwargs_params():
x = relay.var("x", shape=(1, 10)) x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10)) y = relay.var("y", shape=(1, 10))
...@@ -170,6 +172,46 @@ def test_kwargs_params(): ...@@ -170,6 +172,46 @@ def test_kwargs_params():
tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data)
def test_function_taking_adt_ref_tuple():
mod = relay.Module()
prelude = relay.prelude.Prelude(mod)
intrp = create_executor("debug", mod)
nil_value = ConstructorValue(prelude.nil, [], [])
cons_value = ConstructorValue(prelude.cons, [
TensorValue(np.random.rand(1, 10).astype('float32')),
nil_value
], [relay.TensorType((1, 10), 'float32')])
ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32')))
tuple_value = TupleValue(*[
TensorValue(np.random.rand(1, 10).astype('float32')) for _ in range(10)
])
id_func = intrp.evaluate(prelude.id)
res_nil = id_func(nil_value)
assert res_nil.constructor == nil_value.constructor
assert len(res_nil.fields) == 0
res_cons = id_func(cons_value)
assert res_cons.constructor == cons_value.constructor
assert len(res_cons.fields) == len(cons_value.fields)
tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(),
cons_value.fields[0].asnumpy())
assert isinstance(res_cons.fields[1], ConstructorValue)
assert res_cons.fields[1].constructor == prelude.nil
assert len(res_cons.fields[1].fields) == 0
res_ref = id_func(ref_value)
tvm.testing.assert_allclose(res_ref.value.asnumpy(), ref_value.value.asnumpy())
res_tuple = id_func(tuple_value)
for i in range(10):
tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(),
tuple_value.fields[i].asnumpy())
if __name__ == "__main__": if __name__ == "__main__":
test_id() test_id()
test_add_const() test_add_const()
...@@ -181,3 +223,4 @@ if __name__ == "__main__": ...@@ -181,3 +223,4 @@ if __name__ == "__main__":
test_kwargs_params() test_kwargs_params()
test_ref() test_ref()
test_tensor_value() test_tensor_value()
test_function_taking_adt_ref_tuple()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment