Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
de076999
Commit
de076999
authored
Aug 29, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[NODE] Move op inside node attribute (#30)
parent
ac070f83
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
58 additions
and
54 deletions
+58
-54
nnvm/example/src/operator.cc
+1
-1
nnvm/include/nnvm/node.h
+19
-14
nnvm/include/nnvm/pass_functions.h
+1
-0
nnvm/src/core/graph.cc
+3
-3
nnvm/src/core/symbolic.cc
+15
-16
nnvm/src/pass/gradient.cc
+3
-3
nnvm/src/pass/infer_shape_type.cc
+3
-3
nnvm/src/pass/order_mutation.cc
+4
-5
nnvm/src/pass/place_device.cc
+1
-1
nnvm/src/pass/plan_memory.cc
+2
-2
nnvm/src/pass/saveload_json.cc
+6
-6
No files found.
nnvm/example/src/operator.cc
View file @
de076999
...
...
@@ -46,7 +46,7 @@ inline NodeEntry MakeNode(const char* op_name,
std
::
string
node_name
,
std
::
vector
<
NodeEntry
>
inputs
)
{
NodePtr
p
=
Node
::
Create
();
p
->
op
=
nnvm
::
Op
::
Get
(
op_name
);
p
->
attrs
.
op
=
nnvm
::
Op
::
Get
(
op_name
);
p
->
attrs
.
name
=
std
::
move
(
node_name
);
p
->
inputs
=
std
::
move
(
inputs
);
return
NodeEntry
{
p
,
0
,
0
};
...
...
nnvm/include/nnvm/node.h
View file @
de076999
...
...
@@ -46,6 +46,11 @@ struct NodeEntry {
* Usually are additional parameters like axis,
*/
struct
NodeAttrs
{
/*!
* \brief The operator this node uses.
* For place holder variable, op == nullptr.
*/
const
Op
*
op
{
nullptr
};
/*! \brief name of the node */
std
::
string
name
;
/*! \brief Vector representation of positional attributes */
...
...
@@ -65,11 +70,8 @@ struct NodeAttrs {
*/
class
Node
{
public
:
/*!
* \brief The operator this node uses.
* For place holder variable, op == nullptr.
*/
const
Op
*
op
{
nullptr
};
/*! \brief The attributes in the node. */
NodeAttrs
attrs
;
/*! \brief inputs to this node */
std
::
vector
<
NodeEntry
>
inputs
;
/*!
...
...
@@ -77,10 +79,10 @@ class Node {
* Gives operation must be performed before this operation.
*/
std
::
vector
<
NodePtr
>
control_deps
;
/*! \brief The attributes in the node. */
NodeAttrs
attrs
;
/*! \brief destructor of node */
~
Node
();
/*! \return operator in this node */
inline
const
Op
*
op
()
const
;
/*!
* \brief return whether node is placeholder variable.
* This is equivalent to op == nullptr
...
...
@@ -99,25 +101,28 @@ class Node {
};
// implementation of functions.
inline
const
Op
*
Node
::
op
()
const
{
return
this
->
attrs
.
op
;
}
inline
bool
Node
::
is_variable
()
const
{
return
this
->
op
==
nullptr
;
return
this
->
op
()
==
nullptr
;
}
inline
uint32_t
Node
::
num_outputs
()
const
{
if
(
is_variable
())
return
1
;
if
(
this
->
op
->
get_num_outputs
==
nullptr
)
{
return
this
->
op
->
num_outputs
;
if
(
this
->
op
()
->
get_num_outputs
==
nullptr
)
{
return
this
->
op
()
->
num_outputs
;
}
else
{
return
this
->
op
->
get_num_outputs
(
this
->
attrs
);
return
this
->
op
()
->
get_num_outputs
(
this
->
attrs
);
}
}
inline
uint32_t
Node
::
num_inputs
()
const
{
if
(
is_variable
())
return
1
;
if
(
this
->
op
->
get_num_inputs
==
nullptr
)
{
return
this
->
op
->
num_inputs
;
if
(
this
->
op
()
->
get_num_inputs
==
nullptr
)
{
return
this
->
op
()
->
num_inputs
;
}
else
{
return
this
->
op
->
get_num_inputs
(
this
->
attrs
);
return
this
->
op
()
->
get_num_inputs
(
this
->
attrs
);
}
}
...
...
nnvm/include/nnvm/pass_functions.h
View file @
de076999
...
...
@@ -12,6 +12,7 @@
#include <string>
#include <memory>
#include <vector>
#include "./base.h"
#include "./pass.h"
#include "./graph_attr_types.h"
...
...
nnvm/src/core/graph.cc
View file @
de076999
...
...
@@ -66,9 +66,9 @@ IndexedGraph::IndexedGraph(const Graph &g) {
for
(
size_t
nid
=
0
;
nid
<
nodes_
.
size
();
++
nid
)
{
nodes_
[
nid
].
inputs
=
array_view
<
NodeEntry
>
(
iptr
+
inputs_rptr
[
nid
],
iptr
+
inputs_rptr
[
nid
+
1
]);
if
(
nodes_
[
nid
].
source
->
op
!=
nullptr
&&
fmutate_inputs
.
count
(
nodes_
[
nid
].
source
->
op
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
nodes_
[
nid
].
source
->
op
](
nodes_
[
nid
].
source
->
attrs
))
{
if
(
nodes_
[
nid
].
source
->
op
()
!=
nullptr
&&
fmutate_inputs
.
count
(
nodes_
[
nid
].
source
->
op
()
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
nodes_
[
nid
].
source
->
op
()
](
nodes_
[
nid
].
source
->
attrs
))
{
mutable_input_nodes_
.
insert
(
nodes_
[
nid
].
inputs
[
i
].
node_id
);
}
}
...
...
nnvm/src/core/symbolic.cc
View file @
de076999
...
...
@@ -20,7 +20,7 @@ struct VariableParam {
NodePtr
CreateVariableNode
(
const
std
::
string
&
name
)
{
NodePtr
n
=
Node
::
Create
();
n
->
op
=
nullptr
;
n
->
attrs
.
op
=
nullptr
;
n
->
attrs
.
name
=
name
;
n
->
attrs
.
parsed
=
VariableParam
();
return
n
;
...
...
@@ -37,8 +37,8 @@ inline void UpdateNodeVersion(Node *n) {
e
.
version
=
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
}
}
if
(
fmutate_inputs
.
count
(
n
->
op
)
!=
0
)
{
for
(
uint32_t
i
:
fmutate_inputs
[
n
->
op
](
n
->
attrs
))
{
if
(
fmutate_inputs
.
count
(
n
->
op
()
)
!=
0
)
{
for
(
uint32_t
i
:
fmutate_inputs
[
n
->
op
()
](
n
->
attrs
))
{
NodeEntry
&
e
=
n
->
inputs
[
i
];
CHECK
(
e
.
node
->
is_variable
())
<<
"Mutation target can only be Variable"
;
...
...
@@ -96,7 +96,6 @@ Symbol Symbol::Copy() const {
// use DFSVisit to copy all the nodes
DFSVisit
(
this
->
outputs
,
[
&
old_new
](
const
NodePtr
&
node
)
{
NodePtr
np
=
Node
::
Create
();
np
->
op
=
node
->
op
;
np
->
attrs
=
node
->
attrs
;
old_new
[
node
.
get
()]
=
std
::
move
(
np
);
});
...
...
@@ -123,7 +122,7 @@ void Symbol::Print(std::ostream &os) const {
if
(
outputs
[
0
].
node
->
is_variable
())
{
os
<<
"Variable:"
<<
outputs
[
0
].
node
->
attrs
.
name
<<
'\n'
;
}
else
{
os
<<
"AtomicFunctor "
<<
" Op:"
<<
outputs
[
0
].
node
->
op
->
name
<<
'\n'
;
os
<<
"AtomicFunctor "
<<
" Op:"
<<
outputs
[
0
].
node
->
op
()
->
name
<<
'\n'
;
}
}
else
{
// use DFSVisit to copy all the nodes
...
...
@@ -137,7 +136,7 @@ void Symbol::Print(std::ostream &os) const {
os
<<
"Variable:"
<<
node
->
attrs
.
name
<<
'\n'
;
}
else
{
os
<<
"--------------------
\n
"
;
os
<<
"Op:"
<<
node
->
op
->
name
<<
", Name="
<<
node
->
attrs
.
name
<<
'\n'
os
<<
"Op:"
<<
node
->
op
()
->
name
<<
", Name="
<<
node
->
attrs
.
name
<<
'\n'
<<
"Inputs:
\n
"
;
for
(
size_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
const
NodeEntry
&
e
=
node
->
inputs
[
i
];
...
...
@@ -196,8 +195,8 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
DFSVisit
(
this
->
outputs
,
[
&
ret
,
&
mutable_set
,
&
vlist
](
const
NodePtr
&
node
)
{
if
(
node
->
is_variable
())
{
vlist
.
push_back
(
node
.
get
());
}
else
if
(
fmutate_inputs
.
count
(
node
->
op
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
node
->
op
](
node
->
attrs
)){
}
else
if
(
fmutate_inputs
.
count
(
node
->
op
()
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
node
->
op
()
](
node
->
attrs
)){
mutable_set
.
insert
(
node
->
inputs
[
i
].
node
.
get
());
}
}
...
...
@@ -221,7 +220,7 @@ std::vector<std::string> Symbol::ListOutputNames() const {
}
else
{
const
std
::
string
&
hname
=
head
.
node
->
attrs
.
name
;
std
::
string
rname
;
FListOutputNames
fn
=
flist_ouputs
.
get
(
head
.
node
->
op
,
nullptr
);
FListOutputNames
fn
=
flist_ouputs
.
get
(
head
.
node
->
op
()
,
nullptr
);
if
(
fn
!=
nullptr
)
{
rname
=
fn
(
head
.
node
->
attrs
)[
head
.
index
];
}
else
{
...
...
@@ -278,10 +277,10 @@ void Symbol::Compose(const array_view<const Symbol*>& args,
}
// switch to keyword argument matching
if
(
args
.
size
()
!=
n_req
)
{
FListInputNames
fn
=
flist_inputs
.
get
(
n
->
op
,
nullptr
);
FListInputNames
fn
=
flist_inputs
.
get
(
n
->
op
()
,
nullptr
);
auto
arg_names
=
(
fn
==
nullptr
)
?
std
::
vector
<
std
::
string
>
{
"data"
}
:
fn
(
n
->
attrs
);
if
(
arg_names
.
size
()
!=
n_req
)
{
LOG
(
FATAL
)
<<
"Not enough argument to call operator "
<<
outputs
[
0
].
node
->
op
->
name
;
LOG
(
FATAL
)
<<
"Not enough argument to call operator "
<<
outputs
[
0
].
node
->
op
()
->
name
;
}
size_t
nmatched
=
0
;
for
(
size_t
i
=
args
.
size
();
i
<
n_req
;
++
i
)
{
...
...
@@ -422,8 +421,8 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
node
->
attrs
.
dict
[
kv
.
first
]
=
kv
.
second
;
}
}
if
(
node
->
op
!=
nullptr
&&
node
->
op
->
attr_parser
!=
nullptr
)
{
node
->
op
->
attr_parser
(
&
(
node
->
attrs
));
if
(
node
->
op
()
!=
nullptr
&&
node
->
op
()
->
attr_parser
!=
nullptr
)
{
node
->
op
()
->
attr_parser
(
&
(
node
->
attrs
));
}
}
...
...
@@ -461,10 +460,10 @@ Symbol Symbol::CreateFunctor(const Op* op,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attrs
)
{
Symbol
s
;
NodePtr
n
=
Node
::
Create
();
n
->
op
=
op
;
n
->
attrs
.
op
=
op
;
n
->
attrs
.
dict
=
std
::
move
(
attrs
);
if
(
n
->
op
->
attr_parser
!=
nullptr
)
{
n
->
op
->
attr_parser
(
&
(
n
->
attrs
));
if
(
n
->
op
()
->
attr_parser
!=
nullptr
)
{
n
->
op
()
->
attr_parser
(
&
(
n
->
attrs
));
}
s
.
outputs
.
emplace_back
(
NodeEntry
{
std
::
move
(
n
),
0
,
0
});
return
s
;
...
...
nnvm/src/pass/gradient.cc
View file @
de076999
...
...
@@ -20,11 +20,11 @@ NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
return
std
::
move
(
v
[
0
]);
}
else
if
(
v
.
size
()
==
0
)
{
NodePtr
zero_node
=
Node
::
Create
();
zero_node
->
op
=
Op
::
Get
(
"__zero__"
);
zero_node
->
attrs
.
op
=
Op
::
Get
(
"__zero__"
);
return
NodeEntry
{
zero_node
,
0
,
0
};
}
else
{
NodePtr
sum_node
=
Node
::
Create
();
sum_node
->
op
=
Op
::
Get
(
"__ewise_sum__"
);
sum_node
->
attrs
.
op
=
Op
::
Get
(
"__ewise_sum__"
);
sum_node
->
inputs
=
std
::
move
(
v
);
return
NodeEntry
{
sum_node
,
0
,
0
};
}
...
...
@@ -109,7 +109,7 @@ Graph Gradient(Graph src) {
e
.
sum
=
agg_fun
(
std
::
move
(
e
.
grads
));
out_agg_grads
.
push_back
(
e
.
sum
);
}
std
::
vector
<
NodeEntry
>
input_grads
=
grad_fun_map
[
ptr
->
op
]
std
::
vector
<
NodeEntry
>
input_grads
=
grad_fun_map
[
ptr
->
op
()
]
(
mirror_map
.
size
()
==
0
?
ptr
:
mirror_map
.
at
(
ptr
.
get
()),
out_agg_grads
);
auto
git
=
input_grads
.
begin
();
for
(
auto
it
=
(
*
rit
)
->
inputs
.
begin
();
it
!=
(
*
rit
)
->
inputs
.
end
();
++
it
,
++
git
)
{
...
...
nnvm/src/pass/infer_shape_type.cc
View file @
de076999
...
...
@@ -65,7 +65,7 @@ Graph InferAttr(Graph &&ret,
}
continue
;
}
if
(
finfer_shape
.
count
(
inode
.
source
->
op
))
{
if
(
finfer_shape
.
count
(
inode
.
source
->
op
()
))
{
ishape
.
resize
(
num_inputs
,
def_value
);
for
(
uint32_t
i
=
0
;
i
<
ishape
.
size
();
++
i
)
{
ishape
[
i
]
=
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])];
...
...
@@ -75,14 +75,14 @@ Graph InferAttr(Graph &&ret,
oshape
[
i
]
=
rshape
[
idx
.
entry_id
(
nid
,
i
)];
}
num_unknown
+=
!
(
finfer_shape
[
inode
.
source
->
op
](
inode
.
source
->
attrs
,
&
ishape
,
&
oshape
));
!
(
finfer_shape
[
inode
.
source
->
op
()
](
inode
.
source
->
attrs
,
&
ishape
,
&
oshape
));
for
(
uint32_t
i
=
0
;
i
<
num_inputs
;
++
i
)
{
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])]
=
ishape
[
i
];
}
for
(
uint32_t
i
=
0
;
i
<
num_outputs
;
++
i
)
{
rshape
[
idx
.
entry_id
(
nid
,
i
)]
=
oshape
[
i
];
}
}
else
if
(
is_backward
.
get
(
inode
.
source
->
op
,
false
))
{
}
else
if
(
is_backward
.
get
(
inode
.
source
->
op
()
,
false
))
{
// backward operator inference.
CHECK_GE
(
inode
.
control_deps
.
size
(),
1
)
<<
"BackwardOp need to have control_deps to its forward op"
;
...
...
nnvm/src/pass/order_mutation.cc
View file @
de076999
...
...
@@ -43,8 +43,8 @@ Graph OrderMutation(const Graph& src) {
auto
prepare
=
[
&
version_hist
,
&
old_new
]
(
const
NodePtr
&
n
)
{
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
std
::
vector
<
uint32_t
>
mutate_inputs
;
if
(
!
n
->
is_variable
()
&&
fmutate_inputs
.
count
(
n
->
op
))
{
mutate_inputs
=
fmutate_inputs
[
n
->
op
](
n
->
attrs
);
if
(
!
n
->
is_variable
()
&&
fmutate_inputs
.
count
(
n
->
op
()
))
{
mutate_inputs
=
fmutate_inputs
[
n
->
op
()
](
n
->
attrs
);
}
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
...
...
@@ -67,7 +67,6 @@ Graph OrderMutation(const Graph& src) {
}
if
(
need_repl
)
{
NodePtr
np
=
Node
::
Create
();
np
->
op
=
n
->
op
;
np
->
attrs
=
n
->
attrs
;
old_new
[
n
.
get
()]
=
std
::
move
(
np
);
}
...
...
@@ -101,8 +100,8 @@ Graph OrderMutation(const Graph& src) {
// add control deps
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
std
::
vector
<
uint32_t
>
mutate_inputs
;
if
(
fmutate_inputs
.
count
(
kv
.
first
->
op
))
{
mutate_inputs
=
fmutate_inputs
[
kv
.
first
->
op
](
kv
.
first
->
attrs
);
if
(
fmutate_inputs
.
count
(
kv
.
first
->
op
()
))
{
mutate_inputs
=
fmutate_inputs
[
kv
.
first
->
op
()
](
kv
.
first
->
attrs
);
}
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
...
...
nnvm/src/pass/place_device.cc
View file @
de076999
...
...
@@ -109,9 +109,9 @@ Graph PlaceDevice(Graph src) {
NodeEntry
{
it
->
second
,
0
,
0
});
}
else
{
NodePtr
copy_node
=
Node
::
Create
();
copy_node
->
op
=
copy_op
;
std
::
ostringstream
os
;
os
<<
inode
.
source
->
inputs
[
i
].
node
->
attrs
.
name
<<
"_"
<<
e
.
index
<<
"_copy"
;
copy_node
->
attrs
.
op
=
copy_op
;
copy_node
->
attrs
.
name
=
os
.
str
();
copy_node
->
inputs
.
push_back
(
inode
.
source
->
inputs
[
i
]);
copy_map
[
copy_key
]
=
copy_node
;
...
...
nnvm/src/pass/plan_memory.cc
View file @
de076999
...
...
@@ -168,8 +168,8 @@ Graph PlanMemory(Graph ret) {
const
auto
&
inode
=
idx
[
nid
];
if
(
inode
.
source
->
is_variable
())
continue
;
// check inplace option
if
(
finplace_option
.
count
(
inode
.
source
->
op
)
!=
0
)
{
auto
inplace_pairs
=
finplace_option
[
inode
.
source
->
op
](
inode
.
source
->
attrs
);
if
(
finplace_option
.
count
(
inode
.
source
->
op
()
)
!=
0
)
{
auto
inplace_pairs
=
finplace_option
[
inode
.
source
->
op
()
](
inode
.
source
->
attrs
);
for
(
auto
&
kv
:
inplace_pairs
)
{
uint32_t
eid_out
=
idx
.
entry_id
(
nid
,
kv
.
second
);
uint32_t
eid_in
=
idx
.
entry_id
(
inode
.
inputs
[
kv
.
first
]);
...
...
nnvm/src/pass/saveload_json.cc
View file @
de076999
...
...
@@ -68,8 +68,8 @@ struct JSONNode {
// function to save JSON node.
void
Save
(
dmlc
::
JSONWriter
*
writer
)
const
{
writer
->
BeginObject
();
if
(
node
->
op
!=
nullptr
)
{
writer
->
WriteObjectKeyValue
(
"op"
,
node
->
op
->
name
);
if
(
node
->
op
()
!=
nullptr
)
{
writer
->
WriteObjectKeyValue
(
"op"
,
node
->
op
()
->
name
);
}
else
{
std
::
string
json_null
=
"null"
;
writer
->
WriteObjectKeyValue
(
"op"
,
json_null
);
...
...
@@ -108,10 +108,10 @@ struct JSONNode {
if
(
op_type_str
!=
"null"
)
{
try
{
node
->
op
=
Op
::
Get
(
op_type_str
);
node
->
attrs
.
op
=
Op
::
Get
(
op_type_str
);
// rebuild attribute parser
if
(
node
->
op
->
attr_parser
!=
nullptr
)
{
node
->
op
->
attr_parser
(
&
(
node
->
attrs
));
if
(
node
->
op
()
->
attr_parser
!=
nullptr
)
{
node
->
op
()
->
attr_parser
(
&
(
node
->
attrs
));
}
}
catch
(
const
dmlc
::
Error
&
err
)
{
std
::
ostringstream
os
;
...
...
@@ -120,7 +120,7 @@ struct JSONNode {
throw
dmlc
::
Error
(
os
.
str
());
}
}
else
{
node
->
op
=
nullptr
;
node
->
attrs
.
op
=
nullptr
;
}
}
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment