Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
46fa6eeb
Commit
46fa6eeb
authored
Oct 16, 2019
by
Altan Haan
Committed by
Wuwei Lin
Oct 16, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Training] Add and fix gradients (#4126)
* add and fix gradients * fix linter issues
parent
1c0e7435
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
111 additions
and
19 deletions
+111
-19
python/tvm/relay/op/_tensor_grad.py
+74
-8
tests/python/relay/test_op_grad_level2.py
+26
-3
tests/python/relay/test_op_grad_level4.py
+11
-8
No files found.
python/tvm/relay/op/_tensor_grad.py
View file @
46fa6eeb
...
@@ -48,6 +48,9 @@ from .transform import (
...
@@ -48,6 +48,9 @@ from .transform import (
tile
,
tile
,
transpose
,
transpose
,
where
,
where
,
repeat
,
expand_dims
,
full_like
)
)
...
@@ -198,6 +201,7 @@ def clip_grad(orig, grad):
...
@@ -198,6 +201,7 @@ def clip_grad(orig, grad):
@register_gradient
(
"nn.max_pool2d"
)
@register_gradient
(
"nn.max_pool2d"
)
def
max_pool2d_grad
(
orig
,
grad
):
def
max_pool2d_grad
(
orig
,
grad
):
"""Returns the gradient of max_pool2d."""
attrs
=
orig
.
attrs
attrs
=
orig
.
attrs
pool_grad
=
_nn
.
max_pool2d_grad
(
grad
,
orig
.
args
[
0
],
pool_size
=
attrs
.
pool_size
,
pool_grad
=
_nn
.
max_pool2d_grad
(
grad
,
orig
.
args
[
0
],
pool_size
=
attrs
.
pool_size
,
strides
=
attrs
.
strides
,
padding
=
attrs
.
padding
,
strides
=
attrs
.
strides
,
padding
=
attrs
.
padding
,
...
@@ -207,6 +211,7 @@ def max_pool2d_grad(orig, grad):
...
@@ -207,6 +211,7 @@ def max_pool2d_grad(orig, grad):
@register_gradient
(
"nn.avg_pool2d"
)
@register_gradient
(
"nn.avg_pool2d"
)
def
avg_pool2d_grad
(
orig
,
grad
):
def
avg_pool2d_grad
(
orig
,
grad
):
"""Returns the gradient of avg_pool2d."""
attrs
=
orig
.
attrs
attrs
=
orig
.
attrs
pool_grad
=
_nn
.
avg_pool2d_grad
(
grad
,
orig
.
args
[
0
],
pool_size
=
attrs
.
pool_size
,
pool_grad
=
_nn
.
avg_pool2d_grad
(
grad
,
orig
.
args
[
0
],
pool_size
=
attrs
.
pool_size
,
strides
=
attrs
.
strides
,
padding
=
attrs
.
padding
,
strides
=
attrs
.
strides
,
padding
=
attrs
.
padding
,
...
@@ -215,6 +220,26 @@ def avg_pool2d_grad(orig, grad):
...
@@ -215,6 +220,26 @@ def avg_pool2d_grad(orig, grad):
return
[
pool_grad
]
return
[
pool_grad
]
@register_gradient
(
"nn.global_avg_pool2d"
)
def
global_avg_pool2d_grad
(
orig
,
grad
):
"""Returns the gradient of global_avg_pool2d."""
data
=
orig
.
args
[
0
]
shape
=
data
.
checked_type
.
shape
layout
=
orig
.
attrs
.
layout
# we assume NCHW or NHWC layout for now, but easy to add more
assert
layout
in
[
"NCHW"
,
"NHWC"
]
if
layout
==
"NCHW"
:
pool_size
=
shape
[
2
],
shape
[
3
]
elif
layout
==
"NHWC"
:
pool_size
=
shape
[
1
],
shape
[
2
]
pool_grad
=
_nn
.
avg_pool2d_grad
(
grad
,
data
,
pool_size
=
pool_size
,
strides
=
(
1
,
1
),
padding
=
(
0
,
0
),
layout
=
layout
)
return
[
pool_grad
]
# not implemented, this is only for testing.
# not implemented, this is only for testing.
@register_gradient
(
"concatenate"
)
@register_gradient
(
"concatenate"
)
def
concatenate_grad
(
orig
,
grad
):
def
concatenate_grad
(
orig
,
grad
):
...
@@ -287,16 +312,53 @@ def conv2d_grad(orig, grad):
...
@@ -287,16 +312,53 @@ def conv2d_grad(orig, grad):
return
[
backward_data
,
backward_weight
]
return
[
backward_data
,
backward_weight
]
def
_get_reduce_axis
(
call
):
"""Helper function that returns the reduce axis of the call as plain python ints."""
x
,
axis
=
call
.
args
[
0
],
call
.
attrs
.
axis
shape
=
x
.
checked_type
.
concrete_shape
# should never exclude when axis is None
assert
not
(
axis
is
None
and
call
.
attrs
.
exclude
)
if
axis
is
None
:
return
None
# convert to nonnegative integers and sort
axis
=
sorted
([
ax
if
ax
>=
0
else
len
(
shape
)
+
ax
for
ax
in
map
(
int
,
axis
)])
if
call
.
attrs
.
exclude
:
axis
=
[
ax
for
ax
in
range
(
len
(
shape
))
if
ax
not
in
axis
]
return
axis
def
_unreduce_expand
(
x
,
axis
):
"""Helper function that returns x expanded on the reduced dimensions in axis."""
# assume axis is sorted nonnegative ints
for
ax
in
axis
:
x
=
expand_dims
(
x
,
ax
)
return
x
@register_gradient
(
"max"
)
@register_gradient
(
"max"
)
def
max_grad
(
orig
,
grad
):
def
max_grad
(
orig
,
grad
):
"""Returns the gradient of max"""
"""Returns the gradient of max"""
# Only support axis=0, since broadcasting orig to x behaves incorrectly
x
,
axis
=
orig
.
args
[
0
],
_get_reduce_axis
(
orig
)
x
,
axis
=
orig
.
args
[
0
],
orig
.
attrs
.
axis
shape
=
x
.
checked_type
.
concrete_shape
assert
(
axis
is
not
None
and
len
(
axis
)
==
1
and
int
(
axis
[
0
])
==
0
)
orig
=
broadcast_to_like
(
orig
,
x
)
repeated
=
orig
grad
=
broadcast_to_like
(
grad
,
x
)
if
axis
is
None
:
indicators
=
cast_like
(
equal
(
orig
,
x
),
grad
)
repeated
=
full_like
(
x
,
repeated
)
return
[
indicators
*
grad
]
else
:
# expand dims (if necessary) and repeat along each axis
if
not
orig
.
attrs
.
keepdims
:
repeated
=
_unreduce_expand
(
repeated
,
axis
)
grad
=
_unreduce_expand
(
grad
,
axis
)
for
ax
in
axis
:
repeated
=
repeat
(
repeated
,
shape
[
ax
],
ax
)
indicators
=
cast_like
(
equal
(
repeated
,
x
),
grad
)
num_selected
=
_sum
(
indicators
,
axis
,
keepdims
=
True
)
# spread error across all max weights
return
[
indicators
*
grad
/
num_selected
]
@register_gradient
(
"nn.softmax"
)
@register_gradient
(
"nn.softmax"
)
...
@@ -372,7 +434,11 @@ def negative_grad(orig, grad):
...
@@ -372,7 +434,11 @@ def negative_grad(orig, grad):
@register_gradient
(
"sum"
)
@register_gradient
(
"sum"
)
def
sum_grad
(
orig
,
grad
):
def
sum_grad
(
orig
,
grad
):
"""Returns grad broadcasted to data dims"""
"""Returns grad broadcasted to data dims"""
data
=
orig
.
args
[
0
]
data
,
axis
=
orig
.
args
[
0
],
_get_reduce_axis
(
orig
)
if
not
orig
.
attrs
.
keepdims
:
if
axis
is
None
:
axis
=
list
(
range
(
len
(
data
.
checked_type
.
concrete_shape
)))
grad
=
_unreduce_expand
(
grad
,
axis
)
return
[
broadcast_to_like
(
grad
,
data
)]
return
[
broadcast_to_like
(
grad
,
data
)]
...
...
tests/python/relay/test_op_grad_level2.py
View file @
46fa6eeb
...
@@ -48,8 +48,7 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
...
@@ -48,8 +48,7 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):
def
test_max_pool2d_grad
():
def
test_max_pool2d_grad
():
verify_max_pool2d_grad
((
1
,
4
,
16
,
16
),
pool_size
=
(
2
,
2
),
strides
=
(
2
,
2
),
padding
=
(
0
,
0
),
verify_max_pool2d_grad
((
1
,
4
,
16
,
16
),
pool_size
=
(
2
,
2
),
strides
=
(
2
,
2
),
padding
=
(
0
,
0
),
ceil_mode
=
False
)
ceil_mode
=
False
)
verify_max_pool2d_grad
((
1
,
4
,
16
,
16
),
pool_size
=
(
1
,
1
),
strides
=
(
1
,
1
),
padding
=
(
1
,
1
),
ceil_mode
=
False
)
verify_max_pool2d_grad
((
1
,
4
,
16
,
16
),
pool_size
=
(
1
,
1
),
strides
=
(
1
,
1
),
padding
=
(
1
,
1
),
ceil_mode
=
False
)
...
@@ -75,7 +74,6 @@ def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, coun
...
@@ -75,7 +74,6 @@ def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, coun
op_res
,
(
op_grad
,
)
=
intrp
.
evaluate
(
bwd_func
)(
data
)
op_res
,
(
op_grad
,
)
=
intrp
.
evaluate
(
bwd_func
)(
data
)
np
.
testing
.
assert_allclose
(
op_grad
.
asnumpy
(),
ref_grad
,
rtol
=
0.01
)
np
.
testing
.
assert_allclose
(
op_grad
.
asnumpy
(),
ref_grad
,
rtol
=
0.01
)
def
test_avg_pool2d_grad
():
def
test_avg_pool2d_grad
():
verify_avg_pool2d_grad
((
1
,
4
,
16
,
16
),
pool_size
=
(
2
,
2
),
strides
=
(
2
,
2
),
padding
=
(
0
,
0
),
verify_avg_pool2d_grad
((
1
,
4
,
16
,
16
),
pool_size
=
(
2
,
2
),
strides
=
(
2
,
2
),
padding
=
(
0
,
0
),
ceil_mode
=
False
,
count_include_pad
=
True
)
ceil_mode
=
False
,
count_include_pad
=
True
)
...
@@ -83,6 +81,30 @@ def test_avg_pool2d_grad():
...
@@ -83,6 +81,30 @@ def test_avg_pool2d_grad():
ceil_mode
=
False
,
count_include_pad
=
False
)
ceil_mode
=
False
,
count_include_pad
=
False
)
def
verify_global_avg_pool2d_grad
(
x_shape
):
x
=
relay
.
var
(
"x"
,
relay
.
TensorType
(
x_shape
,
"float32"
))
y
=
tvm
.
relay
.
nn
.
global_avg_pool2d
(
x
)
fwd_func
=
relay
.
Function
([
x
],
y
)
fwd_func
=
run_infer_type
(
fwd_func
)
bwd_func
=
run_infer_type
(
gradient
(
fwd_func
))
data
=
np
.
random
.
rand
(
*
x_shape
)
.
astype
(
"float32"
)
y_shape
=
topi
.
util
.
get_const_tuple
(
fwd_func
.
ret_type
.
shape
)
out_grad
=
np
.
ones
(
shape
=
y_shape
)
ref_grad
=
topi
.
testing
.
pool_grad_nchw
(
data
,
out_grad
,
pool_size
=
(
x_shape
[
2
],
x_shape
[
3
]),
strides
=
(
1
,
1
),
padding
=
[
0
,
0
,
0
,
0
],
pool_type
=
'avg'
,
ceil_mode
=
False
)
for
target
,
ctx
in
ctx_list
():
intrp
=
relay
.
create_executor
(
ctx
=
ctx
,
target
=
target
)
op_res
,
(
op_grad
,
)
=
intrp
.
evaluate
(
bwd_func
)(
data
)
np
.
testing
.
assert_allclose
(
op_grad
.
asnumpy
(),
ref_grad
,
rtol
=
0.01
)
def
test_global_avg_pool2d_grad
():
verify_global_avg_pool2d_grad
((
1
,
4
,
16
,
16
))
verify_global_avg_pool2d_grad
((
1
,
8
,
8
,
24
))
def
verify_conv2d_grad
(
dshape
,
wshape
,
strides
,
padding
,
dilation
,
groups
=
1
,
mode
=
'higher_order'
):
def
verify_conv2d_grad
(
dshape
,
wshape
,
strides
,
padding
,
dilation
,
groups
=
1
,
mode
=
'higher_order'
):
try
:
try
:
import
torch
import
torch
...
@@ -155,6 +177,7 @@ def test_batch_flatten_grad():
...
@@ -155,6 +177,7 @@ def test_batch_flatten_grad():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_max_pool2d_grad
()
test_max_pool2d_grad
()
test_avg_pool2d_grad
()
test_avg_pool2d_grad
()
test_global_avg_pool2d_grad
()
test_conv2d_grad
()
test_conv2d_grad
()
test_dense_grad
()
test_dense_grad
()
test_batch_flatten_grad
()
test_batch_flatten_grad
()
tests/python/relay/test_op_grad_level4.py
View file @
46fa6eeb
...
@@ -29,18 +29,21 @@ def test_sum_grad():
...
@@ -29,18 +29,21 @@ def test_sum_grad():
verify_sum_grad
((
4
,
2
))
verify_sum_grad
((
4
,
2
))
verify_sum_grad
((
4
,
2
),
axis
=-
1
,
keepdims
=
True
)
verify_sum_grad
((
4
,
2
),
axis
=-
1
,
keepdims
=
True
)
verify_sum_grad
((
4
,
2
,
1
),
axis
=
(
1
,
2
),
exclude
=
True
)
verify_sum_grad
((
4
,
2
,
1
),
axis
=
(
1
,
2
),
exclude
=
True
)
verify_sum_grad
((
4
,
2
,
1
),
axis
=
1
)
def
test_max_grad
():
def
verify_max_grad
(
d_shape
,
axis
=
None
,
keepdims
=
False
,
exclude
=
False
):
s
=
(
10
,
10
)
data
=
relay
.
var
(
"data"
,
relay
.
TensorType
(
d_shape
,
"float32"
))
t
=
relay
.
TensorType
(
s
)
fwd_func
=
relay
.
Function
([
data
],
relay
.
max
(
data
,
axis
=
axis
,
keepdims
=
keepdims
,
exclude
=
exclude
))
x
=
relay
.
var
(
"x"
,
t
)
axis
=
0
z
=
relay
.
max
(
x
,
axis
)
fwd_func
=
relay
.
Function
([
x
],
z
)
check_grad
(
fwd_func
,
scale
=
1e-3
)
check_grad
(
fwd_func
,
scale
=
1e-3
)
def
test_max_grad
():
verify_max_grad
((
10
,
10
),
axis
=
None
)
verify_max_grad
((
10
,
10
),
axis
=-
1
)
verify_max_grad
((
6
,
3
,
2
),
axis
=
(
1
,
2
),
keepdims
=
True
)
verify_max_grad
((
5
,
4
,
3
),
axis
=
(
0
,
2
),
exclude
=
True
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pytest
.
main
()
pytest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment