Commit 0bf64ee0 by 雾雨魔理沙 Committed by Tianqi Chen

fix typo in backend interpreter (#2752)

parent f3b7c0ac
......@@ -95,7 +95,7 @@ class RefValue(Value):
def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(_nd.cpu(0)))
return Constant(arg.data.copyto(nd.cpu(0)))
elif isinstance(arg, np.ndarray):
return Constant(nd.array(arg))
elif isinstance(arg, Constant):
......
......@@ -2,7 +2,7 @@ import numpy as np
import tvm
import tvm.testing
from tvm import relay
from tvm.relay.backend.interpreter import Value, TupleValue
from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue
from tvm.relay.scope_builder import ScopeBuilder
from tvm.relay import testing, create_executor
......@@ -135,6 +135,11 @@ def test_binds():
tvm.testing.assert_allclose(xx + xx, res)
def test_tensor_value():
x = relay.var("x", shape=(1, 10))
xx = np.ones((1, 10)).astype("float32")
check_eval(relay.Function([x], x), [TensorValue(xx)], xx)
def test_kwargs_params():
x = relay.var("x", shape=(1, 10))
y = relay.var("y", shape=(1, 10))
......@@ -159,3 +164,4 @@ if __name__ == "__main__":
test_binds()
test_kwargs_params()
test_ref()
test_tensor_value()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment