Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
69312744
Commit
69312744
authored
Jul 27, 2018
by
sgrechanik-h
Committed by
Tianqi Chen
Jul 27, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NNVM] Fix grads for sum and expand_like (#1455)
parent
ddd249f2
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
50 additions
and
9 deletions
+50
-9
nnvm/python/nnvm/top/transform.py
+4
-0
nnvm/src/top/tensor/reduce.cc
+8
-2
nnvm/src/top/tensor/transform.cc
+16
-3
nnvm/tests/python/compiler/test_top_level4.py
+19
-1
topi/python/topi/transform.py
+3
-3
No files found.
nnvm/python/nnvm/top/transform.py
View file @
69312744
...
@@ -15,6 +15,10 @@ reg.register_schedule("expand_dims", _fschedule_broadcast)
...
@@ -15,6 +15,10 @@ reg.register_schedule("expand_dims", _fschedule_broadcast)
@reg.register_compute
(
"expand_like"
)
@reg.register_compute
(
"expand_like"
)
def
compute_expand_like
(
attrs
,
inputs
,
_
):
def
compute_expand_like
(
attrs
,
inputs
,
_
):
"""Compute definition of expand_like"""
"""Compute definition of expand_like"""
if
len
(
inputs
[
0
]
.
shape
)
==
len
(
inputs
[
1
]
.
shape
):
# If the number of dimensions is not changed then it is just a broadcasting
return
topi
.
broadcast_to
(
inputs
[
0
],
inputs
[
1
]
.
shape
)
exclude
=
attrs
.
get_bool
(
"exclude"
)
exclude
=
attrs
.
get_bool
(
"exclude"
)
axis
=
attrs
.
get_int_tuple
(
"axis"
)
axis
=
attrs
.
get_int_tuple
(
"axis"
)
if
exclude
:
if
exclude
:
...
...
nnvm/src/top/tensor/reduce.cc
View file @
69312744
...
@@ -170,12 +170,18 @@ Example::
...
@@ -170,12 +170,18 @@ Example::
"FGradient"
,
[](
const
NodePtr
&
n
,
"FGradient"
,
[](
const
NodePtr
&
n
,
const
std
::
vector
<
NodeEntry
>&
ograds
){
const
std
::
vector
<
NodeEntry
>&
ograds
){
const
ReduceParam
&
param
=
nnvm
::
get
<
ReduceParam
>
(
n
->
attrs
.
parsed
);
const
ReduceParam
&
param
=
nnvm
::
get
<
ReduceParam
>
(
n
->
attrs
.
parsed
);
std
::
ostringstream
axis
;
axis
<<
param
.
axis
;
bool
exclude
=
param
.
exclude
;
TShape
p_axis
=
param
.
axis
;
if
(
!
param
.
exclude
&&
param
.
axis
.
ndim
()
==
0
)
{
exclude
=
true
;
p_axis
=
TShape
();
}
std
::
ostringstream
axis
;
axis
<<
p_axis
;
return
std
::
vector
<
NodeEntry
>
{
return
std
::
vector
<
NodeEntry
>
{
MakeNode
(
"expand_like"
,
n
->
attrs
.
name
+
"_grad"
,
MakeNode
(
"expand_like"
,
n
->
attrs
.
name
+
"_grad"
,
{
ograds
[
0
],
n
->
inputs
[
0
]},
{
ograds
[
0
],
n
->
inputs
[
0
]},
{{
"axis"
,
axis
.
str
()},
{{
"axis"
,
axis
.
str
()},
{
"exclude"
,
std
::
to_string
(
param
.
exclude
)}})
{
"exclude"
,
std
::
to_string
(
exclude
)}})
};
};
});
});
...
...
nnvm/src/top/tensor/transform.cc
View file @
69312744
...
@@ -251,7 +251,8 @@ will return a new array with shape ``(2,1,1,1,1,1,3,4)``.
...
@@ -251,7 +251,8 @@ will return a new array with shape ``(2,1,1,1,1,1,3,4)``.
NNVM_REGISTER_OP
(
expand_like
)
NNVM_REGISTER_OP
(
expand_like
)
.
describe
(
R"code(Expand an input array with the shape of second array.
.
describe
(
R"code(Expand an input array with the shape of second array.
This operation can always be composed of unsqueezing and expanding dims.
This operation can be thought of as a composition of expand_dims and broadcast_to.
If the dimensions are already expanded then it just broadcasts.
Examples::
Examples::
input = [ 12. 19. 27.]
input = [ 12. 19. 27.]
input.shape = (3,)
input.shape = (3,)
...
@@ -282,11 +283,23 @@ Examples::
...
@@ -282,11 +283,23 @@ Examples::
std
::
ostringstream
axis
;
std
::
ostringstream
axis
;
axis
<<
param
.
axis
;
axis
<<
param
.
axis
;
if
(
param
.
axis
.
ndim
()
==
0
&&
!
param
.
exclude
)
{
// Special case needed because sum interprets axis=[] differently
return
std
::
vector
<
NodeEntry
>
{
return
std
::
vector
<
NodeEntry
>
{
MakeNode
(
"sum"
,
n
->
attrs
.
name
+
"_grad"
,
ograds
[
0
],
MakeNode
(
"zeros_like"
,
n
->
attrs
.
name
+
"_zero_grad"
,
{
n
->
inputs
[
1
]})
};
}
auto
sum_node
=
MakeNode
(
"sum"
,
n
->
attrs
.
name
+
"_sum_grad"
,
{
ograds
[
0
]},
{
ograds
[
0
]},
{{
"axis"
,
axis
.
str
()},
{{
"axis"
,
axis
.
str
()},
{
"exclude"
,
std
::
to_string
(
param
.
exclude
)}}),
{
"exclude"
,
std
::
to_string
(
param
.
exclude
)}});
return
std
::
vector
<
NodeEntry
>
{
MakeNode
(
"reshape_like"
,
n
->
attrs
.
name
+
"_grad"
,
{
sum_node
,
n
->
inputs
[
0
]}),
MakeNode
(
"zeros_like"
,
n
->
attrs
.
name
+
"_zero_grad"
,
{
n
->
inputs
[
1
]})
MakeNode
(
"zeros_like"
,
n
->
attrs
.
name
+
"_zero_grad"
,
{
n
->
inputs
[
1
]})
};
};
})
})
...
...
nnvm/tests/python/compiler/test_top_level4.py
View file @
69312744
...
@@ -378,6 +378,13 @@ def verify_expand_like(in_shape, out_shape, axis, exclude):
...
@@ -378,6 +378,13 @@ def verify_expand_like(in_shape, out_shape, axis, exclude):
def
forward
(
x
,
y
):
def
forward
(
x
,
y
):
odim
=
len
(
out_shape
)
odim
=
len
(
out_shape
)
if
len
(
x
.
shape
)
==
len
(
y
.
shape
):
return
np
.
broadcast_to
(
x
,
y
.
shape
)
if
x
.
shape
==
(
1
,)
and
len
(
y
.
shape
)
==
odim
:
x
=
np
.
reshape
(
x
,
())
real_axis
=
[
i
if
i
>=
0
else
i
+
odim
for
i
in
axis
]
real_axis
=
[
i
if
i
>=
0
else
i
+
odim
for
i
in
axis
]
real_axis
=
sorted
(
real_axis
)
real_axis
=
sorted
(
real_axis
)
if
exclude
:
if
exclude
:
...
@@ -391,11 +398,17 @@ def verify_expand_like(in_shape, out_shape, axis, exclude):
...
@@ -391,11 +398,17 @@ def verify_expand_like(in_shape, out_shape, axis, exclude):
def
backward
(
head_grads
,
x
,
y
):
def
backward
(
head_grads
,
x
,
y
):
odim
=
len
(
out_shape
)
odim
=
len
(
out_shape
)
keepdims
=
len
(
x
.
shape
)
==
len
(
y
.
shape
)
if
x
.
shape
==
(
1
,)
and
len
(
y
.
shape
)
==
odim
:
x
=
np
.
reshape
(
x
,
())
real_axis
=
[
i
if
i
>=
0
else
i
+
odim
for
i
in
axis
]
real_axis
=
[
i
if
i
>=
0
else
i
+
odim
for
i
in
axis
]
real_axis
=
sorted
(
real_axis
)
real_axis
=
sorted
(
real_axis
)
if
exclude
:
if
exclude
:
real_axis
=
list
(
set
(
range
(
odim
))
-
set
(
real_axis
))
real_axis
=
list
(
set
(
range
(
odim
))
-
set
(
real_axis
))
return
[
np
.
sum
(
head_grads
,
axis
=
tuple
(
real_axis
)),
return
[
np
.
sum
(
head_grads
,
axis
=
tuple
(
real_axis
)
,
keepdims
=
keepdims
),
np
.
zeros_like
(
y
)]
np
.
zeros_like
(
y
)]
...
@@ -410,6 +423,11 @@ def test_expand_like():
...
@@ -410,6 +423,11 @@ def test_expand_like():
verify_expand_like
((
2
,),
(
2
,
3
),
[
1
],
False
)
verify_expand_like
((
2
,),
(
2
,
3
),
[
1
],
False
)
verify_expand_like
((
3
,
4
),
(
3
,
5
,
4
),
[
1
],
False
)
verify_expand_like
((
3
,
4
),
(
3
,
5
,
4
),
[
1
],
False
)
verify_expand_like
((
5
,
7
),
(
5
,
6
,
7
,
8
),
[
0
,
2
],
True
)
verify_expand_like
((
5
,
7
),
(
5
,
6
,
7
,
8
),
[
0
,
2
],
True
)
verify_expand_like
((
2
,
3
),
(
2
,
3
),
[],
False
)
verify_expand_like
((
1
,),
(
2
,
3
),
[
0
,
1
],
False
)
verify_expand_like
((
1
,
1
),
(
2
,
3
),
[
0
,
1
],
False
)
verify_expand_like
((
2
,
1
),
(
2
,
3
),
[
1
],
False
)
verify_expand_like
((
1
,
3
),
(
2
,
3
),
[
0
],
False
)
def
verify_elemwise_sum
(
num_args
):
def
verify_elemwise_sum
(
num_args
):
...
...
topi/python/topi/transform.py
View file @
69312744
...
@@ -65,15 +65,15 @@ def expand_like(a, shape_like, axis):
...
@@ -65,15 +65,15 @@ def expand_like(a, shape_like, axis):
"""
"""
odim
=
len
(
axis
)
+
len
(
a
.
shape
)
odim
=
len
(
axis
)
+
len
(
a
.
shape
)
if
odim
!=
len
(
shape_like
.
shape
):
if
odim
!=
len
(
shape_like
.
shape
):
if
len
(
a
.
shape
)
==
1
and
len
(
axis
)
==
len
(
shape_like
.
shape
):
# A special case: `a` is a scalar represented as a 1-dim tensor
return
tvm
.
compute
(
shape_like
.
shape
,
lambda
*
idxs
:
a
(
0
))
raise
ValueError
(
"shape inconsistent when expand_like ({}, {}, {})"
.
format
(
raise
ValueError
(
"shape inconsistent when expand_like ({}, {}, {})"
.
format
(
len
(
axis
),
len
(
a
.
shape
),
len
(
shape_like
.
shape
)))
len
(
axis
),
len
(
a
.
shape
),
len
(
shape_like
.
shape
)))
real_axis
=
topi
.
reduction
.
_get_real_axis
(
len
(
shape_like
.
shape
),
axis
)
real_axis
=
topi
.
reduction
.
_get_real_axis
(
len
(
shape_like
.
shape
),
axis
)
real_axis
=
sorted
(
real_axis
)
real_axis
=
sorted
(
real_axis
)
if
not
real_axis
:
return
a
def
_compute
(
*
idxs
):
def
_compute
(
*
idxs
):
indices
=
[]
indices
=
[]
axis_index
=
0
axis_index
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment