Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
badcdfff
Commit
badcdfff
authored
8 years ago
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Change op function pointer to std::function, enable mutation (#6)
parent
c92d63c7
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
121 additions
and
28 deletions
+121
-28
nnvm/include/nnvm/node.h
+7
-0
nnvm/include/nnvm/op.h
+9
-10
nnvm/include/nnvm/op_attr_types.h
+12
-2
nnvm/src/core/symbolic.cc
+71
-16
nnvm/src/example/operator.cc
+8
-0
nnvm/tests/python/test_symbol.py
+14
-0
No files found.
nnvm/include/nnvm/node.h
View file @
badcdfff
...
...
@@ -24,6 +24,13 @@ struct NodeEntry {
std
::
shared_ptr
<
Node
>
node
;
/*! \brief index of output from the source. */
uint32_t
index
;
/*!
* \brief version of input Variable.
* This field can only be nonzero when this->node is a Variable node.
* version is increased by one each time a Variable get composed to a mutation Op.
* This information can be helpful to decide order of operations when sequence of mutation happens.
*/
uint32_t
version
;
};
/*!
...
...
This diff is collapsed.
Click to expand it.
nnvm/include/nnvm/op.h
View file @
badcdfff
...
...
@@ -101,13 +101,13 @@ class Op {
* \param attrs The attribute of the node
* \return number of outputs.
*/
uint32_t
(
*
get_num_outputs
)(
const
NodeAttrs
&
attrs
)
=
nullptr
;
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attrs
)
>
get_num_outputs
=
nullptr
;
/*!
* \brief get number of inputs given information about the node.
* \param attrs The attribute of the node
* \return number of inputs
*/
uint32_t
(
*
get_num_inputs
)(
const
NodeAttrs
&
attrs
)
=
nullptr
;
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attrs
)
>
get_num_inputs
=
nullptr
;
/*!
* \brief Attribute parser to parse the NodeAttrs information.
*
...
...
@@ -140,8 +140,7 @@ class Op {
* }
* \endcode
*/
void
(
*
attr_parser
)(
NodeAttrs
*
attrs
)
=
nullptr
;
std
::
function
<
void
(
NodeAttrs
*
attrs
)
>
attr_parser
=
nullptr
;
// function fields.
/*!
* \brief setter function during registration
...
...
@@ -161,7 +160,7 @@ class Op {
* \param fn The function to be set.
* \return reference to self.
*/
inline
Op
&
set_num_inputs
(
uint32_t
(
*
fn
)(
const
NodeAttrs
&
attr
)
);
// NOLINT(*)
inline
Op
&
set_num_inputs
(
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attr
)
>
fn
);
// NOLINT(*)
/*!
* \brief Set the num_outputs
* \param n The number of outputs to be set.
...
...
@@ -173,13 +172,13 @@ class Op {
* \param fn The function to be set.
* \return reference to self.
*/
inline
Op
&
set_num_outputs
(
uint32_t
(
*
fn
)(
const
NodeAttrs
&
attr
)
);
// NOLINT(*)
inline
Op
&
set_num_outputs
(
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attr
)
>
fn
);
// NOLINT(*)
/*!
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
* \return reference to self.
*/
inline
Op
&
set_attr_parser
(
void
(
*
fn
)(
NodeAttrs
*
attrs
)
);
// NOLINT(*)
inline
Op
&
set_attr_parser
(
std
::
function
<
void
(
NodeAttrs
*
attrs
)
>
fn
);
// NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
...
...
@@ -342,7 +341,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
return
*
this
;
}
inline
Op
&
Op
::
set_num_inputs
(
uint32_t
(
*
fn
)(
const
NodeAttrs
&
attr
)
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_num_inputs
(
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attr
)
>
fn
)
{
// NOLINT(*)
this
->
get_num_inputs
=
fn
;
return
*
this
;
}
...
...
@@ -352,12 +351,12 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
return
*
this
;
}
inline
Op
&
Op
::
set_num_outputs
(
uint32_t
(
*
fn
)(
const
NodeAttrs
&
attr
)
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_num_outputs
(
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attr
)
>
fn
)
{
// NOLINT(*)
this
->
get_num_outputs
=
fn
;
return
*
this
;
}
inline
Op
&
Op
::
set_attr_parser
(
void
(
*
fn
)(
NodeAttrs
*
attrs
)
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_attr_parser
(
std
::
function
<
void
(
NodeAttrs
*
attrs
)
>
fn
)
{
// NOLINT(*)
this
->
attr_parser
=
fn
;
return
*
this
;
}
...
...
This diff is collapsed.
Click to expand it.
nnvm/include/nnvm/op_attr_types.h
View file @
badcdfff
...
...
@@ -12,8 +12,8 @@
namespace
nnvm
{
// These types are optional attributes in each op
//
Some of them are needed for certain pas
s.
// These types are optional attributes in each op
erator.
//
Each attribute can be required by some passe
s.
/*!
* \brief Return list of input arguments names of each operator.
...
...
@@ -37,6 +37,16 @@ using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs&
*/
using
FListOutputNames
=
std
::
function
<
std
::
vector
<
std
::
string
>
(
const
NodeAttrs
&
attrs
)
>
;
/*!
* \brief Check whether operator will mutate k-th input.
* \param index The input index
* \return Whether this operator will mutate index-th input.
*
* \note Register under "FMutateInput", default return false
* FMutateInputs enables mutation order handling correctly.
*/
using
FMutateInput
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
uint32_t
index
)
>
;
}
// namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_
This diff is collapsed.
Click to expand it.
nnvm/src/core/symbolic.cc
View file @
badcdfff
...
...
@@ -13,6 +13,43 @@ namespace symbol_constants {
const
char
*
kNamespaceSeparator
=
"_"
;
}
// namespace symbol_constants
// auxililary version attribute in variable.
struct
VariableParam
{
uint32_t
version
{
0
};
};
std
::
shared_ptr
<
Node
>
CreateVariableNode
(
const
std
::
string
&
name
)
{
std
::
shared_ptr
<
Node
>
n
=
Node
::
Create
();
n
->
op
=
nullptr
;
n
->
attrs
.
name
=
name
;
n
->
attrs
.
parsed
=
VariableParam
();
return
n
;
}
// scan over a node's input, update the version to latest
// If the node's op mutates a certain input variable,
// The version of that varaible will increase
// version is used to implicitly order the mutation sequences
inline
void
UpdateNodeVersion
(
Node
*
n
)
{
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
>
(
"FMutateInput"
);
for
(
NodeEntry
&
e
:
n
->
inputs
)
{
if
(
e
.
node
->
is_variable
())
{
e
.
version
=
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
}
}
if
(
fmutate_inputs
.
count
(
n
->
op
)
!=
0
)
{
FMutateInput
fmutate
=
fmutate_inputs
[
n
->
op
];
for
(
uint32_t
i
=
0
;
i
<
n
->
inputs
.
size
();
++
i
)
{
if
(
fmutate
(
n
->
attrs
,
i
))
{
NodeEntry
&
e
=
n
->
inputs
[
i
];
CHECK
(
e
.
node
->
is_variable
())
<<
"Mutation target can only be Variable"
;
// increase the version of the variable.
++
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
}
}
}
}
inline
std
::
string
DefaultVarName
(
const
std
::
string
&
op_name
,
const
std
::
string
&
arg_name
)
{
...
...
@@ -67,13 +104,13 @@ Symbol Symbol::Copy() const {
for
(
const
auto
&
kv
:
old_new
)
{
for
(
const
NodeEntry
&
e
:
kv
.
first
->
inputs
)
{
Node
*
ptr
=
e
.
node
.
get
();
kv
.
second
->
inputs
.
emplace_back
(
NodeEntry
{
old_new
[
ptr
],
e
.
index
});
kv
.
second
->
inputs
.
emplace_back
(
NodeEntry
{
old_new
[
ptr
],
e
.
index
,
e
.
version
});
}
}
// set the head
Symbol
ret
;
for
(
const
NodeEntry
&
e
:
outputs
)
{
ret
.
outputs
.
emplace_back
(
NodeEntry
{
old_new
[
e
.
node
.
get
()],
e
.
index
});
ret
.
outputs
.
emplace_back
(
NodeEntry
{
old_new
[
e
.
node
.
get
()],
e
.
index
,
e
.
version
});
}
return
ret
;
}
...
...
@@ -95,8 +132,14 @@ void Symbol::Print(std::ostream &os) const {
os
<<
"Name: "
<<
node
->
attrs
.
name
<<
" Op:"
<<
node
->
op
->
name
<<
'\n'
<<
"Inputs:
\n
"
;
for
(
size_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
os
<<
"
\t
arg["
<<
i
<<
"]="
<<
node
->
inputs
[
i
].
node
->
attrs
.
name
<<
'('
<<
node
->
inputs
[
i
].
index
<<
")
\n
"
;
const
NodeEntry
&
e
=
node
->
inputs
[
i
];
os
<<
"
\t
arg["
<<
i
<<
"]="
<<
e
.
node
->
attrs
.
name
<<
'('
<<
e
.
index
<<
")"
;
if
(
e
.
node
->
is_variable
())
{
os
<<
" version="
<<
e
.
version
<<
'\n'
;
}
else
{
os
<<
'\n'
;
}
}
os
<<
"Attrs:
\n
"
;
for
(
auto
&
kv
:
node
->
attrs
.
dict
)
{
...
...
@@ -163,6 +206,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
void
Symbol
::
Compose
(
const
std
::
vector
<
Symbol
>&
args
,
const
std
::
unordered_map
<
std
::
string
,
Symbol
>&
kwargs
,
const
std
::
string
&
name
)
{
static
auto
&
flist_inputs
=
Op
::
GetAttr
<
FListInputNames
>
(
"FListInputNames"
);
CHECK_EQ
(
outputs
.
size
(),
1
)
<<
"Only composition of value function is supported currently"
;
CHECK
(
!
outputs
[
0
].
node
->
is_variable
())
<<
"Variable cannot be composed"
;
...
...
@@ -193,7 +238,6 @@ void Symbol::Compose(const std::vector<Symbol>& args,
}
// switch to keyword argument matching
if
(
args
.
size
()
!=
n_req
)
{
static
auto
&
flist_inputs
=
Op
::
GetAttr
<
FListInputNames
>
(
"FListInputNames"
);
FListInputNames
fn
=
flist_inputs
.
get
(
n
->
op
,
nullptr
);
auto
arg_names
=
(
fn
==
nullptr
)
?
std
::
vector
<
std
::
string
>
{
"data"
}
:
fn
(
n
->
attrs
);
if
(
arg_names
.
size
()
!=
n_req
)
{
...
...
@@ -206,8 +250,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
n
->
inputs
[
i
]
=
it
->
second
.
outputs
[
0
];
++
nmatched
;
}
else
{
n
->
inputs
[
i
]
=
NodeEntry
{
Node
::
Create
(),
0
};
n
->
inputs
[
i
].
node
->
attrs
.
name
=
DefaultVarName
(
name
,
arg_names
[
i
])
;
n
->
inputs
[
i
]
=
NodeEntry
{
CreateVariableNode
(
DefaultVarName
(
name
,
arg_names
[
i
])),
0
,
0
}
;
}
}
...
...
@@ -226,6 +270,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
n
->
inputs
.
push_back
(
s
.
outputs
[
0
]);
}
}
UpdateNodeVersion
(
n
);
}
else
{
// general composition
CHECK_EQ
(
args
.
size
(),
0
)
...
...
@@ -253,25 +298,32 @@ void Symbol::Compose(const std::vector<Symbol>& args,
DFSVisit
(
this
->
outputs
,
find_replace_map
);
if
(
nmatched
==
kwargs
.
size
()
&&
arg_counter
<
args
.
size
())
{
std
::
vector
<
Node
*>
update_nodes
;
std
::
vector
<
std
::
pair
<
NodeEntry
*
,
const
NodeEntry
*>
>
replace_plan
;
auto
find_replace_plan
=
[
&
replace_map
,
&
replace_plan
]
auto
find_replace_plan
=
[
&
replace_map
,
&
replace_plan
,
&
update_nodes
]
(
const
std
::
shared_ptr
<
Node
>
&
node
)
{
// visit all the childs, find possible replacement
bool
repl
=
false
;
for
(
size_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
NodeEntry
*
e
=
&
(
node
->
inputs
[
i
]);
if
(
e
->
node
->
is_variable
())
{
auto
iter
=
replace_map
.
find
(
e
->
node
.
get
());
if
(
iter
!=
replace_map
.
end
())
{
replace_plan
.
push_back
(
std
::
make_pair
(
e
,
iter
->
second
));
repl
=
true
;
}
}
}
if
(
repl
)
update_nodes
.
push_back
(
node
.
get
());
};
DFSVisit
(
this
->
outputs
,
find_replace_plan
);
for
(
const
auto
&
kv
:
replace_plan
)
{
*
(
kv
.
first
)
=
*
(
kv
.
second
);
}
for
(
Node
*
n
:
update_nodes
)
{
UpdateNodeVersion
(
n
);
}
}
else
{
std
::
vector
<
std
::
string
>
keys
=
GetKeys
(
kwargs
);
std
::
vector
<
std
::
string
>
arg_names
=
ListArguments
();
...
...
@@ -303,9 +355,15 @@ Symbol Symbol::GetInternals() const {
Symbol
ret
;
DFSVisit
(
this
->
outputs
,
[
&
ret
](
const
std
::
shared_ptr
<
Node
>&
node
)
{
Node
*
n
=
node
.
get
();
if
(
n
->
is_variable
())
{
// grab version from variable.
VariableParam
&
param
=
nnvm
::
get
<
VariableParam
>
(
n
->
attrs
.
parsed
);
ret
.
outputs
.
emplace_back
(
NodeEntry
{
node
,
0
,
param
.
version
});
}
else
{
uint32_t
nout
=
n
->
num_outputs
();
for
(
uint32_t
i
=
0
;
i
<
nout
;
++
i
)
{
ret
.
outputs
.
emplace_back
(
NodeEntry
{
node
,
i
});
ret
.
outputs
.
emplace_back
(
NodeEntry
{
node
,
i
,
0
});
}
}
});
return
ret
;
...
...
@@ -325,7 +383,7 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
}
}
if
(
node
->
op
!=
nullptr
&&
node
->
op
->
attr_parser
!=
nullptr
)
{
(
*
node
->
op
->
attr_parser
)
(
&
(
node
->
attrs
));
node
->
op
->
attr_parser
(
&
(
node
->
attrs
));
}
}
...
...
@@ -366,9 +424,9 @@ Symbol Symbol::CreateFunctor(const Op* op,
n
->
op
=
op
;
n
->
attrs
.
dict
=
std
::
move
(
attrs
);
if
(
n
->
op
->
attr_parser
!=
nullptr
)
{
(
*
n
->
op
->
attr_parser
)
(
&
(
n
->
attrs
));
n
->
op
->
attr_parser
(
&
(
n
->
attrs
));
}
s
.
outputs
.
emplace_back
(
NodeEntry
{
std
::
move
(
n
),
0
});
s
.
outputs
.
emplace_back
(
NodeEntry
{
std
::
move
(
n
),
0
,
0
});
return
s
;
}
...
...
@@ -382,10 +440,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
Symbol
Symbol
::
CreateVariable
(
const
std
::
string
&
name
)
{
Symbol
s
;
std
::
shared_ptr
<
Node
>
n
=
Node
::
Create
();
n
->
op
=
nullptr
;
n
->
attrs
.
name
=
name
;
s
.
outputs
.
emplace_back
(
NodeEntry
{
std
::
move
(
n
),
0
});
s
.
outputs
.
emplace_back
(
NodeEntry
{
CreateVariableNode
(
name
),
0
,
0
});
return
s
;
}
...
...
This diff is collapsed.
Click to expand it.
nnvm/src/example/operator.cc
View file @
badcdfff
...
...
@@ -6,6 +6,7 @@
#include <utility>
using
nnvm
::
FListInputNames
;
using
nnvm
::
FMutateInput
;
using
nnvm
::
NodeAttrs
;
NNVM_REGISTER_OP
(
add
)
...
...
@@ -29,3 +30,10 @@ NNVM_REGISTER_OP(conv2d)
NNVM_REGISTER_OP
(
add
)
.
attr
<
std
::
string
>
(
"nick_name"
,
"plus"
);
NNVM_REGISTER_OP
(
assign
)
.
set_num_inputs
(
2
)
.
set_num_outputs
(
1
)
.
attr
<
FMutateInput
>
(
"FMutateInput"
,
[](
const
NodeAttrs
&
attrs
,
uint32_t
index
)
{
return
index
==
0
;
});
This diff is collapsed.
Click to expand it.
nnvm/tests/python/test_symbol.py
View file @
badcdfff
...
...
@@ -24,6 +24,20 @@ def test_default_input():
except
NNVMError
:
pass
def
test_mutate_input
():
x
=
sym
.
Variable
(
'x'
)
y
=
sym
.
conv2d
(
data
=
x
,
name
=
'conv'
)
z
=
sym
.
assign
(
x
,
y
)
t
=
sym
.
add
(
z
,
x
)
try
:
z
=
sym
.
assign
(
z
,
z
)
assert
False
except
NNVMError
:
pass
if
__name__
==
"__main__"
:
test_default_input
()
test_compose
()
test_mutate_input
()
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment