Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
869a953a
Commit
869a953a
authored
Sep 23, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[OP] Enable register via match tag (#57)
* [OP] Enable register via match tag * more docs on usage
parent
fa5c5883
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
187 additions
and
37 deletions
+187
-37
nnvm/example/src/operator.cc
+22
-16
nnvm/include/nnvm/op.h
+125
-19
nnvm/src/core/op.cc
+40
-2
No files found.
nnvm/example/src/operator.cc
View file @
869a953a
...
...
@@ -84,6 +84,7 @@ NNVM_REGISTER_OP(reshape)
NNVM_REGISTER_OP
(
cast
)
.
describe
(
"cast source type to target"
)
.
set_num_inputs
(
1
)
.
include
(
"ElementwiseOpAttr"
)
.
set_attr_parser
(
[](
NodeAttrs
*
attrs
)
{
// parse attr parser to get target attribute
...
...
@@ -92,7 +93,6 @@ NNVM_REGISTER_OP(cast)
CHECK
(
is
>>
dtype
);
attrs
->
parsed
=
std
::
move
(
dtype
);
})
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
set_attr
<
FInferType
>
(
"FInferType"
,
[](
const
NodeAttrs
&
attrs
,
std
::
vector
<
int
>
*
itype
,
...
...
@@ -101,23 +101,10 @@ NNVM_REGISTER_OP(cast)
return
true
;
});
NNVM_REGISTER_OP
(
exp
)
.
describe
(
"take exponential"
)
.
set_num_inputs
(
1
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
set_attr
<
FGradient
>
(
"FGradient"
,
[](
const
NodePtr
&
n
,
const
std
::
vector
<
NodeEntry
>&
ograds
)
{
return
std
::
vector
<
NodeEntry
>
{
MakeNode
(
"mul"
,
n
->
attrs
.
name
+
"_grad"
,
{
ograds
[
0
],
NodeEntry
{
n
,
0
,
0
}})
};
});
NNVM_REGISTER_OP
(
identity
)
.
describe
(
"identity function"
)
.
set_num_inputs
(
1
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
include
(
"ElementwiseOpAttr"
)
.
set_attr
<
FGradient
>
(
"FGradient"
,
[](
const
NodePtr
&
n
,
const
std
::
vector
<
NodeEntry
>&
ograds
)
{
...
...
@@ -128,7 +115,7 @@ NNVM_REGISTER_OP(add)
.
describe
(
"add two data together"
)
.
set_num_inputs
(
2
)
.
add_alias
(
"__add_symbol__"
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
include
(
"ElementwiseOpAttr"
)
.
set_attr
<
FInplaceOption
>
(
"FInplaceOption"
,
InplaceIn0Out0
)
.
set_attr
<
FGradient
>
(
"FGradient"
,
[](
const
NodePtr
&
n
,
...
...
@@ -139,6 +126,7 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP
(
mul
)
.
describe
(
"multiply two data together"
)
.
set_num_inputs
(
2
)
.
include
(
"ElementwiseOpAttr"
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
set_attr
<
FInplaceOption
>
(
"FInplaceOption"
,
InplaceIn0Out0
)
.
set_attr
<
FGradient
>
(
...
...
@@ -187,4 +175,22 @@ NNVM_REGISTER_OP(assign)
return
std
::
vector
<
uint32_t
>
{
0
};
});
NNVM_REGISTER_OP_GROUP
(
ElementwiseOpAttr
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
);
NNVM_REGISTER_OP
(
exp
)
.
describe
(
"take exponential"
)
.
set_num_inputs
(
1
)
.
include
(
"ElementwiseOpAttr"
)
.
set_attr
<
FGradient
>
(
"FGradient"
,
[](
const
NodePtr
&
n
,
const
std
::
vector
<
NodeEntry
>&
ograds
)
{
return
std
::
vector
<
NodeEntry
>
{
MakeNode
(
"mul"
,
n
->
attrs
.
name
+
"_grad"
,
{
ograds
[
0
],
NodeEntry
{
n
,
0
,
0
}})
};
});
}
// namespace myproject
nnvm/include/nnvm/op.h
View file @
869a953a
...
...
@@ -22,6 +22,7 @@ class Node;
struct
NodeAttrs
;
template
<
typename
ValueType
>
class
OpMap
;
class
OpGroup
;
class
OpRegistryEntry
;
using
dmlc
::
ParamFieldInfo
;
...
...
@@ -44,7 +45,13 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* NNVM_REGISTER_OP(add)
* .describe("add two inputs together")
* .set_num_inputs(2)
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
* .set_attr<OpKernel>("OpKernel<gpu>", AddKernel)
* .include("ElementwiseOpAttr");
*
* // can register attribute by group
* // all the ops that include the group get the attribute.
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
*
* NNVM_REGISTER_OP(sub)
* .describe("substract one tensor from another")
...
...
@@ -53,7 +60,8 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* // Can call regster multiple times in different files
* // to register different part of information
* NNVM_REGISTER_OP(sub)
* .set_attr<OpKernel>("gpu_kernel", SubKernel);
* .set_attr<OpKernel>("OpKernel<gpu>", SubKernel);
* .include("ElementwiseOpAttr");
*
* // get operators from registry.
* void my_function() {
...
...
@@ -65,7 +73,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
*
* // get additional registered information,
* // Assume user registered a OpKernel type attribute as gpu_kernel on each operator.
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("
gpu_kernel
");
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("
OpKernel<gpu>
");
* // we can get the kernel functions by using operator as key.
* auto add_kernel = kernel[add];
* auto sub_kernel = kernel[sub];
...
...
@@ -200,6 +208,23 @@ class Op {
*/
inline
Op
&
set_attr_parser
(
std
::
function
<
void
(
NodeAttrs
*
attrs
)
>
fn
);
// NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template
<
typename
ValueType
>
inline
Op
&
set_attr
(
const
std
::
string
&
attr_name
,
// NOLINT(*)
const
ValueType
&
value
,
int
plevel
=
10
);
/*!
* \brief Add another alias to this operator.
* The same Op can be queried with Op::Get(alias)
* \param alias The alias of the operator.
...
...
@@ -207,14 +232,13 @@ class Op {
*/
Op
&
add_alias
(
const
std
::
string
&
alias
);
// NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \tparam ValueType The type of the value to be set.
* \brief Include all the attributes from an registered op group.
* \param group_name The name of the group.
* \return reference to self.
*
* \sa NNVM_REGISTER_OP_GROUP
*/
template
<
typename
ValueType
>
inline
Op
&
set_attr
(
const
std
::
string
&
attr_name
,
// NOLINT(*)
const
ValueType
&
value
);
Op
&
include
(
const
std
::
string
&
group_name
);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
...
...
@@ -235,6 +259,7 @@ class Op {
private
:
template
<
typename
ValueType
>
friend
class
OpMap
;
friend
class
OpGroup
;
friend
class
dmlc
::
Registry
<
Op
>
;
// Program internal unique index of operator.
// Used to help index the program.
...
...
@@ -246,6 +271,13 @@ class Op {
// update the attribute OpMap
static
void
UpdateAttrMap
(
const
std
::
string
&
key
,
std
::
function
<
void
(
any
*
)
>
updater
);
// add a trigger based on tag matching on certain tag attribute
// This will apply trigger on all the op such that
// include the corresponding group.
// The trigger will also be applied to all future registrations
// that calls include
static
void
AddGroupTrigger
(
const
std
::
string
&
group_name
,
std
::
function
<
void
(
Op
*
)
>
trigger
);
};
/*!
...
...
@@ -285,14 +317,44 @@ class OpMap {
OpMap
()
=
default
;
};
/*!
* \brief auxiliary data structure used to
* set attributes to a group of operators
*/
class
OpGroup
{
public
:
/*! \brief the tag key to be matched */
std
::
string
group_name
;
/*!
* \brief Register additional attributes to operator group.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template
<
typename
ValueType
>
inline
OpGroup
&
set_attr
(
const
std
::
string
&
attr_name
,
// NOLINT(*)
const
ValueType
&
value
,
int
plevel
=
1
);
};
// internal macros to make
#define NNVM_REGISTER_VAR_DEF(OpName) \
#define NNVM_REGISTER_VAR_DEF(OpName)
\
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
#define NNVM_REGISTER_GVAR_DEF(TagName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName
/*!
* \def NNVM_REGISTER_OP
* \brief Register
*
This macro must be used under namespace dmlc, and only used once in cc file.
* \brief Register
a new operator, or set attribute of the corresponding op.
*
* \param OpName The name of registry
*
* \code
...
...
@@ -308,6 +370,31 @@ class OpMap {
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
/*!
* \def NNVM_REGISTER_OP_GROUP
* \brief Register attribute to a group of operators.
* These attributes will be registered to Op that include the group.
*
* \param GroupName The name of the group.
*
* \code
*
* NNVM_REGISTER_OP(add)
* .include("ElementwiseOpAttr");
*
* // register same attributes to all the ops that include the group
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
*
* NNVM_REGISTER_OP(mul)
* .include("ElementwiseOpAttr");
*
* \endcode
*/
#define NNVM_REGISTER_OP_GROUP(GroupName) \
DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \
::nnvm::OpGroup {#GroupName}
// implementations of template functions after this.
// member function of Op
template
<
typename
ValueType
>
...
...
@@ -330,9 +417,14 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
template
<
typename
ValueType
>
inline
Op
&
Op
::
set_attr
(
// NOLINT(*)
const
std
::
string
&
attr_name
,
const
ValueType
&
value
)
{
const
std
::
string
&
attr_name
,
const
ValueType
&
value
,
int
plevel
)
{
CHECK_GT
(
plevel
,
0
)
<<
"plevel in set_attr must be greater than 0"
;
// update the attribute map of the key by creating new empty if needed.
UpdateAttrMap
(
attr_name
,
[
this
,
attr_name
,
value
](
any
*
pmap
)
{
UpdateAttrMap
(
attr_name
,
[
this
,
attr_name
,
value
,
plevel
](
any
*
pmap
)
{
// the callback is in lockscope so is threadsafe.
if
(
pmap
->
empty
())
{
OpMap
<
ValueType
>
pm
;
...
...
@@ -353,15 +445,18 @@ inline Op& Op::set_attr( // NOLINT(*)
std
::
make_pair
(
ValueType
(),
0
));
}
std
::
pair
<
ValueType
,
int
>&
p
=
vec
[
index_
];
CHECK
(
p
.
second
==
0
)
CHECK
(
p
.
second
!=
plevel
)
<<
"Attribute "
<<
attr_name
<<
" of operator "
<<
this
->
name
<<
" is already registered."
;
vec
[
index_
]
=
std
::
make_pair
(
value
,
1
);
<<
" is already registered with same plevel="
<<
plevel
;
if
(
p
.
second
<
plevel
)
{
vec
[
index_
]
=
std
::
make_pair
(
value
,
plevel
);
}
});
return
*
this
;
}
inline
Op
&
Op
::
describe
(
const
std
::
string
&
descr
)
{
// NOLINT(*)
this
->
description
=
descr
;
return
*
this
;
...
...
@@ -409,7 +504,7 @@ template<typename ValueType>
inline
int
OpMap
<
ValueType
>::
count
(
const
Op
*
op
)
const
{
if
(
op
==
nullptr
)
return
0
;
const
uint32_t
idx
=
op
->
index_
;
return
idx
<
data_
.
size
()
?
data_
[
idx
].
second
:
0
;
return
idx
<
data_
.
size
()
?
(
data_
[
idx
].
second
!=
0
)
:
0
;
}
template
<
typename
ValueType
>
...
...
@@ -433,6 +528,17 @@ inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def
}
}
template
<
typename
ValueType
>
inline
OpGroup
&
OpGroup
::
set_attr
(
const
std
::
string
&
attr_name
,
const
ValueType
&
value
,
int
plevel
)
{
auto
trigger
=
[
attr_name
,
value
,
plevel
](
Op
*
op
)
{
op
->
set_attr
<
ValueType
>
(
attr_name
,
value
,
plevel
);
};
Op
::
AddGroupTrigger
(
group_name
,
trigger
);
return
*
this
;
}
}
// namespace nnvm
#endif // NNVM_OP_H_
nnvm/src/core/op.cc
View file @
869a953a
...
...
@@ -9,6 +9,7 @@
#include <memory>
#include <atomic>
#include <mutex>
#include <unordered_set>
namespace
dmlc
{
// enable registry
...
...
@@ -20,11 +21,16 @@ namespace nnvm {
// single manager of operator information.
struct
OpManager
{
// mutex to avoid registration from multiple threads.
std
::
mutex
mutex
;
// recursive is needed for trigger(which calls UpdateAttrMap)
std
::
recursive_mutex
mutex
;
// global operator counter
std
::
atomic
<
int
>
op_counter
{
0
};
// storage of additional attribute table.
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
any
>
>
attr
;
// storage of existing triggers
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
function
<
void
(
Op
*
)
>
>
>
tmap
;
// group of each operator.
std
::
vector
<
std
::
unordered_set
<
std
::
string
>
>
op_group
;
// get singleton of the
static
OpManager
*
Global
()
{
static
OpManager
inst
;
...
...
@@ -66,10 +72,42 @@ const any* Op::GetAttrMap(const std::string& key) {
void
Op
::
UpdateAttrMap
(
const
std
::
string
&
key
,
std
::
function
<
void
(
any
*
)
>
updater
)
{
OpManager
*
mgr
=
OpManager
::
Global
();
std
::
lock_guard
<
std
::
mutex
>
(
mgr
->
mutex
);
std
::
lock_guard
<
std
::
recursive_
mutex
>
(
mgr
->
mutex
);
std
::
unique_ptr
<
any
>&
value
=
mgr
->
attr
[
key
];
if
(
value
.
get
()
==
nullptr
)
value
.
reset
(
new
any
());
if
(
updater
!=
nullptr
)
updater
(
value
.
get
());
}
void
Op
::
AddGroupTrigger
(
const
std
::
string
&
group_name
,
std
::
function
<
void
(
Op
*
)
>
trigger
)
{
OpManager
*
mgr
=
OpManager
::
Global
();
std
::
lock_guard
<
std
::
recursive_mutex
>
(
mgr
->
mutex
);
auto
&
tvec
=
mgr
->
tmap
[
group_name
];
tvec
.
push_back
(
trigger
);
auto
&
op_group
=
mgr
->
op_group
;
for
(
const
Op
*
op
:
dmlc
::
Registry
<
Op
>::
List
())
{
if
(
op
->
index_
<
op_group
.
size
()
&&
op_group
[
op
->
index_
].
count
(
group_name
)
!=
0
)
{
trigger
((
Op
*
)
op
);
// NOLINT(*)
}
}
}
Op
&
Op
::
include
(
const
std
::
string
&
group_name
)
{
OpManager
*
mgr
=
OpManager
::
Global
();
std
::
lock_guard
<
std
::
recursive_mutex
>
(
mgr
->
mutex
);
auto
it
=
mgr
->
tmap
.
find
(
group_name
);
if
(
it
!=
mgr
->
tmap
.
end
())
{
for
(
auto
&
trigger
:
it
->
second
)
{
trigger
(
this
);
}
}
auto
&
op_group
=
mgr
->
op_group
;
if
(
index_
>=
op_group
.
size
())
{
op_group
.
resize
(
index_
+
1
);
}
op_group
[
index_
].
insert
(
group_name
);
return
*
this
;
}
}
// namespace nnvm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment