Commit 0bf64ee0 by 雾雨魔理沙 Committed by Tianqi Chen

fix typo in backend interpreter (#2752)

parent f3b7c0ac
...@@ -95,7 +95,7 @@ class RefValue(Value): ...@@ -95,7 +95,7 @@ class RefValue(Value):
def _arg_to_ast(arg): def _arg_to_ast(arg):
if isinstance(arg, TensorValue): if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(_nd.cpu(0))) return Constant(arg.data.copyto(nd.cpu(0)))
elif isinstance(arg, np.ndarray): elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg)) return Constant(nd.array(arg))
elif isinstance(arg, Constant): elif isinstance(arg, Constant):
......
...@@ -2,7 +2,7 @@ import numpy as np ...@@ -2,7 +2,7 @@ import numpy as np
import tvm import tvm
import tvm.testing import tvm.testing
from tvm import relay from tvm import relay
from tvm.relay.backend.interpreter import Value, TupleValue from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue
from tvm.relay.scope_builder import ScopeBuilder from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor from tvm.relay import testing, create_executor
...@@ -135,6 +135,11 @@ def test_binds(): ...@@ -135,6 +135,11 @@ def test_binds():
tvm.testing.assert_allclose(xx + xx, res) tvm.testing.assert_allclose(xx + xx, res)
def test_tensor_value():
x = relay.var("x", shape=(1, 10))
xx = np.ones((1, 10)).astype("float32")
check_eval(relay.Function([x], x), [TensorValue(xx)], xx)
def test_kwargs_params(): def test_kwargs_params():
x = relay.var("x", shape=(1, 10)) x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10)) y = relay.var("y", shape=(1, 10))
...@@ -159,3 +164,4 @@ if __name__ == "__main__": ...@@ -159,3 +164,4 @@ if __name__ == "__main__":
test_binds() test_binds()
test_kwargs_params() test_kwargs_params()
test_ref() test_ref()
test_tensor_value()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment