Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
de076999
Commit
de076999
authored
Aug 29, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NODE] Move op inside node attribute (#30)
parent
ac070f83
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
58 additions
and
54 deletions
+58
-54
nnvm/example/src/operator.cc
+1
-1
nnvm/include/nnvm/node.h
+19
-14
nnvm/include/nnvm/pass_functions.h
+1
-0
nnvm/src/core/graph.cc
+3
-3
nnvm/src/core/symbolic.cc
+15
-16
nnvm/src/pass/gradient.cc
+3
-3
nnvm/src/pass/infer_shape_type.cc
+3
-3
nnvm/src/pass/order_mutation.cc
+4
-5
nnvm/src/pass/place_device.cc
+1
-1
nnvm/src/pass/plan_memory.cc
+2
-2
nnvm/src/pass/saveload_json.cc
+6
-6
No files found.
nnvm/example/src/operator.cc
View file @
de076999
...
@@ -46,7 +46,7 @@ inline NodeEntry MakeNode(const char* op_name,
...
@@ -46,7 +46,7 @@ inline NodeEntry MakeNode(const char* op_name,
std
::
string
node_name
,
std
::
string
node_name
,
std
::
vector
<
NodeEntry
>
inputs
)
{
std
::
vector
<
NodeEntry
>
inputs
)
{
NodePtr
p
=
Node
::
Create
();
NodePtr
p
=
Node
::
Create
();
p
->
op
=
nnvm
::
Op
::
Get
(
op_name
);
p
->
attrs
.
op
=
nnvm
::
Op
::
Get
(
op_name
);
p
->
attrs
.
name
=
std
::
move
(
node_name
);
p
->
attrs
.
name
=
std
::
move
(
node_name
);
p
->
inputs
=
std
::
move
(
inputs
);
p
->
inputs
=
std
::
move
(
inputs
);
return
NodeEntry
{
p
,
0
,
0
};
return
NodeEntry
{
p
,
0
,
0
};
...
...
nnvm/include/nnvm/node.h
View file @
de076999
...
@@ -46,6 +46,11 @@ struct NodeEntry {
...
@@ -46,6 +46,11 @@ struct NodeEntry {
* Usually are additional parameters like axis,
* Usually are additional parameters like axis,
*/
*/
struct
NodeAttrs
{
struct
NodeAttrs
{
/*!
* \brief The operator this node uses.
* For place holder variable, op == nullptr.
*/
const
Op
*
op
{
nullptr
};
/*! \brief name of the node */
/*! \brief name of the node */
std
::
string
name
;
std
::
string
name
;
/*! \brief Vector representation of positional attributes */
/*! \brief Vector representation of positional attributes */
...
@@ -65,11 +70,8 @@ struct NodeAttrs {
...
@@ -65,11 +70,8 @@ struct NodeAttrs {
*/
*/
class
Node
{
class
Node
{
public
:
public
:
/*!
/*! \brief The attributes in the node. */
* \brief The operator this node uses.
NodeAttrs
attrs
;
* For place holder variable, op == nullptr.
*/
const
Op
*
op
{
nullptr
};
/*! \brief inputs to this node */
/*! \brief inputs to this node */
std
::
vector
<
NodeEntry
>
inputs
;
std
::
vector
<
NodeEntry
>
inputs
;
/*!
/*!
...
@@ -77,10 +79,10 @@ class Node {
...
@@ -77,10 +79,10 @@ class Node {
* Gives operation must be performed before this operation.
* Gives operation must be performed before this operation.
*/
*/
std
::
vector
<
NodePtr
>
control_deps
;
std
::
vector
<
NodePtr
>
control_deps
;
/*! \brief The attributes in the node. */
NodeAttrs
attrs
;
/*! \brief destructor of node */
/*! \brief destructor of node */
~
Node
();
~
Node
();
/*! \return operator in this node */
inline
const
Op
*
op
()
const
;
/*!
/*!
* \brief return whether node is placeholder variable.
* \brief return whether node is placeholder variable.
* This is equivalent to op == nullptr
* This is equivalent to op == nullptr
...
@@ -99,25 +101,28 @@ class Node {
...
@@ -99,25 +101,28 @@ class Node {
};
};
// implementation of functions.
// implementation of functions.
inline
const
Op
*
Node
::
op
()
const
{
return
this
->
attrs
.
op
;
}
inline
bool
Node
::
is_variable
()
const
{
inline
bool
Node
::
is_variable
()
const
{
return
this
->
op
==
nullptr
;
return
this
->
op
()
==
nullptr
;
}
}
inline
uint32_t
Node
::
num_outputs
()
const
{
inline
uint32_t
Node
::
num_outputs
()
const
{
if
(
is_variable
())
return
1
;
if
(
is_variable
())
return
1
;
if
(
this
->
op
->
get_num_outputs
==
nullptr
)
{
if
(
this
->
op
()
->
get_num_outputs
==
nullptr
)
{
return
this
->
op
->
num_outputs
;
return
this
->
op
()
->
num_outputs
;
}
else
{
}
else
{
return
this
->
op
->
get_num_outputs
(
this
->
attrs
);
return
this
->
op
()
->
get_num_outputs
(
this
->
attrs
);
}
}
}
}
inline
uint32_t
Node
::
num_inputs
()
const
{
inline
uint32_t
Node
::
num_inputs
()
const
{
if
(
is_variable
())
return
1
;
if
(
is_variable
())
return
1
;
if
(
this
->
op
->
get_num_inputs
==
nullptr
)
{
if
(
this
->
op
()
->
get_num_inputs
==
nullptr
)
{
return
this
->
op
->
num_inputs
;
return
this
->
op
()
->
num_inputs
;
}
else
{
}
else
{
return
this
->
op
->
get_num_inputs
(
this
->
attrs
);
return
this
->
op
()
->
get_num_inputs
(
this
->
attrs
);
}
}
}
}
...
...
nnvm/include/nnvm/pass_functions.h
View file @
de076999
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include <string>
#include <string>
#include <memory>
#include <memory>
#include <vector>
#include "./base.h"
#include "./base.h"
#include "./pass.h"
#include "./pass.h"
#include "./graph_attr_types.h"
#include "./graph_attr_types.h"
...
...
nnvm/src/core/graph.cc
View file @
de076999
...
@@ -66,9 +66,9 @@ IndexedGraph::IndexedGraph(const Graph &g) {
...
@@ -66,9 +66,9 @@ IndexedGraph::IndexedGraph(const Graph &g) {
for
(
size_t
nid
=
0
;
nid
<
nodes_
.
size
();
++
nid
)
{
for
(
size_t
nid
=
0
;
nid
<
nodes_
.
size
();
++
nid
)
{
nodes_
[
nid
].
inputs
=
array_view
<
NodeEntry
>
(
nodes_
[
nid
].
inputs
=
array_view
<
NodeEntry
>
(
iptr
+
inputs_rptr
[
nid
],
iptr
+
inputs_rptr
[
nid
+
1
]);
iptr
+
inputs_rptr
[
nid
],
iptr
+
inputs_rptr
[
nid
+
1
]);
if
(
nodes_
[
nid
].
source
->
op
!=
nullptr
&&
if
(
nodes_
[
nid
].
source
->
op
()
!=
nullptr
&&
fmutate_inputs
.
count
(
nodes_
[
nid
].
source
->
op
))
{
fmutate_inputs
.
count
(
nodes_
[
nid
].
source
->
op
()
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
nodes_
[
nid
].
source
->
op
](
nodes_
[
nid
].
source
->
attrs
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
nodes_
[
nid
].
source
->
op
()
](
nodes_
[
nid
].
source
->
attrs
))
{
mutable_input_nodes_
.
insert
(
nodes_
[
nid
].
inputs
[
i
].
node_id
);
mutable_input_nodes_
.
insert
(
nodes_
[
nid
].
inputs
[
i
].
node_id
);
}
}
}
}
...
...
nnvm/src/core/symbolic.cc
View file @
de076999
...
@@ -20,7 +20,7 @@ struct VariableParam {
...
@@ -20,7 +20,7 @@ struct VariableParam {
NodePtr
CreateVariableNode
(
const
std
::
string
&
name
)
{
NodePtr
CreateVariableNode
(
const
std
::
string
&
name
)
{
NodePtr
n
=
Node
::
Create
();
NodePtr
n
=
Node
::
Create
();
n
->
op
=
nullptr
;
n
->
attrs
.
op
=
nullptr
;
n
->
attrs
.
name
=
name
;
n
->
attrs
.
name
=
name
;
n
->
attrs
.
parsed
=
VariableParam
();
n
->
attrs
.
parsed
=
VariableParam
();
return
n
;
return
n
;
...
@@ -37,8 +37,8 @@ inline void UpdateNodeVersion(Node *n) {
...
@@ -37,8 +37,8 @@ inline void UpdateNodeVersion(Node *n) {
e
.
version
=
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
e
.
version
=
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
}
}
}
}
if
(
fmutate_inputs
.
count
(
n
->
op
)
!=
0
)
{
if
(
fmutate_inputs
.
count
(
n
->
op
()
)
!=
0
)
{
for
(
uint32_t
i
:
fmutate_inputs
[
n
->
op
](
n
->
attrs
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
n
->
op
()
](
n
->
attrs
))
{
NodeEntry
&
e
=
n
->
inputs
[
i
];
NodeEntry
&
e
=
n
->
inputs
[
i
];
CHECK
(
e
.
node
->
is_variable
())
CHECK
(
e
.
node
->
is_variable
())
<<
"Mutation target can only be Variable"
;
<<
"Mutation target can only be Variable"
;
...
@@ -96,7 +96,6 @@ Symbol Symbol::Copy() const {
...
@@ -96,7 +96,6 @@ Symbol Symbol::Copy() const {
// use DFSVisit to copy all the nodes
// use DFSVisit to copy all the nodes
DFSVisit
(
this
->
outputs
,
[
&
old_new
](
const
NodePtr
&
node
)
{
DFSVisit
(
this
->
outputs
,
[
&
old_new
](
const
NodePtr
&
node
)
{
NodePtr
np
=
Node
::
Create
();
NodePtr
np
=
Node
::
Create
();
np
->
op
=
node
->
op
;
np
->
attrs
=
node
->
attrs
;
np
->
attrs
=
node
->
attrs
;
old_new
[
node
.
get
()]
=
std
::
move
(
np
);
old_new
[
node
.
get
()]
=
std
::
move
(
np
);
});
});
...
@@ -123,7 +122,7 @@ void Symbol::Print(std::ostream &os) const {
...
@@ -123,7 +122,7 @@ void Symbol::Print(std::ostream &os) const {
if
(
outputs
[
0
].
node
->
is_variable
())
{
if
(
outputs
[
0
].
node
->
is_variable
())
{
os
<<
"Variable:"
<<
outputs
[
0
].
node
->
attrs
.
name
<<
'\n'
;
os
<<
"Variable:"
<<
outputs
[
0
].
node
->
attrs
.
name
<<
'\n'
;
}
else
{
}
else
{
os
<<
"AtomicFunctor "
<<
" Op:"
<<
outputs
[
0
].
node
->
op
->
name
<<
'\n'
;
os
<<
"AtomicFunctor "
<<
" Op:"
<<
outputs
[
0
].
node
->
op
()
->
name
<<
'\n'
;
}
}
}
else
{
}
else
{
// use DFSVisit to copy all the nodes
// use DFSVisit to copy all the nodes
...
@@ -137,7 +136,7 @@ void Symbol::Print(std::ostream &os) const {
...
@@ -137,7 +136,7 @@ void Symbol::Print(std::ostream &os) const {
os
<<
"Variable:"
<<
node
->
attrs
.
name
<<
'\n'
;
os
<<
"Variable:"
<<
node
->
attrs
.
name
<<
'\n'
;
}
else
{
}
else
{
os
<<
"--------------------
\n
"
;
os
<<
"--------------------
\n
"
;
os
<<
"Op:"
<<
node
->
op
->
name
<<
", Name="
<<
node
->
attrs
.
name
<<
'\n'
os
<<
"Op:"
<<
node
->
op
()
->
name
<<
", Name="
<<
node
->
attrs
.
name
<<
'\n'
<<
"Inputs:
\n
"
;
<<
"Inputs:
\n
"
;
for
(
size_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
const
NodeEntry
&
e
=
node
->
inputs
[
i
];
const
NodeEntry
&
e
=
node
->
inputs
[
i
];
...
@@ -196,8 +195,8 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
...
@@ -196,8 +195,8 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
DFSVisit
(
this
->
outputs
,
[
&
ret
,
&
mutable_set
,
&
vlist
](
const
NodePtr
&
node
)
{
DFSVisit
(
this
->
outputs
,
[
&
ret
,
&
mutable_set
,
&
vlist
](
const
NodePtr
&
node
)
{
if
(
node
->
is_variable
())
{
if
(
node
->
is_variable
())
{
vlist
.
push_back
(
node
.
get
());
vlist
.
push_back
(
node
.
get
());
}
else
if
(
fmutate_inputs
.
count
(
node
->
op
))
{
}
else
if
(
fmutate_inputs
.
count
(
node
->
op
()
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
node
->
op
](
node
->
attrs
)){
for
(
uint32_t
i
:
fmutate_inputs
[
node
->
op
()
](
node
->
attrs
)){
mutable_set
.
insert
(
node
->
inputs
[
i
].
node
.
get
());
mutable_set
.
insert
(
node
->
inputs
[
i
].
node
.
get
());
}
}
}
}
...
@@ -221,7 +220,7 @@ std::vector<std::string> Symbol::ListOutputNames() const {
...
@@ -221,7 +220,7 @@ std::vector<std::string> Symbol::ListOutputNames() const {
}
else
{
}
else
{
const
std
::
string
&
hname
=
head
.
node
->
attrs
.
name
;
const
std
::
string
&
hname
=
head
.
node
->
attrs
.
name
;
std
::
string
rname
;
std
::
string
rname
;
FListOutputNames
fn
=
flist_ouputs
.
get
(
head
.
node
->
op
,
nullptr
);
FListOutputNames
fn
=
flist_ouputs
.
get
(
head
.
node
->
op
()
,
nullptr
);
if
(
fn
!=
nullptr
)
{
if
(
fn
!=
nullptr
)
{
rname
=
fn
(
head
.
node
->
attrs
)[
head
.
index
];
rname
=
fn
(
head
.
node
->
attrs
)[
head
.
index
];
}
else
{
}
else
{
...
@@ -278,10 +277,10 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
...
@@ -278,10 +277,10 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
}
}
// switch to keyword argument matching
// switch to keyword argument matching
if
(
args
.
size
()
!=
n_req
)
{
if
(
args
.
size
()
!=
n_req
)
{
FListInputNames
fn
=
flist_inputs
.
get
(
n
->
op
,
nullptr
);
FListInputNames
fn
=
flist_inputs
.
get
(
n
->
op
()
,
nullptr
);
auto
arg_names
=
(
fn
==
nullptr
)
?
std
::
vector
<
std
::
string
>
{
"data"
}
:
fn
(
n
->
attrs
);
auto
arg_names
=
(
fn
==
nullptr
)
?
std
::
vector
<
std
::
string
>
{
"data"
}
:
fn
(
n
->
attrs
);
if
(
arg_names
.
size
()
!=
n_req
)
{
if
(
arg_names
.
size
()
!=
n_req
)
{
LOG
(
FATAL
)
<<
"Not enough argument to call operator "
<<
outputs
[
0
].
node
->
op
->
name
;
LOG
(
FATAL
)
<<
"Not enough argument to call operator "
<<
outputs
[
0
].
node
->
op
()
->
name
;
}
}
size_t
nmatched
=
0
;
size_t
nmatched
=
0
;
for
(
size_t
i
=
args
.
size
();
i
<
n_req
;
++
i
)
{
for
(
size_t
i
=
args
.
size
();
i
<
n_req
;
++
i
)
{
...
@@ -422,8 +421,8 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
...
@@ -422,8 +421,8 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
node
->
attrs
.
dict
[
kv
.
first
]
=
kv
.
second
;
node
->
attrs
.
dict
[
kv
.
first
]
=
kv
.
second
;
}
}
}
}
if
(
node
->
op
!=
nullptr
&&
node
->
op
->
attr_parser
!=
nullptr
)
{
if
(
node
->
op
()
!=
nullptr
&&
node
->
op
()
->
attr_parser
!=
nullptr
)
{
node
->
op
->
attr_parser
(
&
(
node
->
attrs
));
node
->
op
()
->
attr_parser
(
&
(
node
->
attrs
));
}
}
}
}
...
@@ -461,10 +460,10 @@ Symbol Symbol::CreateFunctor(const Op* op,
...
@@ -461,10 +460,10 @@ Symbol Symbol::CreateFunctor(const Op* op,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attrs
)
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attrs
)
{
Symbol
s
;
Symbol
s
;
NodePtr
n
=
Node
::
Create
();
NodePtr
n
=
Node
::
Create
();
n
->
op
=
op
;
n
->
attrs
.
op
=
op
;
n
->
attrs
.
dict
=
std
::
move
(
attrs
);
n
->
attrs
.
dict
=
std
::
move
(
attrs
);
if
(
n
->
op
->
attr_parser
!=
nullptr
)
{
if
(
n
->
op
()
->
attr_parser
!=
nullptr
)
{
n
->
op
->
attr_parser
(
&
(
n
->
attrs
));
n
->
op
()
->
attr_parser
(
&
(
n
->
attrs
));
}
}
s
.
outputs
.
emplace_back
(
NodeEntry
{
std
::
move
(
n
),
0
,
0
});
s
.
outputs
.
emplace_back
(
NodeEntry
{
std
::
move
(
n
),
0
,
0
});
return
s
;
return
s
;
...
...
nnvm/src/pass/gradient.cc
View file @
de076999
...
@@ -20,11 +20,11 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
...
@@ -20,11 +20,11 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
return
std
::
move
(
v
[
0
]);
return
std
::
move
(
v
[
0
]);
}
else
if
(
v
.
size
()
==
0
)
{
}
else
if
(
v
.
size
()
==
0
)
{
NodePtr
zero_node
=
Node
::
Create
();
NodePtr
zero_node
=
Node
::
Create
();
zero_node
->
op
=
Op
::
Get
(
"__zero__"
);
zero_node
->
attrs
.
op
=
Op
::
Get
(
"__zero__"
);
return
NodeEntry
{
zero_node
,
0
,
0
};
return
NodeEntry
{
zero_node
,
0
,
0
};
}
else
{
}
else
{
NodePtr
sum_node
=
Node
::
Create
();
NodePtr
sum_node
=
Node
::
Create
();
sum_node
->
op
=
Op
::
Get
(
"__ewise_sum__"
);
sum_node
->
attrs
.
op
=
Op
::
Get
(
"__ewise_sum__"
);
sum_node
->
inputs
=
std
::
move
(
v
);
sum_node
->
inputs
=
std
::
move
(
v
);
return
NodeEntry
{
sum_node
,
0
,
0
};
return
NodeEntry
{
sum_node
,
0
,
0
};
}
}
...
@@ -109,7 +109,7 @@ Graph Gradient(Graph src) {
...
@@ -109,7 +109,7 @@ Graph Gradient(Graph src) {
e
.
sum
=
agg_fun
(
std
::
move
(
e
.
grads
));
e
.
sum
=
agg_fun
(
std
::
move
(
e
.
grads
));
out_agg_grads
.
push_back
(
e
.
sum
);
out_agg_grads
.
push_back
(
e
.
sum
);
}
}
std
::
vector
<
NodeEntry
>
input_grads
=
grad_fun_map
[
ptr
->
op
]
std
::
vector
<
NodeEntry
>
input_grads
=
grad_fun_map
[
ptr
->
op
()
]
(
mirror_map
.
size
()
==
0
?
ptr
:
mirror_map
.
at
(
ptr
.
get
()),
out_agg_grads
);
(
mirror_map
.
size
()
==
0
?
ptr
:
mirror_map
.
at
(
ptr
.
get
()),
out_agg_grads
);
auto
git
=
input_grads
.
begin
();
auto
git
=
input_grads
.
begin
();
for
(
auto
it
=
(
*
rit
)
->
inputs
.
begin
();
it
!=
(
*
rit
)
->
inputs
.
end
();
++
it
,
++
git
)
{
for
(
auto
it
=
(
*
rit
)
->
inputs
.
begin
();
it
!=
(
*
rit
)
->
inputs
.
end
();
++
it
,
++
git
)
{
...
...
nnvm/src/pass/infer_shape_type.cc
View file @
de076999
...
@@ -65,7 +65,7 @@ Graph InferAttr(Graph &&ret,
...
@@ -65,7 +65,7 @@ Graph InferAttr(Graph &&ret,
}
}
continue
;
continue
;
}
}
if
(
finfer_shape
.
count
(
inode
.
source
->
op
))
{
if
(
finfer_shape
.
count
(
inode
.
source
->
op
()
))
{
ishape
.
resize
(
num_inputs
,
def_value
);
ishape
.
resize
(
num_inputs
,
def_value
);
for
(
uint32_t
i
=
0
;
i
<
ishape
.
size
();
++
i
)
{
for
(
uint32_t
i
=
0
;
i
<
ishape
.
size
();
++
i
)
{
ishape
[
i
]
=
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])];
ishape
[
i
]
=
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])];
...
@@ -75,14 +75,14 @@ Graph InferAttr(Graph &&ret,
...
@@ -75,14 +75,14 @@ Graph InferAttr(Graph &&ret,
oshape
[
i
]
=
rshape
[
idx
.
entry_id
(
nid
,
i
)];
oshape
[
i
]
=
rshape
[
idx
.
entry_id
(
nid
,
i
)];
}
}
num_unknown
+=
num_unknown
+=
!
(
finfer_shape
[
inode
.
source
->
op
](
inode
.
source
->
attrs
,
&
ishape
,
&
oshape
));
!
(
finfer_shape
[
inode
.
source
->
op
()
](
inode
.
source
->
attrs
,
&
ishape
,
&
oshape
));
for
(
uint32_t
i
=
0
;
i
<
num_inputs
;
++
i
)
{
for
(
uint32_t
i
=
0
;
i
<
num_inputs
;
++
i
)
{
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])]
=
ishape
[
i
];
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])]
=
ishape
[
i
];
}
}
for
(
uint32_t
i
=
0
;
i
<
num_outputs
;
++
i
)
{
for
(
uint32_t
i
=
0
;
i
<
num_outputs
;
++
i
)
{
rshape
[
idx
.
entry_id
(
nid
,
i
)]
=
oshape
[
i
];
rshape
[
idx
.
entry_id
(
nid
,
i
)]
=
oshape
[
i
];
}
}
}
else
if
(
is_backward
.
get
(
inode
.
source
->
op
,
false
))
{
}
else
if
(
is_backward
.
get
(
inode
.
source
->
op
()
,
false
))
{
// backward operator inference.
// backward operator inference.
CHECK_GE
(
inode
.
control_deps
.
size
(),
1
)
CHECK_GE
(
inode
.
control_deps
.
size
(),
1
)
<<
"BackwardOp need to have control_deps to its forward op"
;
<<
"BackwardOp need to have control_deps to its forward op"
;
...
...
nnvm/src/pass/order_mutation.cc
View file @
de076999
...
@@ -43,8 +43,8 @@ Graph OrderMutation(const Graph& src) {
...
@@ -43,8 +43,8 @@ Graph OrderMutation(const Graph& src) {
auto
prepare
=
[
&
version_hist
,
&
old_new
]
(
const
NodePtr
&
n
)
{
auto
prepare
=
[
&
version_hist
,
&
old_new
]
(
const
NodePtr
&
n
)
{
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
std
::
vector
<
uint32_t
>
mutate_inputs
;
std
::
vector
<
uint32_t
>
mutate_inputs
;
if
(
!
n
->
is_variable
()
&&
fmutate_inputs
.
count
(
n
->
op
))
{
if
(
!
n
->
is_variable
()
&&
fmutate_inputs
.
count
(
n
->
op
()
))
{
mutate_inputs
=
fmutate_inputs
[
n
->
op
](
n
->
attrs
);
mutate_inputs
=
fmutate_inputs
[
n
->
op
()
](
n
->
attrs
);
}
}
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
...
@@ -67,7 +67,6 @@ Graph OrderMutation(const Graph& src) {
...
@@ -67,7 +67,6 @@ Graph OrderMutation(const Graph& src) {
}
}
if
(
need_repl
)
{
if
(
need_repl
)
{
NodePtr
np
=
Node
::
Create
();
NodePtr
np
=
Node
::
Create
();
np
->
op
=
n
->
op
;
np
->
attrs
=
n
->
attrs
;
np
->
attrs
=
n
->
attrs
;
old_new
[
n
.
get
()]
=
std
::
move
(
np
);
old_new
[
n
.
get
()]
=
std
::
move
(
np
);
}
}
...
@@ -101,8 +100,8 @@ Graph OrderMutation(const Graph& src) {
...
@@ -101,8 +100,8 @@ Graph OrderMutation(const Graph& src) {
// add control deps
// add control deps
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
std
::
vector
<
uint32_t
>
mutate_inputs
;
std
::
vector
<
uint32_t
>
mutate_inputs
;
if
(
fmutate_inputs
.
count
(
kv
.
first
->
op
))
{
if
(
fmutate_inputs
.
count
(
kv
.
first
->
op
()
))
{
mutate_inputs
=
fmutate_inputs
[
kv
.
first
->
op
](
kv
.
first
->
attrs
);
mutate_inputs
=
fmutate_inputs
[
kv
.
first
->
op
()
](
kv
.
first
->
attrs
);
}
}
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
...
...
nnvm/src/pass/place_device.cc
View file @
de076999
...
@@ -109,9 +109,9 @@ Graph PlaceDevice(Graph src) {
...
@@ -109,9 +109,9 @@ Graph PlaceDevice(Graph src) {
NodeEntry
{
it
->
second
,
0
,
0
});
NodeEntry
{
it
->
second
,
0
,
0
});
}
else
{
}
else
{
NodePtr
copy_node
=
Node
::
Create
();
NodePtr
copy_node
=
Node
::
Create
();
copy_node
->
op
=
copy_op
;
std
::
ostringstream
os
;
std
::
ostringstream
os
;
os
<<
inode
.
source
->
inputs
[
i
].
node
->
attrs
.
name
<<
"_"
<<
e
.
index
<<
"_copy"
;
os
<<
inode
.
source
->
inputs
[
i
].
node
->
attrs
.
name
<<
"_"
<<
e
.
index
<<
"_copy"
;
copy_node
->
attrs
.
op
=
copy_op
;
copy_node
->
attrs
.
name
=
os
.
str
();
copy_node
->
attrs
.
name
=
os
.
str
();
copy_node
->
inputs
.
push_back
(
inode
.
source
->
inputs
[
i
]);
copy_node
->
inputs
.
push_back
(
inode
.
source
->
inputs
[
i
]);
copy_map
[
copy_key
]
=
copy_node
;
copy_map
[
copy_key
]
=
copy_node
;
...
...
nnvm/src/pass/plan_memory.cc
View file @
de076999
...
@@ -168,8 +168,8 @@ Graph PlanMemory(Graph ret) {
...
@@ -168,8 +168,8 @@ Graph PlanMemory(Graph ret) {
const
auto
&
inode
=
idx
[
nid
];
const
auto
&
inode
=
idx
[
nid
];
if
(
inode
.
source
->
is_variable
())
continue
;
if
(
inode
.
source
->
is_variable
())
continue
;
// check inplace option
// check inplace option
if
(
finplace_option
.
count
(
inode
.
source
->
op
)
!=
0
)
{
if
(
finplace_option
.
count
(
inode
.
source
->
op
()
)
!=
0
)
{
auto
inplace_pairs
=
finplace_option
[
inode
.
source
->
op
](
inode
.
source
->
attrs
);
auto
inplace_pairs
=
finplace_option
[
inode
.
source
->
op
()
](
inode
.
source
->
attrs
);
for
(
auto
&
kv
:
inplace_pairs
)
{
for
(
auto
&
kv
:
inplace_pairs
)
{
uint32_t
eid_out
=
idx
.
entry_id
(
nid
,
kv
.
second
);
uint32_t
eid_out
=
idx
.
entry_id
(
nid
,
kv
.
second
);
uint32_t
eid_in
=
idx
.
entry_id
(
inode
.
inputs
[
kv
.
first
]);
uint32_t
eid_in
=
idx
.
entry_id
(
inode
.
inputs
[
kv
.
first
]);
...
...
nnvm/src/pass/saveload_json.cc
View file @
de076999
...
@@ -68,8 +68,8 @@ struct JSONNode {
...
@@ -68,8 +68,8 @@ struct JSONNode {
// function to save JSON node.
// function to save JSON node.
void
Save
(
dmlc
::
JSONWriter
*
writer
)
const
{
void
Save
(
dmlc
::
JSONWriter
*
writer
)
const
{
writer
->
BeginObject
();
writer
->
BeginObject
();
if
(
node
->
op
!=
nullptr
)
{
if
(
node
->
op
()
!=
nullptr
)
{
writer
->
WriteObjectKeyValue
(
"op"
,
node
->
op
->
name
);
writer
->
WriteObjectKeyValue
(
"op"
,
node
->
op
()
->
name
);
}
else
{
}
else
{
std
::
string
json_null
=
"null"
;
std
::
string
json_null
=
"null"
;
writer
->
WriteObjectKeyValue
(
"op"
,
json_null
);
writer
->
WriteObjectKeyValue
(
"op"
,
json_null
);
...
@@ -108,10 +108,10 @@ struct JSONNode {
...
@@ -108,10 +108,10 @@ struct JSONNode {
if
(
op_type_str
!=
"null"
)
{
if
(
op_type_str
!=
"null"
)
{
try
{
try
{
node
->
op
=
Op
::
Get
(
op_type_str
);
node
->
attrs
.
op
=
Op
::
Get
(
op_type_str
);
// rebuild attribute parser
// rebuild attribute parser
if
(
node
->
op
->
attr_parser
!=
nullptr
)
{
if
(
node
->
op
()
->
attr_parser
!=
nullptr
)
{
node
->
op
->
attr_parser
(
&
(
node
->
attrs
));
node
->
op
()
->
attr_parser
(
&
(
node
->
attrs
));
}
}
}
catch
(
const
dmlc
::
Error
&
err
)
{
}
catch
(
const
dmlc
::
Error
&
err
)
{
std
::
ostringstream
os
;
std
::
ostringstream
os
;
...
@@ -120,7 +120,7 @@ struct JSONNode {
...
@@ -120,7 +120,7 @@ struct JSONNode {
throw
dmlc
::
Error
(
os
.
str
());
throw
dmlc
::
Error
(
os
.
str
());
}
}
}
else
{
}
else
{
node
->
op
=
nullptr
;
node
->
attrs
.
op
=
nullptr
;
}
}
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment