Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
1a00cab9
Commit
1a00cab9
authored
Jul 19, 2019
by
雾雨魔理沙
Committed by
Wuwei Lin
Jul 20, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] add some check for the ad algorithm (#3585)
* do * fix test
parent
313bc9de
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
41 additions
and
18 deletions
+41
-18
python/tvm/relay/op/_tensor_grad.py
+16
-0
src/relay/pass/gradient.cc
+3
-0
tests/python/relay/test_feature.py
+2
-0
tests/python/relay/test_op_grad_level1.py
+3
-8
tests/python/relay/test_op_grad_level3.py
+2
-7
tests/python/relay/test_pass_gradient.py
+11
-1
tests/python/relay/test_pass_partial_eval.py
+2
-2
tests/python/relay/test_pass_to_cps.py
+2
-0
No files found.
python/tvm/relay/op/_tensor_grad.py
View file @
1a00cab9
...
@@ -95,22 +95,37 @@ def divide_grad(orig, grad):
...
@@ -95,22 +95,37 @@ def divide_grad(orig, grad):
collapse_sum_like
(
-
(
grad
*
orig
/
y
),
y
)]
collapse_sum_like
(
-
(
grad
*
orig
/
y
),
y
)]
@register_gradient
(
"zeros"
)
def
zeros_grad
(
orig
,
grad
):
"""Returns []"""
return
[]
@register_gradient
(
"ones"
)
def
ones_grad
(
orig
,
grad
):
"""Returns []"""
return
[]
@register_gradient
(
"zeros_like"
)
@register_gradient
(
"zeros_like"
)
def
zeros_like_grad
(
orig
,
grad
):
def
zeros_like_grad
(
orig
,
grad
):
"""Returns [0]"""
"""Returns [0]"""
return
[
orig
]
return
[
orig
]
@register_gradient
(
"ones_like"
)
@register_gradient
(
"ones_like"
)
def
ones_like_grad
(
orig
,
grad
):
def
ones_like_grad
(
orig
,
grad
):
"""Returns [0]"""
"""Returns [0]"""
return
[
zeros_like
(
orig
.
args
[
0
])]
return
[
zeros_like
(
orig
.
args
[
0
])]
@register_gradient
(
"collapse_sum_like"
)
@register_gradient
(
"collapse_sum_like"
)
def
collapse_sum_like_grad
(
orig
,
grad
):
def
collapse_sum_like_grad
(
orig
,
grad
):
"""Returns [broadcast_to_like(grad, x), 0]"""
"""Returns [broadcast_to_like(grad, x), 0]"""
x
,
y
=
orig
.
args
x
,
y
=
orig
.
args
return
[
broadcast_to_like
(
grad
,
x
),
zeros_like
(
y
)]
return
[
broadcast_to_like
(
grad
,
x
),
zeros_like
(
y
)]
@register_gradient
(
"abs"
)
@register_gradient
(
"abs"
)
def
abs_grad
(
orig
,
grad
):
def
abs_grad
(
orig
,
grad
):
"""Returns grad * (select(x < 0, -1, 1))."""
"""Returns grad * (select(x < 0, -1, 1))."""
...
@@ -119,6 +134,7 @@ def abs_grad(orig, grad):
...
@@ -119,6 +134,7 @@ def abs_grad(orig, grad):
ones
=
ones_like
(
x
)
ones
=
ones_like
(
x
)
return
[
where
(
less
(
x
,
zeros
),
-
ones
*
grad
,
ones
*
grad
)]
return
[
where
(
less
(
x
,
zeros
),
-
ones
*
grad
,
ones
*
grad
)]
@register_gradient
(
"clip"
)
@register_gradient
(
"clip"
)
def
clip_grad
(
orig
,
grad
):
def
clip_grad
(
orig
,
grad
):
"""Returns grad * (select(x < min || max < x , 0, 1))."""
"""Returns grad * (select(x < min || max < x , 0, 1))."""
...
...
src/relay/pass/gradient.cc
View file @
1a00cab9
...
@@ -333,6 +333,9 @@ Expr Gradient(const Expr& re, const Module& mod) {
...
@@ -333,6 +333,9 @@ Expr Gradient(const Expr& re, const Module& mod) {
auto
f
=
e
.
as
<
FunctionNode
>
();
auto
f
=
e
.
as
<
FunctionNode
>
();
CHECK
(
f
)
<<
"input need to be a function"
;
CHECK
(
f
)
<<
"input need to be a function"
;
CHECK
(
f
->
type_params
.
size
()
==
0
)
<<
"no polymorphism supported for now"
;
CHECK
(
f
->
type_params
.
size
()
==
0
)
<<
"no polymorphism supported for now"
;
for
(
const
auto
&
p
:
f
->
params
)
{
CHECK
(
p
->
checked_type
().
as
<
TensorTypeNode
>
())
<<
"input parameters need to be tensor"
;
}
Expr
body
=
LetList
::
With
([
&
](
LetList
*
ll
)
{
Expr
body
=
LetList
::
With
([
&
](
LetList
*
ll
)
{
Var
bp
=
ll
->
Push
(
BPEmpty
());
Var
bp
=
ll
->
Push
(
BPEmpty
());
Expr
rev
=
ReverseAD
(
bp
)(
e
);
Expr
rev
=
ReverseAD
(
bp
)(
e
);
...
...
tests/python/relay/test_feature.py
View file @
1a00cab9
...
@@ -21,6 +21,7 @@ from tvm.relay.analysis import detect_feature
...
@@ -21,6 +21,7 @@ from tvm.relay.analysis import detect_feature
from
tvm.relay.transform
import
gradient
from
tvm.relay.transform
import
gradient
from
tvm.relay.feature
import
Feature
from
tvm.relay.feature
import
Feature
from
tvm.relay.prelude
import
Prelude
from
tvm.relay.prelude
import
Prelude
from
tvm.relay.testing
import
run_infer_type
def
test_prelude
():
def
test_prelude
():
p
=
Prelude
()
p
=
Prelude
()
...
@@ -47,6 +48,7 @@ def test_ad():
...
@@ -47,6 +48,7 @@ def test_ad():
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
x
=
relay
.
var
(
"x"
,
t
)
x
=
relay
.
var
(
"x"
,
t
)
func
=
relay
.
Function
([
x
],
x
+
x
)
func
=
relay
.
Function
([
x
],
x
+
x
)
func
=
run_infer_type
(
func
)
mod
=
relay
.
Module
.
from_expr
(
gradient
(
func
))
mod
=
relay
.
Module
.
from_expr
(
gradient
(
func
))
mod
=
relay
.
transform
.
InferType
()(
mod
)
mod
=
relay
.
transform
.
InferType
()(
mod
)
back_func
=
mod
[
"main"
]
back_func
=
mod
[
"main"
]
...
...
tests/python/relay/test_op_grad_level1.py
View file @
1a00cab9
...
@@ -18,14 +18,7 @@ import numpy as np
...
@@ -18,14 +18,7 @@ import numpy as np
import
tvm
import
tvm
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.transform
import
gradient
from
tvm.relay.transform
import
gradient
from
tvm.relay.testing
import
ctx_list
from
tvm.relay.testing
import
ctx_list
,
run_infer_type
def
run_infer_type
(
expr
):
mod
=
relay
.
Module
.
from_expr
(
expr
)
mod
=
relay
.
transform
.
InferType
()(
mod
)
return
mod
[
"main"
]
def
sigmoid
(
x
):
def
sigmoid
(
x
):
one
=
np
.
ones_like
(
x
)
one
=
np
.
ones_like
(
x
)
...
@@ -49,6 +42,7 @@ def test_unary_op():
...
@@ -49,6 +42,7 @@ def test_unary_op():
data
=
np
.
random
.
rand
(
*
shape
)
.
astype
(
dtype
)
data
=
np
.
random
.
rand
(
*
shape
)
.
astype
(
dtype
)
ref_grad
=
ref
(
data
)
ref_grad
=
ref
(
data
)
fwd_func
=
relay
.
Function
([
x
],
y
)
fwd_func
=
relay
.
Function
([
x
],
y
)
fwd_func
=
run_infer_type
(
fwd_func
)
bwd_func
=
run_infer_type
(
gradient
(
fwd_func
))
bwd_func
=
run_infer_type
(
gradient
(
fwd_func
))
for
target
,
ctx
in
ctx_list
():
for
target
,
ctx
in
ctx_list
():
...
@@ -81,6 +75,7 @@ def test_binary_op():
...
@@ -81,6 +75,7 @@ def test_binary_op():
y_data
=
np
.
random
.
rand
(
*
s
)
.
astype
(
t
.
dtype
)
y_data
=
np
.
random
.
rand
(
*
s
)
.
astype
(
t
.
dtype
)
ref_grad0
,
ref_grad1
=
ref
(
x_data
,
y_data
)
ref_grad0
,
ref_grad1
=
ref
(
x_data
,
y_data
)
fwd_func
=
relay
.
Function
([
x
,
y
],
z
)
fwd_func
=
relay
.
Function
([
x
,
y
],
z
)
fwd_func
=
run_infer_type
(
fwd_func
)
bwd_func
=
run_infer_type
(
gradient
(
fwd_func
))
bwd_func
=
run_infer_type
(
gradient
(
fwd_func
))
for
target
,
ctx
in
ctx_list
():
for
target
,
ctx
in
ctx_list
():
...
...
tests/python/relay/test_op_grad_level3.py
View file @
1a00cab9
...
@@ -18,13 +18,7 @@ import numpy as np
...
@@ -18,13 +18,7 @@ import numpy as np
import
tvm
import
tvm
from
tvm
import
relay
from
tvm
import
relay
from
tvm.relay.transform
import
gradient
from
tvm.relay.transform
import
gradient
from
tvm.relay.testing
import
ctx_list
from
tvm.relay.testing
import
ctx_list
,
run_infer_type
def
run_infer_type
(
expr
):
mod
=
relay
.
Module
.
from_expr
(
expr
)
mod
=
relay
.
transform
.
InferType
()(
mod
)
return
mod
[
"main"
]
def
test_clip
():
def
test_clip
():
ref
=
(
lambda
x
:
np
.
where
(
x
>
10.0
,
np
.
zeros_like
(
x
),
ref
=
(
lambda
x
:
np
.
where
(
x
>
10.0
,
np
.
zeros_like
(
x
),
...
@@ -35,6 +29,7 @@ def test_clip():
...
@@ -35,6 +29,7 @@ def test_clip():
data
=
np
.
random
.
rand
(
10
,
4
)
.
astype
(
"float32"
)
*
11.0
data
=
np
.
random
.
rand
(
10
,
4
)
.
astype
(
"float32"
)
*
11.0
ref_grad
=
ref
(
data
)
ref_grad
=
ref
(
data
)
fwd_func
=
relay
.
Function
([
x
],
y
)
fwd_func
=
relay
.
Function
([
x
],
y
)
fwd_func
=
run_infer_type
(
fwd_func
)
bwd_func
=
run_infer_type
(
gradient
(
fwd_func
))
bwd_func
=
run_infer_type
(
gradient
(
fwd_func
))
for
target
,
ctx
in
ctx_list
():
for
target
,
ctx
in
ctx_list
():
...
...
tests/python/relay/test_pass_gradient.py
View file @
1a00cab9
...
@@ -35,6 +35,7 @@ def test_id():
...
@@ -35,6 +35,7 @@ def test_id():
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
x
=
relay
.
var
(
"x"
,
t
)
x
=
relay
.
var
(
"x"
,
t
)
func
=
relay
.
Function
([
x
],
x
)
func
=
relay
.
Function
([
x
],
x
)
func
=
run_infer_type
(
func
)
back_func
=
run_infer_type
(
gradient
(
func
,
mode
=
"first_order"
))
back_func
=
run_infer_type
(
gradient
(
func
,
mode
=
"first_order"
))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
])]))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
])]))
ex
=
create_executor
()
ex
=
create_executor
()
...
@@ -50,6 +51,7 @@ def test_add():
...
@@ -50,6 +51,7 @@ def test_add():
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
x
=
relay
.
var
(
"x"
,
t
)
x
=
relay
.
var
(
"x"
,
t
)
func
=
relay
.
Function
([
x
],
x
+
x
)
func
=
relay
.
Function
([
x
],
x
+
x
)
func
=
run_infer_type
(
func
)
back_func
=
run_infer_type
(
gradient
(
func
))
back_func
=
run_infer_type
(
gradient
(
func
))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
])]))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
])]))
ex
=
create_executor
()
ex
=
create_executor
()
...
@@ -66,6 +68,7 @@ def test_temp_add():
...
@@ -66,6 +68,7 @@ def test_temp_add():
x
=
relay
.
var
(
"x"
,
t
)
x
=
relay
.
var
(
"x"
,
t
)
y
=
x
+
x
y
=
x
+
x
func
=
relay
.
Function
([
x
],
y
+
y
)
func
=
relay
.
Function
([
x
],
y
+
y
)
func
=
run_infer_type
(
func
)
back_func
=
run_infer_type
(
gradient
(
func
))
back_func
=
run_infer_type
(
gradient
(
func
))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
])]))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
])]))
ex
=
create_executor
()
ex
=
create_executor
()
...
@@ -81,6 +84,7 @@ def test_sub():
...
@@ -81,6 +84,7 @@ def test_sub():
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
x
=
relay
.
var
(
"x"
,
t
)
x
=
relay
.
var
(
"x"
,
t
)
func
=
relay
.
Function
([
x
],
x
-
x
)
func
=
relay
.
Function
([
x
],
x
-
x
)
func
=
run_infer_type
(
func
)
back_func
=
run_infer_type
(
gradient
(
func
))
back_func
=
run_infer_type
(
gradient
(
func
))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
])]))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
])]))
ex
=
create_executor
()
ex
=
create_executor
()
...
@@ -104,6 +108,7 @@ def test_broadcast_add():
...
@@ -104,6 +108,7 @@ def test_broadcast_add():
x
=
relay
.
var
(
"x"
,
t1
)
x
=
relay
.
var
(
"x"
,
t1
)
y
=
relay
.
var
(
"y"
,
t2
)
y
=
relay
.
var
(
"y"
,
t2
)
func
=
relay
.
Function
([
x
,
y
],
x
+
y
)
func
=
relay
.
Function
([
x
,
y
],
x
+
y
)
func
=
run_infer_type
(
func
)
full_func
=
run_infer_type
(
gradient
(
func
))
full_func
=
run_infer_type
(
gradient
(
func
))
assert
full_func
.
checked_type
==
relay
.
FuncType
([
t1
,
t2
],
assert
full_func
.
checked_type
==
relay
.
FuncType
([
t1
,
t2
],
relay
.
TupleType
([
relay
.
TensorType
(
expected_forward
.
shape
,
dtype
),
relay
.
TupleType
([
relay
.
TensorType
(
expected_forward
.
shape
,
dtype
),
...
@@ -131,6 +136,7 @@ def test_broadcast_subtract():
...
@@ -131,6 +136,7 @@ def test_broadcast_subtract():
x
=
relay
.
var
(
"x"
,
t1
)
x
=
relay
.
var
(
"x"
,
t1
)
y
=
relay
.
var
(
"y"
,
t2
)
y
=
relay
.
var
(
"y"
,
t2
)
func
=
relay
.
Function
([
x
,
y
],
x
-
y
)
func
=
relay
.
Function
([
x
,
y
],
x
-
y
)
func
=
run_infer_type
(
func
)
full_func
=
run_infer_type
(
gradient
(
func
))
full_func
=
run_infer_type
(
gradient
(
func
))
assert
full_func
.
checked_type
==
relay
.
FuncType
([
t1
,
t2
],
assert
full_func
.
checked_type
==
relay
.
FuncType
([
t1
,
t2
],
relay
.
TupleType
([
relay
.
TensorType
(
expected_forward
.
shape
,
dtype
),
relay
.
TupleType
([
relay
.
TensorType
(
expected_forward
.
shape
,
dtype
),
...
@@ -156,6 +162,7 @@ def test_tuple():
...
@@ -156,6 +162,7 @@ def test_tuple():
relay
.
TupleGetItem
(
tup
,
0
)
+
relay
.
TupleGetItem
(
tup
,
0
)
+
relay
.
TupleGetItem
(
tup
,
1
)
-
relay
.
TupleGetItem
(
tup
,
1
)
-
relay
.
TupleGetItem
(
tup
,
2
)))
relay
.
TupleGetItem
(
tup
,
2
)))
func
=
run_infer_type
(
func
)
back_func
=
run_infer_type
(
gradient
(
func
))
back_func
=
run_infer_type
(
gradient
(
func
))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
,
t
,
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
,
t
,
t
])]))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
,
t
,
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
,
t
,
t
])]))
x_nd
=
rand
(
dtype
,
*
shape
)
x_nd
=
rand
(
dtype
,
*
shape
)
...
@@ -184,8 +191,8 @@ def test_pow():
...
@@ -184,8 +191,8 @@ def test_pow():
double
=
relay
.
Function
([
x
],
x
+
x
)
double
=
relay
.
Function
([
x
],
x
+
x
)
i
=
relay
.
var
(
"i"
,
t
)
i
=
relay
.
var
(
"i"
,
t
)
func
=
relay
.
Function
([
i
],
p
.
nat_iterate
(
double
,
make_nat_expr
(
p
,
3
))(
i
))
func
=
relay
.
Function
([
i
],
p
.
nat_iterate
(
double
,
make_nat_expr
(
p
,
3
))(
i
))
func
=
gradient
(
func
,
mod
=
mod
)
mod
[
"main"
]
=
func
mod
[
"main"
]
=
func
mod
[
"main"
]
=
gradient
(
mod
[
"main"
],
mod
=
mod
)
m
=
transform
.
InferType
()(
mod
)
m
=
transform
.
InferType
()(
mod
)
back_func
=
m
[
"main"
]
back_func
=
m
[
"main"
]
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
])]))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
])]))
...
@@ -207,6 +214,7 @@ def test_ref():
...
@@ -207,6 +214,7 @@ def test_ref():
body
=
relay
.
Let
(
u
,
relay
.
RefWrite
(
r
,
relay
.
RefRead
(
r
)
+
relay
.
RefRead
(
r
)),
body
)
body
=
relay
.
Let
(
u
,
relay
.
RefWrite
(
r
,
relay
.
RefRead
(
r
)
+
relay
.
RefRead
(
r
)),
body
)
body
=
relay
.
Let
(
r
,
relay
.
RefCreate
(
x
),
body
)
body
=
relay
.
Let
(
r
,
relay
.
RefCreate
(
x
),
body
)
func
=
relay
.
Function
([
x
],
body
)
func
=
relay
.
Function
([
x
],
body
)
func
=
run_infer_type
(
func
)
back_func
=
run_infer_type
(
gradient
(
func
))
back_func
=
run_infer_type
(
gradient
(
func
))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
])]))
assert
back_func
.
checked_type
==
relay
.
FuncType
([
t
],
relay
.
TupleType
([
t
,
relay
.
TupleType
([
t
])]))
x_nd
=
rand
(
dtype
,
*
shape
)
x_nd
=
rand
(
dtype
,
*
shape
)
...
@@ -222,6 +230,7 @@ def test_square_second_order():
...
@@ -222,6 +230,7 @@ def test_square_second_order():
t
=
relay
.
TensorType
(
shape
,
dtype
)
t
=
relay
.
TensorType
(
shape
,
dtype
)
x
=
relay
.
var
(
"x"
,
t
)
x
=
relay
.
var
(
"x"
,
t
)
func
=
relay
.
Function
([
x
],
x
*
x
)
func
=
relay
.
Function
([
x
],
x
*
x
)
func
=
run_infer_type
(
func
)
back_func
=
run_infer_type
(
gradient
(
func
))
back_func
=
run_infer_type
(
gradient
(
func
))
y
=
relay
.
var
(
"y"
,
t
)
y
=
relay
.
var
(
"y"
,
t
)
back_func_adjusted
=
relay
.
Function
([
y
],
relay
.
TupleGetItem
(
relay
.
TupleGetItem
(
back_func
(
y
),
1
),
0
))
back_func_adjusted
=
relay
.
Function
([
y
],
relay
.
TupleGetItem
(
relay
.
TupleGetItem
(
back_func
(
y
),
1
),
0
))
...
@@ -242,6 +251,7 @@ def test_if():
...
@@ -242,6 +251,7 @@ def test_if():
net
=
relay
.
If
(
cond
,
x
,
y
)
net
=
relay
.
If
(
cond
,
x
,
y
)
net
=
relay
.
log
(
net
)
net
=
relay
.
log
(
net
)
func
=
relay
.
Function
(
free_vars
(
net
),
net
)
func
=
relay
.
Function
(
free_vars
(
net
),
net
)
func
=
run_infer_type
(
func
)
net
=
run_infer_type
(
func
)
net
=
run_infer_type
(
func
)
net
=
gradient
(
net
,
mode
=
'higher_order'
)
net
=
gradient
(
net
,
mode
=
'higher_order'
)
net
=
run_infer_type
(
net
)
net
=
run_infer_type
(
net
)
...
...
tests/python/relay/test_pass_partial_eval.py
View file @
1a00cab9
...
@@ -25,7 +25,7 @@ from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead,
...
@@ -25,7 +25,7 @@ from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead,
from
tvm.relay
import
TensorType
,
Tuple
,
If
,
Module
,
Clause
,
PatternConstructor
,
PatternVar
,
Match
from
tvm.relay
import
TensorType
,
Tuple
,
If
,
Module
,
Clause
,
PatternConstructor
,
PatternVar
,
Match
from
tvm.relay
import
GlobalVar
,
Call
from
tvm.relay
import
GlobalVar
,
Call
from
tvm.relay.transform
import
gradient
from
tvm.relay.transform
import
gradient
from
tvm.relay.testing
import
add_nat_definitions
,
make_nat_expr
from
tvm.relay.testing
import
add_nat_definitions
,
make_nat_expr
,
run_infer_type
def
check_eval
(
expr
,
expected_result
,
mod
=
None
,
rtol
=
1e-07
):
def
check_eval
(
expr
,
expected_result
,
mod
=
None
,
rtol
=
1e-07
):
ctx
=
tvm
.
context
(
"llvm"
,
0
)
ctx
=
tvm
.
context
(
"llvm"
,
0
)
...
@@ -54,7 +54,7 @@ def dcpe(expr, mod=None, grad=False):
...
@@ -54,7 +54,7 @@ def dcpe(expr, mod=None, grad=False):
passes
=
[
transform
.
PartialEvaluate
(),
passes
=
[
transform
.
PartialEvaluate
(),
transform
.
DeadCodeElimination
(
inline_once
=
True
)]
transform
.
DeadCodeElimination
(
inline_once
=
True
)]
if
grad
:
if
grad
:
expr
=
gradient
(
expr
)
expr
=
gradient
(
run_infer_type
(
expr
)
)
if
mod
:
if
mod
:
assert
isinstance
(
expr
,
Function
)
assert
isinstance
(
expr
,
Function
)
mod
[
"main"
]
=
expr
mod
[
"main"
]
=
expr
...
...
tests/python/relay/test_pass_to_cps.py
View file @
1a00cab9
...
@@ -81,6 +81,7 @@ def test_cps_pe():
...
@@ -81,6 +81,7 @@ def test_cps_pe():
destroy_ref
(
F
)
destroy_ref
(
F
)
G
=
relay
.
Function
([
cond
],
relay
.
If
(
cond
,
one
,
two
))
G
=
relay
.
Function
([
cond
],
relay
.
If
(
cond
,
one
,
two
))
G
=
run_infer_type
(
G
)
G
=
relay
.
transform
.
gradient
(
G
)
G
=
relay
.
transform
.
gradient
(
G
)
destroy_ref
(
G
)
destroy_ref
(
G
)
...
@@ -91,6 +92,7 @@ def test_cps_pe():
...
@@ -91,6 +92,7 @@ def test_cps_pe():
H
=
relay
.
If
(
cond
,
x
,
y
)
H
=
relay
.
If
(
cond
,
x
,
y
)
H
=
relay
.
add
(
H
,
z
)
H
=
relay
.
add
(
H
,
z
)
H
=
relay
.
Function
([
cond
,
x
,
y
,
z
],
H
)
H
=
relay
.
Function
([
cond
,
x
,
y
,
z
],
H
)
H
=
run_infer_type
(
H
)
H
=
relay
.
transform
.
gradient
(
H
)
H
=
relay
.
transform
.
gradient
(
H
)
destroy_ref
(
H
)
destroy_ref
(
H
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment