Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
badcdfff
Commit
badcdfff
authored
Jul 12, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Change op function pointer to std::function, enable mutation (#6)
parent
c92d63c7
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
121 additions
and
28 deletions
+121
-28
nnvm/include/nnvm/node.h
+7
-0
nnvm/include/nnvm/op.h
+9
-10
nnvm/include/nnvm/op_attr_types.h
+12
-2
nnvm/src/core/symbolic.cc
+71
-16
nnvm/src/example/operator.cc
+8
-0
nnvm/tests/python/test_symbol.py
+14
-0
No files found.
nnvm/include/nnvm/node.h
View file @
badcdfff
...
@@ -24,6 +24,13 @@ struct NodeEntry {
...
@@ -24,6 +24,13 @@ struct NodeEntry {
std
::
shared_ptr
<
Node
>
node
;
std
::
shared_ptr
<
Node
>
node
;
/*! \brief index of output from the source. */
/*! \brief index of output from the source. */
uint32_t
index
;
uint32_t
index
;
/*!
* \brief version of input Variable.
* This field can only be nonzero when this->node is a Variable node.
* version is increased by one each time a Variable get composed to a mutation Op.
* This information can be helpful to decide order of operations when sequence of mutation happens.
*/
uint32_t
version
;
};
};
/*!
/*!
...
...
nnvm/include/nnvm/op.h
View file @
badcdfff
...
@@ -101,13 +101,13 @@ class Op {
...
@@ -101,13 +101,13 @@ class Op {
* \param attrs The attribute of the node
* \param attrs The attribute of the node
* \return number of outputs.
* \return number of outputs.
*/
*/
uint32_t
(
*
get_num_outputs
)(
const
NodeAttrs
&
attrs
)
=
nullptr
;
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attrs
)
>
get_num_outputs
=
nullptr
;
/*!
/*!
* \brief get number of inputs given information about the node.
* \brief get number of inputs given information about the node.
* \param attrs The attribute of the node
* \param attrs The attribute of the node
* \return number of inputs
* \return number of inputs
*/
*/
uint32_t
(
*
get_num_inputs
)(
const
NodeAttrs
&
attrs
)
=
nullptr
;
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attrs
)
>
get_num_inputs
=
nullptr
;
/*!
/*!
* \brief Attribute parser to parse the NodeAttrs information.
* \brief Attribute parser to parse the NodeAttrs information.
*
*
...
@@ -140,8 +140,7 @@ class Op {
...
@@ -140,8 +140,7 @@ class Op {
* }
* }
* \endcode
* \endcode
*/
*/
void
(
*
attr_parser
)(
NodeAttrs
*
attrs
)
=
nullptr
;
std
::
function
<
void
(
NodeAttrs
*
attrs
)
>
attr_parser
=
nullptr
;
// function fields.
// function fields.
/*!
/*!
* \brief setter function during registration
* \brief setter function during registration
...
@@ -161,7 +160,7 @@ class Op {
...
@@ -161,7 +160,7 @@ class Op {
* \param fn The function to be set.
* \param fn The function to be set.
* \return reference to self.
* \return reference to self.
*/
*/
inline
Op
&
set_num_inputs
(
uint32_t
(
*
fn
)(
const
NodeAttrs
&
attr
)
);
// NOLINT(*)
inline
Op
&
set_num_inputs
(
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attr
)
>
fn
);
// NOLINT(*)
/*!
/*!
* \brief Set the num_outputs
* \brief Set the num_outputs
* \param n The number of outputs to be set.
* \param n The number of outputs to be set.
...
@@ -173,13 +172,13 @@ class Op {
...
@@ -173,13 +172,13 @@ class Op {
* \param fn The function to be set.
* \param fn The function to be set.
* \return reference to self.
* \return reference to self.
*/
*/
inline
Op
&
set_num_outputs
(
uint32_t
(
*
fn
)(
const
NodeAttrs
&
attr
)
);
// NOLINT(*)
inline
Op
&
set_num_outputs
(
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attr
)
>
fn
);
// NOLINT(*)
/*!
/*!
* \brief Set the attr_parser function.
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
* \param fn The number of outputs to be set.
* \return reference to self.
* \return reference to self.
*/
*/
inline
Op
&
set_attr_parser
(
void
(
*
fn
)(
NodeAttrs
*
attrs
)
);
// NOLINT(*)
inline
Op
&
set_attr_parser
(
std
::
function
<
void
(
NodeAttrs
*
attrs
)
>
fn
);
// NOLINT(*)
/*!
/*!
* \brief Register additional attributes to operator.
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param attr_name The name of the attribute.
...
@@ -342,7 +341,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
...
@@ -342,7 +341,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
return
*
this
;
return
*
this
;
}
}
inline
Op
&
Op
::
set_num_inputs
(
uint32_t
(
*
fn
)(
const
NodeAttrs
&
attr
)
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_num_inputs
(
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attr
)
>
fn
)
{
// NOLINT(*)
this
->
get_num_inputs
=
fn
;
this
->
get_num_inputs
=
fn
;
return
*
this
;
return
*
this
;
}
}
...
@@ -352,12 +351,12 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
...
@@ -352,12 +351,12 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
return
*
this
;
return
*
this
;
}
}
inline
Op
&
Op
::
set_num_outputs
(
uint32_t
(
*
fn
)(
const
NodeAttrs
&
attr
)
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_num_outputs
(
std
::
function
<
uint32_t
(
const
NodeAttrs
&
attr
)
>
fn
)
{
// NOLINT(*)
this
->
get_num_outputs
=
fn
;
this
->
get_num_outputs
=
fn
;
return
*
this
;
return
*
this
;
}
}
inline
Op
&
Op
::
set_attr_parser
(
void
(
*
fn
)(
NodeAttrs
*
attrs
)
)
{
// NOLINT(*)
inline
Op
&
Op
::
set_attr_parser
(
std
::
function
<
void
(
NodeAttrs
*
attrs
)
>
fn
)
{
// NOLINT(*)
this
->
attr_parser
=
fn
;
this
->
attr_parser
=
fn
;
return
*
this
;
return
*
this
;
}
}
...
...
nnvm/include/nnvm/op_attr_types.h
View file @
badcdfff
...
@@ -12,8 +12,8 @@
...
@@ -12,8 +12,8 @@
namespace
nnvm
{
namespace
nnvm
{
// These types are optional attributes in each op
// These types are optional attributes in each op
erator.
//
Some of them are needed for certain pas
s.
//
Each attribute can be required by some passe
s.
/*!
/*!
* \brief Return list of input arguments names of each operator.
* \brief Return list of input arguments names of each operator.
...
@@ -37,6 +37,16 @@ using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs&
...
@@ -37,6 +37,16 @@ using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs&
*/
*/
using
FListOutputNames
=
std
::
function
<
std
::
vector
<
std
::
string
>
(
const
NodeAttrs
&
attrs
)
>
;
using
FListOutputNames
=
std
::
function
<
std
::
vector
<
std
::
string
>
(
const
NodeAttrs
&
attrs
)
>
;
/*!
* \brief Check whether operator will mutate k-th input.
* \param index The input index
* \return Whether this operator will mutate index-th input.
*
* \note Register under "FMutateInput", default return false
* FMutateInputs enables mutation order handling correctly.
*/
using
FMutateInput
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
uint32_t
index
)
>
;
}
// namespace nnvm
}
// namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_
#endif // NNVM_OP_ATTR_TYPES_H_
nnvm/src/core/symbolic.cc
View file @
badcdfff
...
@@ -13,6 +13,43 @@ namespace symbol_constants {
...
@@ -13,6 +13,43 @@ namespace symbol_constants {
const
char
*
kNamespaceSeparator
=
"_"
;
const
char
*
kNamespaceSeparator
=
"_"
;
}
// namespace symbol_constants
}
// namespace symbol_constants
// auxililary version attribute in variable.
struct
VariableParam
{
uint32_t
version
{
0
};
};
std
::
shared_ptr
<
Node
>
CreateVariableNode
(
const
std
::
string
&
name
)
{
std
::
shared_ptr
<
Node
>
n
=
Node
::
Create
();
n
->
op
=
nullptr
;
n
->
attrs
.
name
=
name
;
n
->
attrs
.
parsed
=
VariableParam
();
return
n
;
}
// scan over a node's input, update the version to latest
// If the node's op mutates a certain input variable,
// The version of that varaible will increase
// version is used to implicitly order the mutation sequences
inline
void
UpdateNodeVersion
(
Node
*
n
)
{
static
auto
&
fmutate_inputs
=
Op
::
GetAttr
<
FMutateInput
>
(
"FMutateInput"
);
for
(
NodeEntry
&
e
:
n
->
inputs
)
{
if
(
e
.
node
->
is_variable
())
{
e
.
version
=
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
}
}
if
(
fmutate_inputs
.
count
(
n
->
op
)
!=
0
)
{
FMutateInput
fmutate
=
fmutate_inputs
[
n
->
op
];
for
(
uint32_t
i
=
0
;
i
<
n
->
inputs
.
size
();
++
i
)
{
if
(
fmutate
(
n
->
attrs
,
i
))
{
NodeEntry
&
e
=
n
->
inputs
[
i
];
CHECK
(
e
.
node
->
is_variable
())
<<
"Mutation target can only be Variable"
;
// increase the version of the variable.
++
nnvm
::
get
<
VariableParam
>
(
e
.
node
->
attrs
.
parsed
).
version
;
}
}
}
}
inline
std
::
string
DefaultVarName
(
const
std
::
string
&
op_name
,
inline
std
::
string
DefaultVarName
(
const
std
::
string
&
op_name
,
const
std
::
string
&
arg_name
)
{
const
std
::
string
&
arg_name
)
{
...
@@ -67,13 +104,13 @@ Symbol Symbol::Copy() const {
...
@@ -67,13 +104,13 @@ Symbol Symbol::Copy() const {
for
(
const
auto
&
kv
:
old_new
)
{
for
(
const
auto
&
kv
:
old_new
)
{
for
(
const
NodeEntry
&
e
:
kv
.
first
->
inputs
)
{
for
(
const
NodeEntry
&
e
:
kv
.
first
->
inputs
)
{
Node
*
ptr
=
e
.
node
.
get
();
Node
*
ptr
=
e
.
node
.
get
();
kv
.
second
->
inputs
.
emplace_back
(
NodeEntry
{
old_new
[
ptr
],
e
.
index
});
kv
.
second
->
inputs
.
emplace_back
(
NodeEntry
{
old_new
[
ptr
],
e
.
index
,
e
.
version
});
}
}
}
}
// set the head
// set the head
Symbol
ret
;
Symbol
ret
;
for
(
const
NodeEntry
&
e
:
outputs
)
{
for
(
const
NodeEntry
&
e
:
outputs
)
{
ret
.
outputs
.
emplace_back
(
NodeEntry
{
old_new
[
e
.
node
.
get
()],
e
.
index
});
ret
.
outputs
.
emplace_back
(
NodeEntry
{
old_new
[
e
.
node
.
get
()],
e
.
index
,
e
.
version
});
}
}
return
ret
;
return
ret
;
}
}
...
@@ -95,8 +132,14 @@ void Symbol::Print(std::ostream &os) const {
...
@@ -95,8 +132,14 @@ void Symbol::Print(std::ostream &os) const {
os
<<
"Name: "
<<
node
->
attrs
.
name
<<
" Op:"
<<
node
->
op
->
name
<<
'\n'
os
<<
"Name: "
<<
node
->
attrs
.
name
<<
" Op:"
<<
node
->
op
->
name
<<
'\n'
<<
"Inputs:
\n
"
;
<<
"Inputs:
\n
"
;
for
(
size_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
os
<<
"
\t
arg["
<<
i
<<
"]="
<<
node
->
inputs
[
i
].
node
->
attrs
.
name
const
NodeEntry
&
e
=
node
->
inputs
[
i
];
<<
'('
<<
node
->
inputs
[
i
].
index
<<
")
\n
"
;
os
<<
"
\t
arg["
<<
i
<<
"]="
<<
e
.
node
->
attrs
.
name
<<
'('
<<
e
.
index
<<
")"
;
if
(
e
.
node
->
is_variable
())
{
os
<<
" version="
<<
e
.
version
<<
'\n'
;
}
else
{
os
<<
'\n'
;
}
}
}
os
<<
"Attrs:
\n
"
;
os
<<
"Attrs:
\n
"
;
for
(
auto
&
kv
:
node
->
attrs
.
dict
)
{
for
(
auto
&
kv
:
node
->
attrs
.
dict
)
{
...
@@ -163,6 +206,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
...
@@ -163,6 +206,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
void
Symbol
::
Compose
(
const
std
::
vector
<
Symbol
>&
args
,
void
Symbol
::
Compose
(
const
std
::
vector
<
Symbol
>&
args
,
const
std
::
unordered_map
<
std
::
string
,
Symbol
>&
kwargs
,
const
std
::
unordered_map
<
std
::
string
,
Symbol
>&
kwargs
,
const
std
::
string
&
name
)
{
const
std
::
string
&
name
)
{
static
auto
&
flist_inputs
=
Op
::
GetAttr
<
FListInputNames
>
(
"FListInputNames"
);
CHECK_EQ
(
outputs
.
size
(),
1
)
CHECK_EQ
(
outputs
.
size
(),
1
)
<<
"Only composition of value function is supported currently"
;
<<
"Only composition of value function is supported currently"
;
CHECK
(
!
outputs
[
0
].
node
->
is_variable
())
<<
"Variable cannot be composed"
;
CHECK
(
!
outputs
[
0
].
node
->
is_variable
())
<<
"Variable cannot be composed"
;
...
@@ -193,7 +238,6 @@ void Symbol::Compose(const std::vector<Symbol>& args,
...
@@ -193,7 +238,6 @@ void Symbol::Compose(const std::vector<Symbol>& args,
}
}
// switch to keyword argument matching
// switch to keyword argument matching
if
(
args
.
size
()
!=
n_req
)
{
if
(
args
.
size
()
!=
n_req
)
{
static
auto
&
flist_inputs
=
Op
::
GetAttr
<
FListInputNames
>
(
"FListInputNames"
);
FListInputNames
fn
=
flist_inputs
.
get
(
n
->
op
,
nullptr
);
FListInputNames
fn
=
flist_inputs
.
get
(
n
->
op
,
nullptr
);
auto
arg_names
=
(
fn
==
nullptr
)
?
std
::
vector
<
std
::
string
>
{
"data"
}
:
fn
(
n
->
attrs
);
auto
arg_names
=
(
fn
==
nullptr
)
?
std
::
vector
<
std
::
string
>
{
"data"
}
:
fn
(
n
->
attrs
);
if
(
arg_names
.
size
()
!=
n_req
)
{
if
(
arg_names
.
size
()
!=
n_req
)
{
...
@@ -206,8 +250,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
...
@@ -206,8 +250,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
n
->
inputs
[
i
]
=
it
->
second
.
outputs
[
0
];
n
->
inputs
[
i
]
=
it
->
second
.
outputs
[
0
];
++
nmatched
;
++
nmatched
;
}
else
{
}
else
{
n
->
inputs
[
i
]
=
NodeEntry
{
Node
::
Create
(),
0
};
n
->
inputs
[
i
]
=
NodeEntry
{
n
->
inputs
[
i
].
node
->
attrs
.
name
=
DefaultVarName
(
name
,
arg_names
[
i
])
;
CreateVariableNode
(
DefaultVarName
(
name
,
arg_names
[
i
])),
0
,
0
}
;
}
}
}
}
...
@@ -226,6 +270,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
...
@@ -226,6 +270,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
n
->
inputs
.
push_back
(
s
.
outputs
[
0
]);
n
->
inputs
.
push_back
(
s
.
outputs
[
0
]);
}
}
}
}
UpdateNodeVersion
(
n
);
}
else
{
}
else
{
// general composition
// general composition
CHECK_EQ
(
args
.
size
(),
0
)
CHECK_EQ
(
args
.
size
(),
0
)
...
@@ -253,25 +298,32 @@ void Symbol::Compose(const std::vector<Symbol>& args,
...
@@ -253,25 +298,32 @@ void Symbol::Compose(const std::vector<Symbol>& args,
DFSVisit
(
this
->
outputs
,
find_replace_map
);
DFSVisit
(
this
->
outputs
,
find_replace_map
);
if
(
nmatched
==
kwargs
.
size
()
&&
arg_counter
<
args
.
size
())
{
if
(
nmatched
==
kwargs
.
size
()
&&
arg_counter
<
args
.
size
())
{
std
::
vector
<
Node
*>
update_nodes
;
std
::
vector
<
std
::
pair
<
NodeEntry
*
,
const
NodeEntry
*>
>
replace_plan
;
std
::
vector
<
std
::
pair
<
NodeEntry
*
,
const
NodeEntry
*>
>
replace_plan
;
auto
find_replace_plan
=
[
&
replace_map
,
&
replace_plan
]
auto
find_replace_plan
=
[
&
replace_map
,
&
replace_plan
,
&
update_nodes
]
(
const
std
::
shared_ptr
<
Node
>
&
node
)
{
(
const
std
::
shared_ptr
<
Node
>
&
node
)
{
// visit all the childs, find possible replacement
// visit all the childs, find possible replacement
bool
repl
=
false
;
for
(
size_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
node
->
inputs
.
size
();
++
i
)
{
NodeEntry
*
e
=
&
(
node
->
inputs
[
i
]);
NodeEntry
*
e
=
&
(
node
->
inputs
[
i
]);
if
(
e
->
node
->
is_variable
())
{
if
(
e
->
node
->
is_variable
())
{
auto
iter
=
replace_map
.
find
(
e
->
node
.
get
());
auto
iter
=
replace_map
.
find
(
e
->
node
.
get
());
if
(
iter
!=
replace_map
.
end
())
{
if
(
iter
!=
replace_map
.
end
())
{
replace_plan
.
push_back
(
std
::
make_pair
(
e
,
iter
->
second
));
replace_plan
.
push_back
(
std
::
make_pair
(
e
,
iter
->
second
));
repl
=
true
;
}
}
}
}
}
}
if
(
repl
)
update_nodes
.
push_back
(
node
.
get
());
};
};
DFSVisit
(
this
->
outputs
,
find_replace_plan
);
DFSVisit
(
this
->
outputs
,
find_replace_plan
);
for
(
const
auto
&
kv
:
replace_plan
)
{
for
(
const
auto
&
kv
:
replace_plan
)
{
*
(
kv
.
first
)
=
*
(
kv
.
second
);
*
(
kv
.
first
)
=
*
(
kv
.
second
);
}
}
for
(
Node
*
n
:
update_nodes
)
{
UpdateNodeVersion
(
n
);
}
}
else
{
}
else
{
std
::
vector
<
std
::
string
>
keys
=
GetKeys
(
kwargs
);
std
::
vector
<
std
::
string
>
keys
=
GetKeys
(
kwargs
);
std
::
vector
<
std
::
string
>
arg_names
=
ListArguments
();
std
::
vector
<
std
::
string
>
arg_names
=
ListArguments
();
...
@@ -303,9 +355,15 @@ Symbol Symbol::GetInternals() const {
...
@@ -303,9 +355,15 @@ Symbol Symbol::GetInternals() const {
Symbol
ret
;
Symbol
ret
;
DFSVisit
(
this
->
outputs
,
[
&
ret
](
const
std
::
shared_ptr
<
Node
>&
node
)
{
DFSVisit
(
this
->
outputs
,
[
&
ret
](
const
std
::
shared_ptr
<
Node
>&
node
)
{
Node
*
n
=
node
.
get
();
Node
*
n
=
node
.
get
();
if
(
n
->
is_variable
())
{
// grab version from variable.
VariableParam
&
param
=
nnvm
::
get
<
VariableParam
>
(
n
->
attrs
.
parsed
);
ret
.
outputs
.
emplace_back
(
NodeEntry
{
node
,
0
,
param
.
version
});
}
else
{
uint32_t
nout
=
n
->
num_outputs
();
uint32_t
nout
=
n
->
num_outputs
();
for
(
uint32_t
i
=
0
;
i
<
nout
;
++
i
)
{
for
(
uint32_t
i
=
0
;
i
<
nout
;
++
i
)
{
ret
.
outputs
.
emplace_back
(
NodeEntry
{
node
,
i
});
ret
.
outputs
.
emplace_back
(
NodeEntry
{
node
,
i
,
0
});
}
}
}
});
});
return
ret
;
return
ret
;
...
@@ -325,7 +383,7 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
...
@@ -325,7 +383,7 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
}
}
}
}
if
(
node
->
op
!=
nullptr
&&
node
->
op
->
attr_parser
!=
nullptr
)
{
if
(
node
->
op
!=
nullptr
&&
node
->
op
->
attr_parser
!=
nullptr
)
{
(
*
node
->
op
->
attr_parser
)
(
&
(
node
->
attrs
));
node
->
op
->
attr_parser
(
&
(
node
->
attrs
));
}
}
}
}
...
@@ -366,9 +424,9 @@ Symbol Symbol::CreateFunctor(const Op* op,
...
@@ -366,9 +424,9 @@ Symbol Symbol::CreateFunctor(const Op* op,
n
->
op
=
op
;
n
->
op
=
op
;
n
->
attrs
.
dict
=
std
::
move
(
attrs
);
n
->
attrs
.
dict
=
std
::
move
(
attrs
);
if
(
n
->
op
->
attr_parser
!=
nullptr
)
{
if
(
n
->
op
->
attr_parser
!=
nullptr
)
{
(
*
n
->
op
->
attr_parser
)
(
&
(
n
->
attrs
));
n
->
op
->
attr_parser
(
&
(
n
->
attrs
));
}
}
s
.
outputs
.
emplace_back
(
NodeEntry
{
std
::
move
(
n
),
0
});
s
.
outputs
.
emplace_back
(
NodeEntry
{
std
::
move
(
n
),
0
,
0
});
return
s
;
return
s
;
}
}
...
@@ -382,10 +440,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
...
@@ -382,10 +440,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
Symbol
Symbol
::
CreateVariable
(
const
std
::
string
&
name
)
{
Symbol
Symbol
::
CreateVariable
(
const
std
::
string
&
name
)
{
Symbol
s
;
Symbol
s
;
std
::
shared_ptr
<
Node
>
n
=
Node
::
Create
();
s
.
outputs
.
emplace_back
(
NodeEntry
{
CreateVariableNode
(
name
),
0
,
0
});
n
->
op
=
nullptr
;
n
->
attrs
.
name
=
name
;
s
.
outputs
.
emplace_back
(
NodeEntry
{
std
::
move
(
n
),
0
});
return
s
;
return
s
;
}
}
...
...
nnvm/src/example/operator.cc
View file @
badcdfff
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include <utility>
#include <utility>
using
nnvm
::
FListInputNames
;
using
nnvm
::
FListInputNames
;
using
nnvm
::
FMutateInput
;
using
nnvm
::
NodeAttrs
;
using
nnvm
::
NodeAttrs
;
NNVM_REGISTER_OP
(
add
)
NNVM_REGISTER_OP
(
add
)
...
@@ -29,3 +30,10 @@ NNVM_REGISTER_OP(conv2d)
...
@@ -29,3 +30,10 @@ NNVM_REGISTER_OP(conv2d)
NNVM_REGISTER_OP
(
add
)
NNVM_REGISTER_OP
(
add
)
.
attr
<
std
::
string
>
(
"nick_name"
,
"plus"
);
.
attr
<
std
::
string
>
(
"nick_name"
,
"plus"
);
NNVM_REGISTER_OP
(
assign
)
.
set_num_inputs
(
2
)
.
set_num_outputs
(
1
)
.
attr
<
FMutateInput
>
(
"FMutateInput"
,
[](
const
NodeAttrs
&
attrs
,
uint32_t
index
)
{
return
index
==
0
;
});
nnvm/tests/python/test_symbol.py
View file @
badcdfff
...
@@ -24,6 +24,20 @@ def test_default_input():
...
@@ -24,6 +24,20 @@ def test_default_input():
except
NNVMError
:
except
NNVMError
:
pass
pass
def
test_mutate_input
():
x
=
sym
.
Variable
(
'x'
)
y
=
sym
.
conv2d
(
data
=
x
,
name
=
'conv'
)
z
=
sym
.
assign
(
x
,
y
)
t
=
sym
.
add
(
z
,
x
)
try
:
z
=
sym
.
assign
(
z
,
z
)
assert
False
except
NNVMError
:
pass
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_default_input
()
test_default_input
()
test_compose
()
test_compose
()
test_mutate_input
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment