Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
1146b816
Commit
1146b816
authored
Apr 26, 2019
by
Pedro Larroy
Committed by
Tianqi Chen
Apr 26, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Check that the node is not null, add contains to OpMap (#3037)
parent
b405f68b
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
47 additions
and
20 deletions
+47
-20
3rdparty/dmlc-core
+1
-1
nnvm/include/nnvm/graph.h
+7
-3
nnvm/include/nnvm/op.h
+20
-2
nnvm/src/core/graph.cc
+1
-0
nnvm/src/pass/gradient.cc
+18
-14
No files found.
dmlc-core
@
82bf4c2e
Subproject commit
3ffea8694adf9c0363f9abbf162dc0e4a45b22c5
Subproject commit
82bf4c2e2af312b3d52513aa727483803a2f8734
nnvm/include/nnvm/graph.h
View file @
1146b816
...
...
@@ -315,12 +315,16 @@ inline void DFSVisit(const std::vector<NodeEntry>& heads,
});
PostOrderDFSVisit
<
GNode
,
Node
*>
(
head_nodes
,
[
fvisit
](
GNode
n
)
{
fvisit
(
*
n
);
},
// FVisit
[](
GNode
n
)
->
Node
*
{
return
n
->
get
();
},
// HashFunc
[
fvisit
](
GNode
n
)
{
fvisit
(
*
n
);
},
// FVisit
[](
GNode
n
)
->
Node
*
{
return
n
->
get
();
},
// HashFunc
[](
GNode
n
)
->
uint32_t
{
// InDegree
if
(
!
(
*
n
))
return
0
;
return
(
*
n
)
->
inputs
.
size
()
+
(
*
n
)
->
control_deps
.
size
();
},
},
[](
GNode
n
,
uint32_t
index
)
->
GNode
{
// GetInput
if
(
index
<
(
*
n
)
->
inputs
.
size
())
{
return
&
(
*
n
)
->
inputs
.
at
(
index
).
node
;
...
...
nnvm/include/nnvm/op.h
View file @
1146b816
...
...
@@ -340,6 +340,13 @@ class OpMap {
*/
inline
int
count
(
const
Op
*
op
)
const
;
/*!
* \brief Check if the map has op as key.
* \param op The key to the map
* \return true if op is contained in map, false otherwise.
*/
inline
bool
contains
(
const
Op
*
op
)
const
;
private
:
friend
class
Op
;
// internal attribute name
...
...
@@ -539,9 +546,20 @@ inline Op& Op::set_attr_parser(std::function<void (NodeAttrs* attrs)> fn) { //
// member functions of OpMap
template
<
typename
ValueType
>
inline
int
OpMap
<
ValueType
>::
count
(
const
Op
*
op
)
const
{
if
(
op
==
nullptr
)
return
0
;
if
(
contains
(
op
))
{
return
1
;
}
else
{
return
0
;
}
}
template
<
typename
ValueType
>
inline
bool
OpMap
<
ValueType
>::
contains
(
const
Op
*
op
)
const
{
if
(
op
==
nullptr
)
{
return
false
;
}
const
uint32_t
idx
=
op
->
index_
;
return
idx
<
data_
.
size
()
?
(
data_
[
idx
].
second
!=
0
)
:
0
;
return
idx
<
data_
.
size
()
?
(
data_
[
idx
].
second
!=
0
)
:
false
;
}
template
<
typename
ValueType
>
...
...
nnvm/src/core/graph.cc
View file @
1146b816
...
...
@@ -78,6 +78,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
(
const
NodePtr
&
n
)
{
CHECK_LT
(
nodes_
.
size
(),
std
::
numeric_limits
<
uint32_t
>::
max
());
uint32_t
nid
=
static_cast
<
uint32_t
>
(
nodes_
.
size
());
CHECK
(
n
);
for
(
const
auto
&
subgraph
:
n
->
attrs
.
subgraphs
)
subgraphs
.
push_back
(
subgraph
);
// nodes_
...
...
nnvm/src/pass/gradient.cc
View file @
1146b816
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -143,13 +143,13 @@ Graph Gradient(Graph src) {
<<
"because it is unreachable from the outputs."
;
}
// construct mirror
reduece memory
strategy if needed
// construct mirror
as memory reduction
strategy if needed
std
::
unordered_map
<
Node
*
,
NodePtr
>
mirror_map
;
if
(
mirror_fun
!=
nullptr
)
{
for
(
const
NodePtr
&
n
:
topo_order
)
{
if
(
mirror_fun
(
*
n
))
{
for
(
const
NodePtr
&
n
ode_ptr
:
topo_order
)
{
if
(
mirror_fun
(
*
n
ode_ptr
))
{
NodePtr
new_node
=
Node
::
Create
();
*
new_node
=
*
n
;
*
new_node
=
*
n
ode_ptr
;
new_node
->
attrs
.
name
+=
"_mirror"
;
for
(
auto
&
e
:
new_node
->
inputs
)
{
e
.
node
=
mirror_map
.
at
(
e
.
node
.
get
());
...
...
@@ -157,9 +157,9 @@ Graph Gradient(Graph src) {
for
(
auto
&
n
:
new_node
->
control_deps
)
{
n
=
mirror_map
.
at
(
n
.
get
());
}
mirror_map
[
n
.
get
()]
=
std
::
move
(
new_node
);
mirror_map
[
n
ode_ptr
.
get
()]
=
std
::
move
(
new_node
);
}
else
{
mirror_map
[
n
.
get
()]
=
n
;
mirror_map
[
n
ode_ptr
.
get
()]
=
node_ptr
;
}
}
}
...
...
@@ -185,7 +185,8 @@ Graph Gradient(Graph src) {
if
((
*
rit
)
->
inputs
.
size
()
!=
0
)
{
NodePtr
fwd_node
=
(
mirror_map
.
size
()
==
0
?
ptr
:
mirror_map
.
at
(
ptr
.
get
()));
std
::
vector
<
NodeEntry
>
input_grads
;
if
(
grad_fun_map
.
count
(
ptr
->
op
()))
{
// Check for FGradient
if
(
grad_fun_map
.
contains
(
ptr
->
op
()))
{
input_grads
=
grad_fun_map
[
ptr
->
op
()](
fwd_node
,
out_agg_grads
);
CHECK_EQ
((
*
rit
)
->
inputs
.
size
(),
input_grads
.
size
())
<<
"Gradient function not returning enough gradient"
;
...
...
@@ -205,20 +206,23 @@ Graph Gradient(Graph src) {
if
(
p
->
op
()
->
attr_parser
!=
nullptr
)
{
p
->
op
()
->
attr_parser
(
&
(
p
->
attrs
));
}
input_grads
.
emplace_back
(
nnvm
::
NodeEntry
{
p
,
0
,
0
}
);
input_grads
.
emplace_back
(
p
,
0
,
0
);
}
}
else
{
LOG
(
FATAL
)
<<
"Operator "
<<
fwd_node
->
op
()
->
name
<<
" is non-differentiable "
<<
"because it didn't register FGradient attribute."
;
}
for
(
const
auto
&
nodeEntry
:
input_grads
)
CHECK
(
nodeEntry
.
node
);
auto
git
=
input_grads
.
begin
();
CHECK
((
*
rit
)
->
inputs
.
size
()
<=
input_grads
.
size
());
for
(
auto
it
=
(
*
rit
)
->
inputs
.
begin
();
it
!=
(
*
rit
)
->
inputs
.
end
();
++
it
,
++
git
)
{
auto
&
ge
=
output_grads
[
it
->
node
.
get
()][
it
->
index
];
auto
&
output_grad_entry
=
output_grads
[
it
->
node
.
get
()][
it
->
index
];
// if any of the backward op can do shape inference, the hint is not necessary.
if
(
finfer_shape
.
co
unt
(
git
->
node
->
op
()))
{
ge
.
need_attr_hint
=
false
;
if
(
finfer_shape
.
co
ntains
(
git
->
node
->
op
()))
{
output_grad_entry
.
need_attr_hint
=
false
;
}
ge
.
grads
.
emplace_back
(
std
::
move
(
*
git
));
output_grad_entry
.
grads
.
emplace_back
(
std
::
move
(
*
git
));
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment