Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
486249e8
Commit
486249e8
authored
Aug 09, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Update mutate function (#23)
parent
16a6db3a
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
59 additions
and
52 deletions
+59
-52
nnvm/docs/Doxyfile
+2
-2
nnvm/example/src/operator.cc
+3
-3
nnvm/include/nnvm/graph.h
+4
-4
nnvm/include/nnvm/op_attr_types.h
+3
-4
nnvm/include/nnvm/pass_functions.h
+12
-12
nnvm/src/core/graph.cc
+1
-1
nnvm/src/core/symbolic.cc
+4
-10
nnvm/src/pass/infer_shape_type.cc
+7
-7
nnvm/src/pass/order_mutation.cc
+23
-9
No files found.
nnvm/docs/Doxyfile
View file @
486249e8
...
...
@@ -32,7 +32,7 @@ DOXYFILE_ENCODING = UTF-8
# title of most generated pages and in a few other places.
# The default value is: My Project.
PROJECT_NAME = "
mxnngraph
"
PROJECT_NAME = "
nnvm
"
# The PROJECT_NUMBER tag can be used to enter a project or revision number. This
# could be handy for archiving the generated documentation or if some version
...
...
@@ -753,7 +753,7 @@ WARN_LOGFILE =
# spaces.
# Note: If this tag is empty the current directory is searched.
INPUT = include
INPUT = include
/nnvm
# This tag can be used to specify the character encoding of the source files
# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
...
...
nnvm/example/src/operator.cc
View file @
486249e8
...
...
@@ -11,7 +11,7 @@
namespace
myproject
{
using
nnvm
::
FListInputNames
;
using
nnvm
::
FMutateInput
;
using
nnvm
::
FMutateInput
s
;
using
nnvm
::
FInferShape
;
using
nnvm
::
FInferType
;
using
nnvm
::
FInplaceOption
;
...
...
@@ -119,8 +119,8 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP
(
assign
)
.
set_num_inputs
(
2
)
.
set_num_outputs
(
1
)
.
attr
<
FMutateInput
>
(
"FMutateInput"
,
[](
const
NodeAttrs
&
attrs
,
uint32_t
index
)
{
return
index
==
0
;
.
attr
<
FMutateInput
s
>
(
"FMutateInputs"
,
[](
const
NodeAttrs
&
attrs
)
{
return
std
::
vector
<
uint32_t
>
{
0
}
;
});
}
// namespace myproject
nnvm/include/nnvm/graph.h
View file @
486249e8
...
...
@@ -144,8 +144,8 @@ class IndexedGraph {
return
nodes_
[
node_id
(
node
)];
}
/*! \return list of argument nodes */
inline
const
std
::
vector
<
uint32_t
>&
arg
_nodes
()
const
{
return
arg
_nodes_
;
inline
const
std
::
vector
<
uint32_t
>&
input
_nodes
()
const
{
return
input
_nodes_
;
}
/*! \return list of output entries */
inline
const
std
::
vector
<
NodeEntry
>&
outputs
()
const
{
...
...
@@ -161,8 +161,8 @@ class IndexedGraph {
explicit
IndexedGraph
(
const
Graph
&
other
);
// node pointers in CSR structure.
std
::
vector
<
Node
>
nodes_
;
// index to
argumen
t nodes
std
::
vector
<
uint32_t
>
arg
_nodes_
;
// index to
inpu
t nodes
std
::
vector
<
uint32_t
>
input
_nodes_
;
// space to store the outputs entries
std
::
vector
<
NodeEntry
>
outputs_
;
// mapping from node to index.
...
...
nnvm/include/nnvm/op_attr_types.h
View file @
486249e8
...
...
@@ -43,13 +43,12 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
/*!
* \brief Check whether operator will mutate k-th input.
* \param attrs The attributes of the node.
* \param index The input index
* \return Whether this operator will mutate index-th input.
* \return list of input indices it mutates.
*
* \note Register under "FMutateInput", default return false
* \note Register under "FMutateInput
s
", default return false
* FMutateInputs enables mutation order handling correctly.
*/
using
FMutateInput
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
uint32_t
index
)
>
;
using
FMutateInput
s
=
std
::
function
<
std
::
vector
<
uint32_t
>
(
const
NodeAttrs
&
attrs
)
>
;
/*!
* \brief Inference function of certain type.
...
...
nnvm/include/nnvm/pass_functions.h
View file @
486249e8
...
...
@@ -54,16 +54,16 @@ inline Graph OrderMutation(Graph src) {
/*!
* \brief Infer shapes in the graph given the information.
* \param graph source graph
* \param shape_
arg
s The shapes of aruguments to the graph.
* \param shape_
input
s The shapes of aruguments to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
*/
inline
Graph
InferShape
(
Graph
graph
,
ShapeVector
shape_
arg
s
=
{},
ShapeVector
shape_
input
s
=
{},
std
::
string
shape_attr_key
=
""
)
{
if
(
shape_
arg
s
.
size
()
!=
0
)
{
graph
.
attrs
[
"shape_
args"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
shape_arg
s
));
if
(
shape_
input
s
.
size
()
!=
0
)
{
graph
.
attrs
[
"shape_
inputs"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
shape_input
s
));
}
if
(
shape_attr_key
.
length
()
!=
0
)
{
graph
.
attrs
[
"shape_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
shape_attr_key
));
...
...
@@ -74,19 +74,19 @@ inline Graph InferShape(Graph graph,
/*!
* \brief Infer types in the graph given the information.
* \param graph source graph
* \param
shape_args The shapes of arugumen
ts to the graph.
* \param
sha
pe_attr_key The key to the node attribute that can indicate shape.
* \param
dtype_inputs The shapes of inpu
ts to the graph.
* \param
dty
pe_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
*/
inline
Graph
InferType
(
Graph
graph
,
DTypeVector
type_arg
s
=
{},
std
::
string
type_attr_key
=
""
)
{
if
(
type_arg
s
.
size
()
!=
0
)
{
graph
.
attrs
[
"dtype_
args"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
type_arg
s
));
DTypeVector
dtype_input
s
=
{},
std
::
string
d
type_attr_key
=
""
)
{
if
(
dtype_input
s
.
size
()
!=
0
)
{
graph
.
attrs
[
"dtype_
inputs"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
dtype_input
s
));
}
if
(
type_attr_key
.
length
()
!=
0
)
{
graph
.
attrs
[
"dtype_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
type_attr_key
));
if
(
d
type_attr_key
.
length
()
!=
0
)
{
graph
.
attrs
[
"dtype_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
d
type_attr_key
));
}
return
ApplyPass
(
std
::
move
(
graph
),
{
"InferType"
});
}
...
...
nnvm/src/core/graph.cc
View file @
486249e8
...
...
@@ -30,7 +30,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
nodes_
.
emplace_back
(
std
::
move
(
new_node
));
// arg_nodes_
if
(
n
->
is_variable
())
{
arg
_nodes_
.
push_back
(
nid
);
input
_nodes_
.
push_back
(
nid
);
}
// node2index_
node2index_
[
n
.
get
()]
=
nid
;
...
...
nnvm/src/core/symbolic.cc
View file @
486249e8
...
...
@@ -31,16 +31,14 @@ NodePtr CreateVariableNode(const std::string& name) {
// The version of that varaible will increase
// version is used to implicitly order the mutation sequences
inline
void
UpdateNodeVersion
(
Node
*
n
)
{
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
>
(
"FMutateInput
"
);
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
s
>
(
"FMutateInputs
"
);
for
(
NodeEntry
&
e
:
n
->
inputs
)
{
if
(
e
.
node
->
is_variable
())
{
e
.
version
=
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
}
}
if
(
fmutate_inputs
.
count
(
n
->
op
)
!=
0
)
{
FMutateInput
fmutate
=
fmutate_inputs
[
n
->
op
];
for
(
uint32_t
i
=
0
;
i
<
n
->
inputs
.
size
();
++
i
)
{
if
(
fmutate
(
n
->
attrs
,
i
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
n
->
op
](
n
->
attrs
))
{
NodeEntry
&
e
=
n
->
inputs
[
i
];
CHECK
(
e
.
node
->
is_variable
())
<<
"Mutation target can only be Variable"
;
...
...
@@ -48,7 +46,6 @@ inline void UpdateNodeVersion(Node *n) {
e
.
version
=
++
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
}
}
}
}
inline
std
::
string
DefaultVarName
(
const
std
::
string
&
op_name
,
...
...
@@ -192,18 +189,15 @@ std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
}
else
{
std
::
unordered_set
<
Node
*>
mutable_set
;
std
::
vector
<
Node
*>
vlist
;
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
>
(
"FMutateInput
"
);
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
s
>
(
"FMutateInputs
"
);
DFSVisit
(
this
->
outputs
,
[
&
ret
,
&
mutable_set
,
&
vlist
](
const
NodePtr
&
node
)
{
if
(
node
->
is_variable
())
{
vlist
.
push_back
(
node
.
get
());
}
else
if
(
fmutate_inputs
.
count
(
node
->
op
))
{
FMutateInput
fmutate
=
fmutate_inputs
[
node
->
op
];
for
(
uint32_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
if
(
fmutate
(
node
->
attrs
,
i
))
{
for
(
uint32_t
i
:
fmutate_inputs
[
node
->
op
](
node
->
attrs
)){
mutable_set
.
insert
(
node
->
inputs
[
i
].
node
.
get
());
}
}
}
});
for
(
Node
*
node
:
vlist
)
{
if
((
option
==
kReadOnlyArgs
&&
mutable_set
.
count
(
node
)
==
0
)
||
...
...
nnvm/src/pass/infer_shape_type.cc
View file @
486249e8
...
...
@@ -15,7 +15,7 @@ template<typename AttrType, typename IsNone>
Graph
InferAttr
(
Graph
&&
ret
,
const
AttrType
def_value
,
const
char
*
infer_name
,
const
char
*
arg
_name
,
const
char
*
input
_name
,
const
char
*
attr_key_name
,
const
char
*
attr_name
,
const
char
*
unknown_name
,
...
...
@@ -29,15 +29,15 @@ Graph InferAttr(Graph &&ret,
// reshape shape vector
AttrVector
rshape
(
idx
.
num_node_entries
(),
def_value
);
if
(
ret
.
attrs
.
count
(
arg
_name
)
!=
0
)
{
const
AttrVector
&
shape_args
=
ret
.
GetAttr
<
AttrVector
>
(
arg
_name
);
CHECK_LE
(
shape_args
.
size
(),
idx
.
arg
_nodes
().
size
())
if
(
ret
.
attrs
.
count
(
input
_name
)
!=
0
)
{
const
AttrVector
&
shape_args
=
ret
.
GetAttr
<
AttrVector
>
(
input
_name
);
CHECK_LE
(
shape_args
.
size
(),
idx
.
input
_nodes
().
size
())
<<
"shape args is more than number of arguments"
;
for
(
size_t
i
=
0
;
i
<
shape_args
.
size
();
++
i
)
{
rshape
[
idx
.
entry_id
(
idx
.
arg
_nodes
()[
i
],
0
)]
=
shape_args
[
i
];
rshape
[
idx
.
entry_id
(
idx
.
input
_nodes
()[
i
],
0
)]
=
shape_args
[
i
];
}
// erase the provided arguments
ret
.
attrs
.
erase
(
arg
_name
);
ret
.
attrs
.
erase
(
input
_name
);
}
std
::
string
shape_attr_key
;
if
(
ret
.
attrs
.
count
(
attr_key_name
)
!=
0
)
{
...
...
@@ -113,7 +113,7 @@ NNVM_REGISTER_PASS(InferType)
.
set_body
([](
Graph
ret
)
{
return
InferAttr
<
int
>
(
std
::
move
(
ret
),
0
,
"FInferType"
,
"dtype_
arg
s"
,
"dtype_attr_key"
,
"FInferType"
,
"dtype_
input
s"
,
"dtype_attr_key"
,
"dtype"
,
"dtype_num_unknown_nodes"
,
[](
const
int
t
)
{
return
t
==
-
1
;
});
})
...
...
nnvm/src/pass/order_mutation.cc
View file @
486249e8
...
...
@@ -21,6 +21,13 @@ inline T get_with_default(const std::unordered_map<Node*, T> &map,
return
def
;
}
inline
bool
IsMutate
(
const
std
::
vector
<
uint32_t
>&
mutate_inputs
,
uint32_t
i
)
{
if
(
mutate_inputs
.
size
()
==
0
)
return
false
;
auto
it
=
std
::
lower_bound
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
(),
i
);
return
(
it
!=
mutate_inputs
.
end
())
&&
(
*
it
==
i
);
}
Graph
OrderMutation
(
const
Graph
&
src
)
{
std
::
unordered_map
<
Node
*
,
std
::
vector
<
NodeEntry
>
>
version_hist
;
DFSVisit
(
src
.
outputs
,
[
&
version_hist
](
const
NodePtr
&
n
)
{
...
...
@@ -37,7 +44,13 @@ Graph OrderMutation(const Graph& src) {
// start preparing for remapping the nodes.
std
::
unordered_map
<
Node
*
,
NodePtr
>
old_new
;
auto
prepare
=
[
&
version_hist
,
&
old_new
]
(
const
NodePtr
&
n
)
{
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
>
(
"FMutateInput"
);
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
std
::
vector
<
uint32_t
>
mutate_inputs
;
if
(
!
n
->
is_variable
()
&&
fmutate_inputs
.
count
(
n
->
op
))
{
mutate_inputs
=
fmutate_inputs
[
n
->
op
](
n
->
attrs
);
}
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
bool
need_repl
=
false
;
for
(
size_t
i
=
0
;
i
<
n
->
inputs
.
size
();
++
i
)
{
const
NodeEntry
&
e
=
n
->
inputs
[
i
];
...
...
@@ -46,9 +59,7 @@ Graph OrderMutation(const Graph& src) {
auto
it
=
version_hist
.
find
(
e
.
node
.
get
());
if
(
it
!=
version_hist
.
end
())
{
std
::
vector
<
NodeEntry
>&
vec
=
it
->
second
;
uint32_t
is_mutate
=
fmutate_inputs
.
count
(
n
->
op
)
?
fmutate_inputs
[
n
->
op
](
n
->
attrs
,
i
)
:
0
;
vec
.
emplace_back
(
NodeEntry
{
n
,
is_mutate
,
e
.
version
});
vec
.
emplace_back
(
NodeEntry
{
n
,
IsMutate
(
mutate_inputs
,
i
),
e
.
version
});
}
}
else
{
if
(
old_new
.
count
(
e
.
node
.
get
())
!=
0
)
need_repl
=
true
;
...
...
@@ -91,18 +102,21 @@ Graph OrderMutation(const Graph& src) {
get_with_default
(
old_new
,
p
.
get
(),
p
));
}
// add control deps
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
>
(
"FMutateInput"
);
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInputs
>
(
"FMutateInputs"
);
std
::
vector
<
uint32_t
>
mutate_inputs
;
if
(
fmutate_inputs
.
count
(
kv
.
first
->
op
))
{
mutate_inputs
=
fmutate_inputs
[
kv
.
first
->
op
](
kv
.
first
->
attrs
);
}
std
::
sort
(
mutate_inputs
.
begin
(),
mutate_inputs
.
end
());
for
(
size_t
i
=
0
;
i
<
kv
.
first
->
inputs
.
size
();
++
i
)
{
const
NodeEntry
&
e
=
kv
.
first
->
inputs
[
i
];
if
(
e
.
node
->
is_variable
()
&&
version_hist
.
count
(
e
.
node
.
get
())
!=
0
)
{
FMutateInput
fmutate
=
fmutate_inputs
.
get
(
kv
.
first
->
op
,
nullptr
);
uint32_t
is_mutate
=
(
fmutate
==
nullptr
)
?
0
:
fmutate
(
kv
.
first
->
attrs
,
i
);
std
::
vector
<
NodeEntry
>&
vec
=
version_hist
.
at
(
e
.
node
.
get
());
auto
it
=
std
::
lower_bound
(
vec
.
begin
(),
vec
.
end
(),
NodeEntry
{
nullptr
,
1
,
e
.
version
},
comparator
);
if
(
is_mutate
!=
0
)
{
if
(
IsMutate
(
mutate_inputs
,
i
)
)
{
int
read_dep
=
0
;
while
(
it
!=
vec
.
begin
())
{
--
it
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment