Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
badcdfff
Commit
badcdfff
authored
Jul 12, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Change op function pointer to std::function, enable mutation (#6)
parent
c92d63c7
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
123 additions
and
30 deletions
+123
-30
nnvm/include/nnvm/node.h
+7
-0
nnvm/include/nnvm/op.h
+9
-10
nnvm/include/nnvm/op_attr_types.h
+12
-2
nnvm/src/core/symbolic.cc
+73
-18
nnvm/src/example/operator.cc
+8
-0
nnvm/tests/python/test_symbol.py
+14
-0
No files found.
nnvm/include/nnvm/node.h
View file @
badcdfff
...
...
@@ -24,6 +24,13 @@ struct NodeEntry {
std
::
shared_ptr
<
Node
>
node
;
/*! \brief index of output from the source. */
uint32_t
index
;
/*!
* \brief version of input Variable.
* This field can only be nonzero when this->node is a Variable node.
* version is increased by one each time a Variable get composed to a mutation Op.
* This information can be helpful to decide order of operations when sequence of mutation happens.
*/
uint32_t
version
;
};
/*!
...
...
nnvm/include/nnvm/op.h
View file @
badcdfff
...
...
@@ -101,13 +101,13 @@ class Op {
* \param attrs The attribute of the node
* \return number of outputs.
*/
uint32_t
(
*
get_num_outputs
)(
const
NodeAttrs
&
attrs
)
=
nullptr
;
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attrs
)
>
get_num_outputs
=
nullptr
;
/*!
* \brief get number of inputs given information about the node.
* \param attrs The attribute of the node
* \return number of inputs
*/
uint32_t
(
*
get_num_inputs
)(
const
NodeAttrs
&
attrs
)
=
nullptr
;
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attrs
)
>
get_num_inputs
=
nullptr
;
/*!
* \brief Attribute parser to parse the NodeAttrs information.
*
...
...
@@ -140,8 +140,7 @@ class Op {
* }
* \endcode
*/
void
(
*
attr_parser
)(
NodeAttrs
*
attrs
)
=
nullptr
;
std
::
function
<
void
(
NodeAttrs
*
attrs
)
>
attr_parser
=
nullptr
;
// function fields.
/*!
* \brief setter function during registration
...
...
@@ -161,7 +160,7 @@ class Op {
* \param fn The function to be set.
* \return reference to self.
*/
inline
Op
&
set_num_inputs
(
uint32_t
(
*
fn
)(
const
NodeAttrs
&
attr
)
);
// NOLINT(*)
inline
Op
&
set_num_inputs
(
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attr
)
>
fn
);
// NOLINT(*)
/*!
* \brief Set the num_outputs
* \param n The number of outputs to be set.
...
...
@@ -173,13 +172,13 @@ class Op {
* \param fn The function to be set.
* \return reference to self.
*/
inline
Op
&
set_num_outputs
(
uint32_t
(
*
fn
)(
const
NodeAttrs
&
attr
)
);
// NOLINT(*)
inline
Op
&
set_num_outputs
(
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attr
)
>
fn
);
// NOLINT(*)
/*!
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
* \return reference to self.
*/
inline
Op
&
set_attr_parser
(
void
(
*
fn
)(
NodeAttrs
*
attrs
)
);
// NOLINT(*)
inline
Op
&
set_attr_parser
(
std
::
function
<
void
(
NodeAttrs
*
attrs
)
>
fn
);
// NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
...
...
@@ -342,7 +341,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
return
*
this
;
}
inline
Op
&
Op
::
set_num_inputs
(
uint32_t
(
*
fn
)(
const
NodeAttrs
&
attr
)
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_num_inputs
(
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attr
)
>
fn
)
{
// NOLINT(*)
this
->
get_num_inputs
=
fn
;
return
*
this
;
}
...
...
@@ -352,12 +351,12 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
return
*
this
;
}
inline
Op
&
Op
::
set_num_outputs
(
uint32_t
(
*
fn
)(
const
NodeAttrs
&
attr
)
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_num_outputs
(
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attr
)
>
fn
)
{
// NOLINT(*)
this
->
get_num_outputs
=
fn
;
return
*
this
;
}
inline
Op
&
Op
::
set_attr_parser
(
void
(
*
fn
)(
NodeAttrs
*
attrs
)
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_attr_parser
(
std
::
function
<
void
(
NodeAttrs
*
attrs
)
>
fn
)
{
// NOLINT(*)
this
->
attr_parser
=
fn
;
return
*
this
;
}
...
...
nnvm/include/nnvm/op_attr_types.h
View file @
badcdfff
...
...
@@ -12,8 +12,8 @@
namespace
nnvm
{
// These types are optional attributes in each op
//
Some of them are needed for certain pas
s.
// These types are optional attributes in each op
erator.
//
Each attribute can be required by some passe
s.
/*!
* \brief Return list of input arguments names of each operator.
...
...
@@ -37,6 +37,16 @@ using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs&
*/
using
FListOutputNames
=
std
::
function
<
std
::
vector
<
std
::
string
>
(
const
NodeAttrs
&
attrs
)
>
;
/*!
* \brief Check whether operator will mutate k-th input.
* \param index The input index
* \return Whether this operator will mutate index-th input.
*
* \note Register under "FMutateInput", default return false
* FMutateInputs enables mutation order handling correctly.
*/
using
FMutateInput
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
uint32_t
index
)
>
;
}
// namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_
nnvm/src/core/symbolic.cc
View file @
badcdfff
...
...
@@ -13,6 +13,43 @@ namespace symbol_constants {
const
char
*
kNamespaceSeparator
=
"_"
;
}
// namespace symbol_constants
// auxililary version attribute in variable.
struct
VariableParam
{
uint32_t
version
{
0
};
};
std
::
shared_ptr
<
Node
>
CreateVariableNode
(
const
std
::
string
&
name
)
{
std
::
shared_ptr
<
Node
>
n
=
Node
::
Create
();
n
->
op
=
nullptr
;
n
->
attrs
.
name
=
name
;
n
->
attrs
.
parsed
=
VariableParam
();
return
n
;
}
// scan over a node's input, update the version to latest
// If the node's op mutates a certain input variable,
// The version of that varaible will increase
// version is used to implicitly order the mutation sequences
inline
void
UpdateNodeVersion
(
Node
*
n
)
{
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
>
(
"FMutateInput"
);
for
(
NodeEntry
&
e
:
n
->
inputs
)
{
if
(
e
.
node
->
is_variable
())
{
e
.
version
=
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
}
}
if
(
fmutate_inputs
.
count
(
n
->
op
)
!=
0
)
{
FMutateInput
fmutate
=
fmutate_inputs
[
n
->
op
];
for
(
uint32_t
i
=
0
;
i
<
n
->
inputs
.
size
();
++
i
)
{
if
(
fmutate
(
n
->
attrs
,
i
))
{
NodeEntry
&
e
=
n
->
inputs
[
i
];
CHECK
(
e
.
node
->
is_variable
())
<<
"Mutation target can only be Variable"
;
// increase the version of the variable.
++
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
}
}
}
}
inline
std
::
string
DefaultVarName
(
const
std
::
string
&
op_name
,
const
std
::
string
&
arg_name
)
{
...
...
@@ -67,13 +104,13 @@ Symbol Symbol::Copy() const {
for
(
const
auto
&
kv
:
old_new
)
{
for
(
const
NodeEntry
&
e
:
kv
.
first
->
inputs
)
{
Node
*
ptr
=
e
.
node
.
get
();
kv
.
second
->
inputs
.
emplace_back
(
NodeEntry
{
old_new
[
ptr
],
e
.
index
});
kv
.
second
->
inputs
.
emplace_back
(
NodeEntry
{
old_new
[
ptr
],
e
.
index
,
e
.
version
});
}
}
// set the head
Symbol
ret
;
for
(
const
NodeEntry
&
e
:
outputs
)
{
ret
.
outputs
.
emplace_back
(
NodeEntry
{
old_new
[
e
.
node
.
get
()],
e
.
index
});
ret
.
outputs
.
emplace_back
(
NodeEntry
{
old_new
[
e
.
node
.
get
()],
e
.
index
,
e
.
version
});
}
return
ret
;
}
...
...
@@ -95,8 +132,14 @@ void Symbol::Print(std::ostream &os) const {
os
<<
"Name: "
<<
node
->
attrs
.
name
<<
" Op:"
<<
node
->
op
->
name
<<
'\n'
<<
"Inputs:
\n
"
;
for
(
size_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
os
<<
"
\t
arg["
<<
i
<<
"]="
<<
node
->
inputs
[
i
].
node
->
attrs
.
name
<<
'('
<<
node
->
inputs
[
i
].
index
<<
")
\n
"
;
const
NodeEntry
&
e
=
node
->
inputs
[
i
];
os
<<
"
\t
arg["
<<
i
<<
"]="
<<
e
.
node
->
attrs
.
name
<<
'('
<<
e
.
index
<<
")"
;
if
(
e
.
node
->
is_variable
())
{
os
<<
" version="
<<
e
.
version
<<
'\n'
;
}
else
{
os
<<
'\n'
;
}
}
os
<<
"Attrs:
\n
"
;
for
(
auto
&
kv
:
node
->
attrs
.
dict
)
{
...
...
@@ -163,6 +206,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
void
Symbol
::
Compose
(
const
std
::
vector
<
Symbol
>&
args
,
const
std
::
unordered_map
<
std
::
string
,
Symbol
>&
kwargs
,
const
std
::
string
&
name
)
{
static
auto
&
flist_inputs
=
Op
::
GetAttr
<
FListInputNames
>
(
"FListInputNames"
);
CHECK_EQ
(
outputs
.
size
(),
1
)
<<
"Only composition of value function is supported currently"
;
CHECK
(
!
outputs
[
0
].
node
->
is_variable
())
<<
"Variable cannot be composed"
;
...
...
@@ -193,7 +238,6 @@ void Symbol::Compose(const std::vector<Symbol>& args,
}
// switch to keyword argument matching
if
(
args
.
size
()
!=
n_req
)
{
static
auto
&
flist_inputs
=
Op
::
GetAttr
<
FListInputNames
>
(
"FListInputNames"
);
FListInputNames
fn
=
flist_inputs
.
get
(
n
->
op
,
nullptr
);
auto
arg_names
=
(
fn
==
nullptr
)
?
std
::
vector
<
std
::
string
>
{
"data"
}
:
fn
(
n
->
attrs
);
if
(
arg_names
.
size
()
!=
n_req
)
{
...
...
@@ -206,8 +250,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
n
->
inputs
[
i
]
=
it
->
second
.
outputs
[
0
];
++
nmatched
;
}
else
{
n
->
inputs
[
i
]
=
NodeEntry
{
Node
::
Create
(),
0
};
n
->
inputs
[
i
].
node
->
attrs
.
name
=
DefaultVarName
(
name
,
arg_names
[
i
])
;
n
->
inputs
[
i
]
=
NodeEntry
{
CreateVariableNode
(
DefaultVarName
(
name
,
arg_names
[
i
])),
0
,
0
}
;
}
}
...
...
@@ -226,6 +270,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
n
->
inputs
.
push_back
(
s
.
outputs
[
0
]);
}
}
UpdateNodeVersion
(
n
);
}
else
{
// general composition
CHECK_EQ
(
args
.
size
(),
0
)
...
...
@@ -253,25 +298,32 @@ void Symbol::Compose(const std::vector<Symbol>& args,
DFSVisit
(
this
->
outputs
,
find_replace_map
);
if
(
nmatched
==
kwargs
.
size
()
&&
arg_counter
<
args
.
size
())
{
std
::
vector
<
Node
*>
update_nodes
;
std
::
vector
<
std
::
pair
<
NodeEntry
*
,
const
NodeEntry
*>
>
replace_plan
;
auto
find_replace_plan
=
[
&
replace_map
,
&
replace_plan
]
auto
find_replace_plan
=
[
&
replace_map
,
&
replace_plan
,
&
update_nodes
]
(
const
std
::
shared_ptr
<
Node
>
&
node
)
{
// visit all the childs, find possible replacement
bool
repl
=
false
;
for
(
size_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
NodeEntry
*
e
=
&
(
node
->
inputs
[
i
]);
if
(
e
->
node
->
is_variable
())
{
auto
iter
=
replace_map
.
find
(
e
->
node
.
get
());
if
(
iter
!=
replace_map
.
end
())
{
replace_plan
.
push_back
(
std
::
make_pair
(
e
,
iter
->
second
));
repl
=
true
;
}
}
}
if
(
repl
)
update_nodes
.
push_back
(
node
.
get
());
};
DFSVisit
(
this
->
outputs
,
find_replace_plan
);
for
(
const
auto
&
kv
:
replace_plan
)
{
*
(
kv
.
first
)
=
*
(
kv
.
second
);
}
for
(
Node
*
n
:
update_nodes
)
{
UpdateNodeVersion
(
n
);
}
}
else
{
std
::
vector
<
std
::
string
>
keys
=
GetKeys
(
kwargs
);
std
::
vector
<
std
::
string
>
arg_names
=
ListArguments
();
...
...
@@ -303,9 +355,15 @@ Symbol Symbol::GetInternals() const {
Symbol
ret
;
DFSVisit
(
this
->
outputs
,
[
&
ret
](
const
std
::
shared_ptr
<
Node
>&
node
)
{
Node
*
n
=
node
.
get
();
uint32_t
nout
=
n
->
num_outputs
();
for
(
uint32_t
i
=
0
;
i
<
nout
;
++
i
)
{
ret
.
outputs
.
emplace_back
(
NodeEntry
{
node
,
i
});
if
(
n
->
is_variable
())
{
// grab version from variable.
VariableParam
&
param
=
nnvm
::
get
<
VariableParam
>
(
n
->
attrs
.
parsed
);
ret
.
outputs
.
emplace_back
(
NodeEntry
{
node
,
0
,
param
.
version
});
}
else
{
uint32_t
nout
=
n
->
num_outputs
();
for
(
uint32_t
i
=
0
;
i
<
nout
;
++
i
)
{
ret
.
outputs
.
emplace_back
(
NodeEntry
{
node
,
i
,
0
});
}
}
});
return
ret
;
...
...
@@ -325,7 +383,7 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
}
}
if
(
node
->
op
!=
nullptr
&&
node
->
op
->
attr_parser
!=
nullptr
)
{
(
*
node
->
op
->
attr_parser
)
(
&
(
node
->
attrs
));
node
->
op
->
attr_parser
(
&
(
node
->
attrs
));
}
}
...
...
@@ -366,9 +424,9 @@ Symbol Symbol::CreateFunctor(const Op* op,
n
->
op
=
op
;
n
->
attrs
.
dict
=
std
::
move
(
attrs
);
if
(
n
->
op
->
attr_parser
!=
nullptr
)
{
(
*
n
->
op
->
attr_parser
)
(
&
(
n
->
attrs
));
n
->
op
->
attr_parser
(
&
(
n
->
attrs
));
}
s
.
outputs
.
emplace_back
(
NodeEntry
{
std
::
move
(
n
),
0
});
s
.
outputs
.
emplace_back
(
NodeEntry
{
std
::
move
(
n
),
0
,
0
});
return
s
;
}
...
...
@@ -382,10 +440,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
Symbol
Symbol
::
CreateVariable
(
const
std
::
string
&
name
)
{
Symbol
s
;
std
::
shared_ptr
<
Node
>
n
=
Node
::
Create
();
n
->
op
=
nullptr
;
n
->
attrs
.
name
=
name
;
s
.
outputs
.
emplace_back
(
NodeEntry
{
std
::
move
(
n
),
0
});
s
.
outputs
.
emplace_back
(
NodeEntry
{
CreateVariableNode
(
name
),
0
,
0
});
return
s
;
}
...
...
nnvm/src/example/operator.cc
View file @
badcdfff
...
...
@@ -6,6 +6,7 @@
#include <utility>
using
nnvm
::
FListInputNames
;
using
nnvm
::
FMutateInput
;
using
nnvm
::
NodeAttrs
;
NNVM_REGISTER_OP
(
add
)
...
...
@@ -29,3 +30,10 @@ NNVM_REGISTER_OP(conv2d)
NNVM_REGISTER_OP
(
add
)
.
attr
<
std
::
string
>
(
"nick_name"
,
"plus"
);
NNVM_REGISTER_OP
(
assign
)
.
set_num_inputs
(
2
)
.
set_num_outputs
(
1
)
.
attr
<
FMutateInput
>
(
"FMutateInput"
,
[](
const
NodeAttrs
&
attrs
,
uint32_t
index
)
{
return
index
==
0
;
});
nnvm/tests/python/test_symbol.py
View file @
badcdfff
...
...
@@ -24,6 +24,20 @@ def test_default_input():
except
NNVMError
:
pass
def
test_mutate_input
():
x
=
sym
.
Variable
(
'x'
)
y
=
sym
.
conv2d
(
data
=
x
,
name
=
'conv'
)
z
=
sym
.
assign
(
x
,
y
)
t
=
sym
.
add
(
z
,
x
)
try
:
z
=
sym
.
assign
(
z
,
z
)
assert
False
except
NNVMError
:
pass
if
__name__
==
"__main__"
:
test_default_input
()
test_compose
()
test_mutate_input
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment