Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
c48812dd
Commit
c48812dd
authored
Oct 31, 2018
by
Jared Roesch
Committed by
Tianqi Chen
Oct 31, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY][RUNTIME] Refactor interpreter and graph_runtime into consistent interface. (#2042)
parent
0319f99d
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
282 additions
and
232 deletions
+282
-232
python/tvm/relay/__init__.py
+2
-0
python/tvm/relay/build_module.py
+46
-0
python/tvm/relay/expr.py
+112
-0
python/tvm/relay/graph_runtime_codegen.py
+3
-186
python/tvm/relay/interpreter.py
+105
-27
tests/python/relay/test_graph_runtime.py
+10
-11
tests/python/relay/test_interpreter.py
+4
-8
No files found.
python/tvm/relay/__init__.py
View file @
c48812dd
...
...
@@ -7,6 +7,8 @@ from . import ty
from
.
import
expr
from
.
import
env
from
.
import
ir_pass
from
.build_module
import
build
from
.interpreter
import
create_executor
# Root operators
from
.op
import
Op
...
...
python/tvm/relay/build_module.py
0 → 100644
View file @
c48812dd
"""
Construct the necessary state for the TVM graph runtime
from a Relay expression.
"""
from
..build_module
import
build
as
tvm_build_module
from
.
graph_runtime_codegen
import
GraphRuntimeCodegen
from
.
import
ir_pass
from
.env
import
Environment
def
build
(
func
,
params
=
None
,
target
=
None
,
env
=
None
):
"""
Compile a single function to the components needed by the
TVM RTS.
Parameters
----------
func: relay.Expr
The function to build.
target: optional str
The target platform.
Returns
-------
(graph_json, mod, params): tuple of (str, tvm.Module, dict)
The outputs of building a Relay function for the TVM runtime.
"""
if
target
is
None
:
target
=
'llvm'
if
env
is
None
:
env
=
Environment
({})
comp
=
GraphRuntimeCodegen
(
env
)
# NB(@jroesch) This creates lowered functions, and generates names for them
#
# We need these names to emit the correct graph as these are names of the
# functions contained in the module.
lowered_ops
=
ir_pass
.
lower_ops
(
env
,
func
)
mod
=
tvm_build_module
([
lf
.
lowered_func
for
lf
in
lowered_ops
],
target
)
# Therefore the call to compile must come after.
comp
.
codegen
(
func
)
graph_json
=
comp
.
to_json
()
return
graph_json
,
mod
,
params
python/tvm/relay/expr.py
View file @
c48812dd
...
...
@@ -319,6 +319,118 @@ class TupleGetItem(Expr):
self
.
__init_handle_by_constructor__
(
_make
.
TupleGetItem
,
tuple_value
,
index
)
class
ExprFunctor
(
object
):
"""
An abstract visitor defined over Expr.
A Python version of the class defined in `expr_functor.h`.
Defines the default dispatch over expressions, and
implements memoization.
"""
def
__init__
(
self
):
self
.
memo_map
=
{}
# pylint: disable=no-else-return
def
visit
(
self
,
expr
):
"""Apply the visitor to an expression."""
found
=
self
.
memo_map
.
get
(
expr
)
if
found
:
return
found
if
isinstance
(
expr
,
Function
):
res
=
self
.
visit_function
(
expr
)
elif
isinstance
(
expr
,
Call
):
res
=
self
.
visit_call
(
expr
)
elif
isinstance
(
expr
,
Let
):
res
=
self
.
visit_let
(
expr
)
elif
isinstance
(
expr
,
Var
):
res
=
self
.
visit_var
(
expr
)
elif
isinstance
(
expr
,
GlobalVar
):
res
=
self
.
visit_global_var
(
expr
)
elif
isinstance
(
expr
,
If
):
res
=
self
.
visit_if
(
expr
)
elif
isinstance
(
expr
,
Tuple
):
res
=
self
.
visit_tuple
(
expr
)
elif
isinstance
(
expr
,
Constant
):
res
=
self
.
visit_constant
(
expr
)
else
:
raise
Exception
(
"warning unhandled case: {0}"
.
format
(
type
(
expr
)))
self
.
memo_map
[
expr
]
=
res
return
res
def
visit_function
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_let
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_call
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_var
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_type
(
self
,
typ
):
return
typ
def
visit_if
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_tuple
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_constant
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_global_var
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
class
ExprMutator
(
ExprFunctor
):
"""
A functional visitor over Expr.
The default behavior recursively traverses the AST
and reconstructs the AST.
"""
def
visit_function
(
self
,
fn
):
new_body
=
self
.
visit
(
fn
.
body
)
return
Function
(
list
(
fn
.
params
),
fn
.
ret_type
,
new_body
,
fn
.
type_params
)
def
visit_let
(
self
,
let
):
new_var
=
self
.
visit
(
let
.
var
)
new_val
=
self
.
visit
(
let
.
value
)
new_body
=
self
.
visit
(
let
.
body
)
return
Let
(
new_var
,
new_val
,
new_body
)
def
visit_call
(
self
,
call
):
new_fn
=
self
.
visit
(
call
.
op
)
new_args
=
[
self
.
visit
(
arg
)
for
arg
in
call
.
args
]
return
Call
(
new_fn
,
new_args
,
call
.
attrs
)
def
visit_var
(
self
,
rvar
):
return
rvar
def
visit_global_id
(
self
,
global_var
):
return
global_var
def
visit_if
(
self
,
ite
):
return
If
(
self
.
visit
(
ite
.
guard
),
self
.
visit
(
ite
.
true_b
),
self
.
visit
(
ite
.
false_b
))
def
visit_tuple
(
self
,
tup
):
return
Tuple
([
self
.
visit
(
field
)
for
field
in
tup
.
fields
])
def
visit_constant
(
self
,
rconst
):
return
rconst
class
TupleWrapper
(
object
):
"""TupleWrapper.
...
...
python/tvm/relay/graph_runtime_codegen.py
View file @
c48812dd
...
...
@@ -25,113 +25,7 @@ import json
import
attr
from
.
import
ir_pass
from
.op
import
Op
from
.expr
import
Var
,
Function
,
Call
,
If
,
GlobalVar
,
Constant
,
Let
,
Tuple
from
..build_module
import
build
as
tvm_build_module
from
..
contrib
import
graph_runtime
from
.ir_pass
import
infer_type
from
..
import
cpu
class
AbstractExprVisitor
(
object
):
"""A visitor over Expr in Python."""
def
__init__
(
self
):
self
.
memo_map
=
{}
# pylint: disable=no-else-return
def
visit
(
self
,
expr
):
"""Apply the visitor to an expression."""
found
=
self
.
memo_map
.
get
(
expr
)
if
found
:
return
found
if
isinstance
(
expr
,
Function
):
res
=
self
.
visit_function
(
expr
)
elif
isinstance
(
expr
,
Call
):
res
=
self
.
visit_call
(
expr
)
elif
isinstance
(
expr
,
Let
):
res
=
self
.
visit_let
(
expr
)
elif
isinstance
(
expr
,
Var
):
res
=
self
.
visit_var
(
expr
)
elif
isinstance
(
expr
,
GlobalVar
):
res
=
self
.
visit_global_var
(
expr
)
elif
isinstance
(
expr
,
If
):
res
=
self
.
visit_if
(
expr
)
elif
isinstance
(
expr
,
Tuple
):
res
=
self
.
visit_tuple
(
expr
)
elif
isinstance
(
expr
,
Constant
):
res
=
self
.
visit_constant
(
expr
)
else
:
raise
Exception
(
"warning unhandled case: {0}"
.
format
(
type
(
expr
)))
self
.
memo_map
[
expr
]
=
res
return
res
def
visit_function
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_let
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_call
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_var
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_type
(
self
,
typ
):
return
typ
def
visit_if
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_tuple
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_constant
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
def
visit_global_var
(
self
,
_
):
raise
Exception
(
"Abstract method please implement me."
)
class
ExprMutator
(
AbstractExprVisitor
):
"""A functional visitor over Expr in Python."""
def
visit_function
(
self
,
fn
):
new_body
=
self
.
visit
(
fn
.
body
)
return
Function
(
list
(
fn
.
params
),
fn
.
ret_type
,
new_body
,
fn
.
type_params
)
def
visit_let
(
self
,
let
):
new_var
=
self
.
visit
(
let
.
var
)
new_val
=
self
.
visit
(
let
.
value
)
new_body
=
self
.
visit
(
let
.
body
)
return
Let
(
new_var
,
new_val
,
new_body
)
def
visit_call
(
self
,
call
):
new_fn
=
self
.
visit
(
call
.
op
)
new_args
=
[
self
.
visit
(
arg
)
for
arg
in
call
.
args
]
return
Call
(
new_fn
,
new_args
,
call
.
attrs
)
def
visit_var
(
self
,
var
):
return
var
def
visit_global_id
(
self
,
global_var
):
return
global_var
def
visit_if
(
self
,
ite
):
return
If
(
self
.
visit
(
ite
.
guard
),
self
.
visit
(
ite
.
true_b
),
self
.
visit
(
ite
.
false_b
))
def
visit_tuple
(
self
,
tup
):
return
Tuple
([
self
.
visit
(
field
)
for
field
in
tup
.
fields
])
def
visit_constant
(
self
,
const
):
return
const
from
.expr
import
Function
,
GlobalVar
,
ExprMutator
@attr.s
...
...
@@ -359,8 +253,8 @@ class GraphRuntimeCodegen(ExprMutator):
self
.
add_binding
(
ident
,
val_ref
)
return
self
.
visit
(
body
)
def
visit_var
(
self
,
var
):
return
self
.
lookup
(
var
)
def
visit_var
(
self
,
r
var
):
return
self
.
lookup
(
r
var
)
def
visit_call
(
self
,
call
):
"""Transform a ::tvm.relay.Call into an operator in the TVM graph."""
...
...
@@ -472,80 +366,3 @@ class GraphRuntimeCodegen(ExprMutator):
}
return
json
.
dumps
(
json_dict
)
def
build
(
env
,
func
,
target
=
None
):
"""
Compile a single function to the components needed by the
TVM RTS.
Parameters
----------
func: relay.Expr
The function to build.
target: optional str
The target platform.
Returns
-------
(graph_json, mod, params): tuple of (str, tvm.Module, dict)
The outputs of building a Relay function for the TVM runtime.
"""
if
target
is
None
:
target
=
'llvm'
comp
=
GraphRuntimeCodegen
(
env
)
# NB(@jroesch) This creates lowered functions, and generates names for them
#
# We need these names to emit the correct graph as these are names of the
# functions contained in the module.
lowered_ops
=
ir_pass
.
lower_ops
(
env
,
func
)
mod
=
tvm_build_module
([
lf
.
lowered_func
for
lf
in
lowered_ops
],
target
)
# Therefore the call to compile must come after.
comp
.
codegen
(
func
)
graph_json
=
comp
.
to_json
()
return
graph_json
,
mod
,
None
# params currently isn't supported by API
def
graph_evaluate
(
env
,
func
,
*
args
):
"""
Corresponding function to tvm.relay.eval.evaluate.
This function evaluates a Relay expression on the
TVM graph_runtime.
Parameters
----------
env: tvm.relay.Environment
The global environment used.
expr: tvm.relay.Expr
The expression to evaluate.
args: list of tvm.relay.Expr
The arguments to apply to the expression, only works
if the expression has a function type.
Returns
-------
value: tvm.NDArray
The output Tensor produced by evaluating the expression.
"""
func
=
infer_type
(
func
,
env
)
func
=
ir_pass
.
fuse_ops
(
env
,
func
)
func
=
infer_type
(
func
,
env
)
graph_json
,
mod
,
params
=
build
(
env
,
func
)
assert
params
is
None
gmodule
=
graph_runtime
.
create
(
graph_json
,
mod
,
cpu
(
0
))
# Create map of inputs.
inputs
=
{}
for
i
,
arg
in
enumerate
(
args
):
inputs
[
func
.
params
[
i
]
.
name_hint
]
=
arg
# Set the inputs here.
gmodule
.
set_input
(
**
inputs
)
# Run the module, and fetch the output.
gmodule
.
run
()
return
gmodule
.
get_output
(
0
)
python/tvm/relay/interpreter.py
View file @
c48812dd
...
...
@@ -4,12 +4,16 @@ from __future__ import absolute_import
import
numpy
as
np
from
..
import
register_func
,
nd
from
.base
import
NodeBase
,
register_relay_node
from
.
import
build_module
from
.
import
_make
from
.
import
_interpreter
from
.
import
ir_pass
from
.expr
import
Call
,
Constant
,
GlobalVar
from
.
import
const
from
.env
import
Environment
from
.expr
import
Call
,
Constant
,
GlobalVar
,
Function
,
const
from
.scope_builder
import
ScopeBuilder
from
.._ffi.base
import
integer_types
from
..contrib
import
graph_runtime
as
tvm_runtime
from
..
import
cpu
class
Value
(
NodeBase
):
"""Base class of all values.
...
...
@@ -83,48 +87,122 @@ def _arg_to_ast(arg):
else
:
return
const
(
arg
)
class
Executor
(
object
):
"""An abstract interface for executing Relay programs."""
def
__init__
(
self
,
env
=
None
):
"""
Parameters
----------
env: relay.Environment
The environment.
"""
if
env
is
None
:
self
.
env
=
Environment
({})
else
:
self
.
env
=
env
def
apply_passes
(
expr
,
env
=
None
):
ck_expr
=
ir_pass
.
infer_type
(
expr
,
env
=
env
)
fused_expr
=
ir_pass
.
fuse_ops
(
env
,
ck_expr
)
return
fused_expr
def
optimize
(
self
,
expr
):
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr
=
ir_pass
.
infer_type
(
expr
,
env
=
self
.
env
)
fused_expr
=
ir_pass
.
fuse_ops
(
self
.
env
,
ck_expr
)
ck_fused
=
ir_pass
.
infer_type
(
fused_expr
,
env
=
self
.
env
)
return
ck_fused
def
evaluate
(
env
,
expr
,
*
args
):
def
_make_executor
(
self
,
_
):
"""
Evaluate a Relay expression on the interpreter.
Construct a Python function that implements the evaluation
of expression.
Parameters
----------
env: tvm.relay.Environment
The global environment used.
expr: relay.Expr
The Relay expression to execute.
Returns
-------
executor: function
A Python function which implements the behavior of `expr`.
"""
raise
Exception
(
"abstract method: please implement me."
)
def
evaluate
(
self
,
expr
,
params
=
None
):
"""
Evaluate a Relay expression on the interpreter.
Parameters
----------
expr: tvm.relay.Expr
The expression to evaluate.
"""
if
params
:
scope_builder
=
ScopeBuilder
()
for
key
,
value
in
params
:
scope_builder
.
let
(
key
,
value
)
scope_builder
.
ret
(
expr
)
expr
=
scope_builder
.
get
()
args: list of tvm.relay.Expr
The arguments to apply to the expression, only works
if the expression has a function type.
if
isinstance
(
expr
,
Function
):
assert
not
ir_pass
.
free_vars
(
expr
)
Returns
-------
value: tvm.relay.eval.Value
The value produced by evaluating the expression.
return
self
.
_make_executor
(
expr
)
class
Interpreter
(
Executor
):
"""
A wrapper around the Relay interpreter, implements the excecutor interface.
"""
# assert len(args) == 0
def
__init__
(
self
,
env
=
None
):
Executor
.
__init__
(
self
,
env
)
def
_make_executor
(
self
,
expr
):
def
_interp_wrapper
(
*
args
):
relay_args
=
[]
for
arg
in
args
:
relay_args
.
append
(
_arg_to_ast
(
arg
))
# TODO: We need to move this optimization code into the optimizer/pass manager
if
isinstance
(
expr
,
GlobalVar
):
func
=
env
[
expr
]
func
=
apply_passes
(
func
,
env
)
env
.
_add
(
expr
,
func
,
True
)
func
=
self
.
env
[
expr
]
func
=
self
.
optimize
(
func
)
self
.
env
.
_add
(
expr
,
func
,
True
)
opt_expr
=
Call
(
expr
,
relay_args
)
# import pdb; pdb.set_trace()
return
_interpreter
.
evaluate
(
env
,
opt_expr
)
return
_interpreter
.
evaluate
(
self
.
env
,
opt_expr
)
else
:
call
=
Call
(
expr
,
relay_args
)
opt_expr
=
self
.
optimize
(
call
)
return
_interpreter
.
evaluate
(
self
.
env
,
opt_expr
)
return
_interp_wrapper
class
GraphRuntime
(
Executor
):
"""A wrapper around the TVM graph runtime, implements the Executor interface."""
def
__init__
(
self
,
env
=
None
):
Executor
.
__init__
(
self
,
env
)
def
_make_executor
(
self
,
expr
):
def
_graph_wrapper
(
*
args
):
func
=
self
.
optimize
(
expr
)
graph_json
,
mod
,
params
=
build_module
.
build
(
func
,
env
=
self
.
env
)
assert
params
is
None
gmodule
=
tvm_runtime
.
create
(
graph_json
,
mod
,
cpu
(
0
))
# Create map of inputs.
inputs
=
{}
for
i
,
arg
in
enumerate
(
args
):
inputs
[
func
.
params
[
i
]
.
name_hint
]
=
arg
# Set the inputs here.
gmodule
.
set_input
(
**
inputs
)
# Run the module, and fetch the output.
gmodule
.
run
()
return
gmodule
.
get_output
(
0
)
return
_graph_wrapper
def
create_executor
(
mode
=
'debug'
,
env
=
None
):
if
mode
==
'debug'
:
return
Interpreter
(
env
)
elif
mode
==
'graph'
:
return
GraphRuntime
(
env
)
else
:
expr
=
Call
(
expr
,
relay_args
)
opt_expr
=
apply_passes
(
expr
,
env
)
return
_interpreter
.
evaluate
(
env
,
opt_expr
)
raise
Exception
(
"unknown mode {0}"
.
format
(
mode
))
tests/python/relay/test_graph_runtime.py
View file @
c48812dd
import
numpy
as
np
from
tvm
import
relay
from
tvm.relay
import
create_executor
from
tvm.relay.ir_pass
import
infer_type
from
tvm.relay.interpreter
import
evaluate
from
tvm.relay.graph_runtime_codegen
import
graph_evaluate
from
tvm.relay.interpreter
import
Interpreter
from
tvm.relay.scope_builder
import
ScopeBuilder
from
tvm.relay.op
import
add
from
tvm.relay.env
import
Environment
# @tq, @jr should we put this in testing ns?
def
check_rts
(
e
nv
,
expr
,
args
,
expected_result
):
def
check_rts
(
e
xpr
,
args
,
expected_result
,
env
=
None
):
"""
Check that evaluating `expr` applied to the arguments produces
`result` on both the evaluator and TVM runtime.
...
...
@@ -25,8 +25,10 @@ def check_rts(env, expr, args, expected_result):
expected_result:
The expected result of running the expression.
"""
eval_result
=
evaluate
(
env
,
expr
,
*
args
)
rts_result
=
graph_evaluate
(
env
,
expr
,
*
args
)
intrp
=
create_executor
(
'graph'
,
env
=
env
)
graph
=
create_executor
(
'graph'
,
env
=
env
)
eval_result
=
intrp
.
evaluate
(
expr
)(
*
args
)
rts_result
=
graph
.
evaluate
(
expr
)(
*
args
)
np
.
testing
.
assert_allclose
(
eval_result
.
asnumpy
(),
rts_result
.
asnumpy
())
def
test_add_op_scalar
():
...
...
@@ -36,13 +38,12 @@ def test_add_op_scalar():
return x + y;
}
"""
env
=
Environment
()
x
=
relay
.
var
(
'x'
,
shape
=
())
y
=
relay
.
var
(
'y'
,
shape
=
())
func
=
relay
.
Function
([
x
,
y
],
add
(
x
,
y
))
x_data
=
np
.
array
(
10.0
,
dtype
=
'float32'
)
y_data
=
np
.
array
(
1.0
,
dtype
=
'float32'
)
check_rts
(
env
,
func
,
[
x_data
,
y_data
],
x_data
+
y_data
)
check_rts
(
func
,
[
x_data
,
y_data
],
x_data
+
y_data
)
def
test_add_op_tensor
():
"""
...
...
@@ -51,13 +52,12 @@ def test_add_op_tensor():
return x + y;
}
"""
env
=
Environment
()
x
=
relay
.
var
(
'x'
,
shape
=
(
10
,
5
))
y
=
relay
.
var
(
'y'
,
shape
=
(
10
,
5
))
func
=
relay
.
Function
([
x
,
y
],
add
(
x
,
y
))
x_data
=
np
.
random
.
rand
(
10
,
5
)
.
astype
(
'float32'
)
y_data
=
np
.
random
.
rand
(
10
,
5
)
.
astype
(
'float32'
)
check_rts
(
env
,
func
,
[
x_data
,
y_data
],
x_data
+
y_data
)
check_rts
(
func
,
[
x_data
,
y_data
],
x_data
+
y_data
)
def
test_add_op_broadcast
():
"""
...
...
@@ -66,13 +66,12 @@ def test_add_op_broadcast():
return x + y;
}
"""
env
=
Environment
()
x
=
relay
.
var
(
'x'
,
shape
=
(
10
,
5
))
y
=
relay
.
var
(
'y'
,
shape
=
(
1
,
5
))
func
=
relay
.
Function
([
x
,
y
],
add
(
x
,
y
))
x_data
=
np
.
random
.
rand
(
10
,
5
)
.
astype
(
'float32'
)
y_data
=
np
.
random
.
rand
(
1
,
5
)
.
astype
(
'float32'
)
check_rts
(
env
,
func
,
[
x_data
,
y_data
],
x_data
+
y_data
)
check_rts
(
func
,
[
x_data
,
y_data
],
x_data
+
y_data
)
if
__name__
==
"__main__"
:
test_add_op_scalar
()
...
...
tests/python/relay/test_interpreter.py
View file @
c48812dd
import
numpy
as
np
import
tvm
from
tvm
import
relay
from
tvm.relay.interpreter
import
Value
,
TupleValue
,
evaluate
from
tvm.relay.interpreter
import
Value
,
TupleValue
from
tvm.relay
import
op
from
tvm.relay.scope_builder
import
ScopeBuilder
from
tvm.relay
import
testing
from
tvm.relay
import
testing
,
create_executor
def
check_eval
(
expr
,
args
,
expected_result
,
env
=
None
,
rtol
=
1e-07
):
if
env
is
None
:
env
=
relay
.
env
.
Environment
({})
result
=
evaluate
(
env
,
expr
,
*
args
)
intrp
=
create_executor
(
env
=
env
)
result
=
intrp
.
evaluate
(
expr
)(
*
args
)
np
.
testing
.
assert_allclose
(
result
.
asnumpy
(),
expected_result
,
rtol
=
rtol
)
...
...
@@ -32,8 +30,6 @@ def test_tuple_value():
def
test_id
():
x
=
relay
.
var
(
'x'
,
'float32'
)
ident
=
relay
.
Function
([
x
],
x
)
env
=
relay
.
env
.
Environment
({})
res
=
evaluate
(
env
,
ident
,
1.0
)
check_eval
(
ident
,
[
1.0
],
1.0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment