Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
869a953a
Commit
869a953a
authored
Sep 23, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[OP] Enable register via match tag (#57)
* [OP] Enable register via match tag * more docs on usage
parent
fa5c5883
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
186 additions
and
36 deletions
+186
-36
nnvm/example/src/operator.cc
+22
-16
nnvm/include/nnvm/op.h
+124
-18
nnvm/src/core/op.cc
+40
-2
No files found.
nnvm/example/src/operator.cc
View file @
869a953a
...
@@ -84,6 +84,7 @@ NNVM_REGISTER_OP(reshape)
...
@@ -84,6 +84,7 @@ NNVM_REGISTER_OP(reshape)
NNVM_REGISTER_OP
(
cast
)
NNVM_REGISTER_OP
(
cast
)
.
describe
(
"cast source type to target"
)
.
describe
(
"cast source type to target"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
include
(
"ElementwiseOpAttr"
)
.
set_attr_parser
(
.
set_attr_parser
(
[](
NodeAttrs
*
attrs
)
{
[](
NodeAttrs
*
attrs
)
{
// parse attr parser to get target attribute
// parse attr parser to get target attribute
...
@@ -92,7 +93,6 @@ NNVM_REGISTER_OP(cast)
...
@@ -92,7 +93,6 @@ NNVM_REGISTER_OP(cast)
CHECK
(
is
>>
dtype
);
CHECK
(
is
>>
dtype
);
attrs
->
parsed
=
std
::
move
(
dtype
);
attrs
->
parsed
=
std
::
move
(
dtype
);
})
})
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
set_attr
<
FInferType
>
(
.
set_attr
<
FInferType
>
(
"FInferType"
,
[](
const
NodeAttrs
&
attrs
,
"FInferType"
,
[](
const
NodeAttrs
&
attrs
,
std
::
vector
<
int
>
*
itype
,
std
::
vector
<
int
>
*
itype
,
...
@@ -101,23 +101,10 @@ NNVM_REGISTER_OP(cast)
...
@@ -101,23 +101,10 @@ NNVM_REGISTER_OP(cast)
return
true
;
return
true
;
});
});
NNVM_REGISTER_OP
(
exp
)
.
describe
(
"take exponential"
)
.
set_num_inputs
(
1
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
set_attr
<
FGradient
>
(
"FGradient"
,
[](
const
NodePtr
&
n
,
const
std
::
vector
<
NodeEntry
>&
ograds
)
{
return
std
::
vector
<
NodeEntry
>
{
MakeNode
(
"mul"
,
n
->
attrs
.
name
+
"_grad"
,
{
ograds
[
0
],
NodeEntry
{
n
,
0
,
0
}})
};
});
NNVM_REGISTER_OP
(
identity
)
NNVM_REGISTER_OP
(
identity
)
.
describe
(
"identity function"
)
.
describe
(
"identity function"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
include
(
"ElementwiseOpAttr"
)
.
set_attr
<
FGradient
>
(
.
set_attr
<
FGradient
>
(
"FGradient"
,
[](
const
NodePtr
&
n
,
"FGradient"
,
[](
const
NodePtr
&
n
,
const
std
::
vector
<
NodeEntry
>&
ograds
)
{
const
std
::
vector
<
NodeEntry
>&
ograds
)
{
...
@@ -128,7 +115,7 @@ NNVM_REGISTER_OP(add)
...
@@ -128,7 +115,7 @@ NNVM_REGISTER_OP(add)
.
describe
(
"add two data together"
)
.
describe
(
"add two data together"
)
.
set_num_inputs
(
2
)
.
set_num_inputs
(
2
)
.
add_alias
(
"__add_symbol__"
)
.
add_alias
(
"__add_symbol__"
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
include
(
"ElementwiseOpAttr"
)
.
set_attr
<
FInplaceOption
>
(
"FInplaceOption"
,
InplaceIn0Out0
)
.
set_attr
<
FInplaceOption
>
(
"FInplaceOption"
,
InplaceIn0Out0
)
.
set_attr
<
FGradient
>
(
.
set_attr
<
FGradient
>
(
"FGradient"
,
[](
const
NodePtr
&
n
,
"FGradient"
,
[](
const
NodePtr
&
n
,
...
@@ -139,6 +126,7 @@ NNVM_REGISTER_OP(add)
...
@@ -139,6 +126,7 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP
(
mul
)
NNVM_REGISTER_OP
(
mul
)
.
describe
(
"multiply two data together"
)
.
describe
(
"multiply two data together"
)
.
set_num_inputs
(
2
)
.
set_num_inputs
(
2
)
.
include
(
"ElementwiseOpAttr"
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
set_attr
<
FInplaceOption
>
(
"FInplaceOption"
,
InplaceIn0Out0
)
.
set_attr
<
FInplaceOption
>
(
"FInplaceOption"
,
InplaceIn0Out0
)
.
set_attr
<
FGradient
>
(
.
set_attr
<
FGradient
>
(
...
@@ -187,4 +175,22 @@ NNVM_REGISTER_OP(assign)
...
@@ -187,4 +175,22 @@ NNVM_REGISTER_OP(assign)
return
std
::
vector
<
uint32_t
>
{
0
};
return
std
::
vector
<
uint32_t
>
{
0
};
});
});
NNVM_REGISTER_OP_GROUP
(
ElementwiseOpAttr
)
.
set_attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
);
NNVM_REGISTER_OP
(
exp
)
.
describe
(
"take exponential"
)
.
set_num_inputs
(
1
)
.
include
(
"ElementwiseOpAttr"
)
.
set_attr
<
FGradient
>
(
"FGradient"
,
[](
const
NodePtr
&
n
,
const
std
::
vector
<
NodeEntry
>&
ograds
)
{
return
std
::
vector
<
NodeEntry
>
{
MakeNode
(
"mul"
,
n
->
attrs
.
name
+
"_grad"
,
{
ograds
[
0
],
NodeEntry
{
n
,
0
,
0
}})
};
});
}
// namespace myproject
}
// namespace myproject
nnvm/include/nnvm/op.h
View file @
869a953a
...
@@ -22,6 +22,7 @@ class Node;
...
@@ -22,6 +22,7 @@ class Node;
struct
NodeAttrs
;
struct
NodeAttrs
;
template
<
typename
ValueType
>
template
<
typename
ValueType
>
class
OpMap
;
class
OpMap
;
class
OpGroup
;
class
OpRegistryEntry
;
class
OpRegistryEntry
;
using
dmlc
::
ParamFieldInfo
;
using
dmlc
::
ParamFieldInfo
;
...
@@ -44,7 +45,13 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
...
@@ -44,7 +45,13 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* NNVM_REGISTER_OP(add)
* NNVM_REGISTER_OP(add)
* .describe("add two inputs together")
* .describe("add two inputs together")
* .set_num_inputs(2)
* .set_num_inputs(2)
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
* .set_attr<OpKernel>("OpKernel<gpu>", AddKernel)
* .include("ElementwiseOpAttr");
*
* // can register attribute by group
* // all the ops that include the group get the attribute.
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
*
*
* NNVM_REGISTER_OP(sub)
* NNVM_REGISTER_OP(sub)
* .describe("substract one tensor from another")
* .describe("substract one tensor from another")
...
@@ -53,7 +60,8 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
...
@@ -53,7 +60,8 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* // Can call regster multiple times in different files
* // Can call regster multiple times in different files
* // to register different part of information
* // to register different part of information
* NNVM_REGISTER_OP(sub)
* NNVM_REGISTER_OP(sub)
* .set_attr<OpKernel>("gpu_kernel", SubKernel);
* .set_attr<OpKernel>("OpKernel<gpu>", SubKernel);
* .include("ElementwiseOpAttr");
*
*
* // get operators from registry.
* // get operators from registry.
* void my_function() {
* void my_function() {
...
@@ -65,7 +73,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
...
@@ -65,7 +73,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
*
*
* // get additional registered information,
* // get additional registered information,
* // Assume user registered a OpKernel type attribute as gpu_kernel on each operator.
* // Assume user registered a OpKernel type attribute as gpu_kernel on each operator.
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("
gpu_kernel
");
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("
OpKernel<gpu>
");
* // we can get the kernel functions by using operator as key.
* // we can get the kernel functions by using operator as key.
* auto add_kernel = kernel[add];
* auto add_kernel = kernel[add];
* auto sub_kernel = kernel[sub];
* auto sub_kernel = kernel[sub];
...
@@ -200,6 +208,23 @@ class Op {
...
@@ -200,6 +208,23 @@ class Op {
*/
*/
inline
Op
&
set_attr_parser
(
std
::
function
<
void
(
NodeAttrs
*
attrs
)
>
fn
);
// NOLINT(*)
inline
Op
&
set_attr_parser
(
std
::
function
<
void
(
NodeAttrs
*
attrs
)
>
fn
);
// NOLINT(*)
/*!
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template
<
typename
ValueType
>
inline
Op
&
set_attr
(
const
std
::
string
&
attr_name
,
// NOLINT(*)
const
ValueType
&
value
,
int
plevel
=
10
);
/*!
* \brief Add another alias to this operator.
* \brief Add another alias to this operator.
* The same Op can be queried with Op::Get(alias)
* The same Op can be queried with Op::Get(alias)
* \param alias The alias of the operator.
* \param alias The alias of the operator.
...
@@ -207,14 +232,13 @@ class Op {
...
@@ -207,14 +232,13 @@ class Op {
*/
*/
Op
&
add_alias
(
const
std
::
string
&
alias
);
// NOLINT(*)
Op
&
add_alias
(
const
std
::
string
&
alias
);
// NOLINT(*)
/*!
/*!
* \brief Register additional attributes to operator.
* \brief Include all the attributes from an registered op group.
* \param attr_name The name of the attribute.
* \param group_name The name of the group.
* \param value The value to be set.
* \return reference to self.
* \tparam ValueType The type of the value to be set.
*
* \sa NNVM_REGISTER_OP_GROUP
*/
*/
template
<
typename
ValueType
>
Op
&
include
(
const
std
::
string
&
group_name
);
inline
Op
&
set_attr
(
const
std
::
string
&
attr_name
,
// NOLINT(*)
const
ValueType
&
value
);
/*!
/*!
* \brief Get an Op for a given operator name.
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
* Will raise an error if the op has not been registered.
...
@@ -235,6 +259,7 @@ class Op {
...
@@ -235,6 +259,7 @@ class Op {
private
:
private
:
template
<
typename
ValueType
>
template
<
typename
ValueType
>
friend
class
OpMap
;
friend
class
OpMap
;
friend
class
OpGroup
;
friend
class
dmlc
::
Registry
<
Op
>
;
friend
class
dmlc
::
Registry
<
Op
>
;
// Program internal unique index of operator.
// Program internal unique index of operator.
// Used to help index the program.
// Used to help index the program.
...
@@ -246,6 +271,13 @@ class Op {
...
@@ -246,6 +271,13 @@ class Op {
// update the attribute OpMap
// update the attribute OpMap
static
void
UpdateAttrMap
(
const
std
::
string
&
key
,
static
void
UpdateAttrMap
(
const
std
::
string
&
key
,
std
::
function
<
void
(
any
*
)
>
updater
);
std
::
function
<
void
(
any
*
)
>
updater
);
// add a trigger based on tag matching on certain tag attribute
// This will apply trigger on all the op such that
// include the corresponding group.
// The trigger will also be applied to all future registrations
// that calls include
static
void
AddGroupTrigger
(
const
std
::
string
&
group_name
,
std
::
function
<
void
(
Op
*
)
>
trigger
);
};
};
/*!
/*!
...
@@ -285,14 +317,44 @@ class OpMap {
...
@@ -285,14 +317,44 @@ class OpMap {
OpMap
()
=
default
;
OpMap
()
=
default
;
};
};
/*!
* \brief auxiliary data structure used to
* set attributes to a group of operators
*/
class
OpGroup
{
public
:
/*! \brief the tag key to be matched */
std
::
string
group_name
;
/*!
* \brief Register additional attributes to operator group.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template
<
typename
ValueType
>
inline
OpGroup
&
set_attr
(
const
std
::
string
&
attr_name
,
// NOLINT(*)
const
ValueType
&
value
,
int
plevel
=
1
);
};
// internal macros to make
// internal macros to make
#define NNVM_REGISTER_VAR_DEF(OpName) \
#define NNVM_REGISTER_VAR_DEF(OpName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
#define NNVM_REGISTER_GVAR_DEF(TagName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName
/*!
/*!
* \def NNVM_REGISTER_OP
* \def NNVM_REGISTER_OP
* \brief Register
* \brief Register
a new operator, or set attribute of the corresponding op.
*
This macro must be used under namespace dmlc, and only used once in cc file.
*
* \param OpName The name of registry
* \param OpName The name of registry
*
*
* \code
* \code
...
@@ -308,6 +370,31 @@ class OpMap {
...
@@ -308,6 +370,31 @@ class OpMap {
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
/*!
* \def NNVM_REGISTER_OP_GROUP
* \brief Register attribute to a group of operators.
* These attributes will be registered to Op that include the group.
*
* \param GroupName The name of the group.
*
* \code
*
* NNVM_REGISTER_OP(add)
* .include("ElementwiseOpAttr");
*
* // register same attributes to all the ops that include the group
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
*
* NNVM_REGISTER_OP(mul)
* .include("ElementwiseOpAttr");
*
* \endcode
*/
#define NNVM_REGISTER_OP_GROUP(GroupName) \
DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \
::nnvm::OpGroup {#GroupName}
// implementations of template functions after this.
// implementations of template functions after this.
// member function of Op
// member function of Op
template
<
typename
ValueType
>
template
<
typename
ValueType
>
...
@@ -330,9 +417,14 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
...
@@ -330,9 +417,14 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
template
<
typename
ValueType
>
template
<
typename
ValueType
>
inline
Op
&
Op
::
set_attr
(
// NOLINT(*)
inline
Op
&
Op
::
set_attr
(
// NOLINT(*)
const
std
::
string
&
attr_name
,
const
ValueType
&
value
)
{
const
std
::
string
&
attr_name
,
const
ValueType
&
value
,
int
plevel
)
{
CHECK_GT
(
plevel
,
0
)
<<
"plevel in set_attr must be greater than 0"
;
// update the attribute map of the key by creating new empty if needed.
// update the attribute map of the key by creating new empty if needed.
UpdateAttrMap
(
attr_name
,
[
this
,
attr_name
,
value
](
any
*
pmap
)
{
UpdateAttrMap
(
attr_name
,
[
this
,
attr_name
,
value
,
plevel
](
any
*
pmap
)
{
// the callback is in lockscope so is threadsafe.
// the callback is in lockscope so is threadsafe.
if
(
pmap
->
empty
())
{
if
(
pmap
->
empty
())
{
OpMap
<
ValueType
>
pm
;
OpMap
<
ValueType
>
pm
;
...
@@ -353,15 +445,18 @@ inline Op& Op::set_attr( // NOLINT(*)
...
@@ -353,15 +445,18 @@ inline Op& Op::set_attr( // NOLINT(*)
std
::
make_pair
(
ValueType
(),
0
));
std
::
make_pair
(
ValueType
(),
0
));
}
}
std
::
pair
<
ValueType
,
int
>&
p
=
vec
[
index_
];
std
::
pair
<
ValueType
,
int
>&
p
=
vec
[
index_
];
CHECK
(
p
.
second
==
0
)
CHECK
(
p
.
second
!=
plevel
)
<<
"Attribute "
<<
attr_name
<<
"Attribute "
<<
attr_name
<<
" of operator "
<<
this
->
name
<<
" of operator "
<<
this
->
name
<<
" is already registered."
;
<<
" is already registered with same plevel="
<<
plevel
;
vec
[
index_
]
=
std
::
make_pair
(
value
,
1
);
if
(
p
.
second
<
plevel
)
{
vec
[
index_
]
=
std
::
make_pair
(
value
,
plevel
);
}
});
});
return
*
this
;
return
*
this
;
}
}
inline
Op
&
Op
::
describe
(
const
std
::
string
&
descr
)
{
// NOLINT(*)
inline
Op
&
Op
::
describe
(
const
std
::
string
&
descr
)
{
// NOLINT(*)
this
->
description
=
descr
;
this
->
description
=
descr
;
return
*
this
;
return
*
this
;
...
@@ -409,7 +504,7 @@ template<typename ValueType>
...
@@ -409,7 +504,7 @@ template<typename ValueType>
inline
int
OpMap
<
ValueType
>::
count
(
const
Op
*
op
)
const
{
inline
int
OpMap
<
ValueType
>::
count
(
const
Op
*
op
)
const
{
if
(
op
==
nullptr
)
return
0
;
if
(
op
==
nullptr
)
return
0
;
const
uint32_t
idx
=
op
->
index_
;
const
uint32_t
idx
=
op
->
index_
;
return
idx
<
data_
.
size
()
?
data_
[
idx
].
second
:
0
;
return
idx
<
data_
.
size
()
?
(
data_
[
idx
].
second
!=
0
)
:
0
;
}
}
template
<
typename
ValueType
>
template
<
typename
ValueType
>
...
@@ -433,6 +528,17 @@ inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def
...
@@ -433,6 +528,17 @@ inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def
}
}
}
}
template
<
typename
ValueType
>
inline
OpGroup
&
OpGroup
::
set_attr
(
const
std
::
string
&
attr_name
,
const
ValueType
&
value
,
int
plevel
)
{
auto
trigger
=
[
attr_name
,
value
,
plevel
](
Op
*
op
)
{
op
->
set_attr
<
ValueType
>
(
attr_name
,
value
,
plevel
);
};
Op
::
AddGroupTrigger
(
group_name
,
trigger
);
return
*
this
;
}
}
// namespace nnvm
}
// namespace nnvm
#endif // NNVM_OP_H_
#endif // NNVM_OP_H_
nnvm/src/core/op.cc
View file @
869a953a
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <memory>
#include <memory>
#include <atomic>
#include <atomic>
#include <mutex>
#include <mutex>
#include <unordered_set>
namespace
dmlc
{
namespace
dmlc
{
// enable registry
// enable registry
...
@@ -20,11 +21,16 @@ namespace nnvm {
...
@@ -20,11 +21,16 @@ namespace nnvm {
// single manager of operator information.
// single manager of operator information.
struct
OpManager
{
struct
OpManager
{
// mutex to avoid registration from multiple threads.
// mutex to avoid registration from multiple threads.
std
::
mutex
mutex
;
// recursive is needed for trigger(which calls UpdateAttrMap)
std
::
recursive_mutex
mutex
;
// global operator counter
// global operator counter
std
::
atomic
<
int
>
op_counter
{
0
};
std
::
atomic
<
int
>
op_counter
{
0
};
// storage of additional attribute table.
// storage of additional attribute table.
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
any
>
>
attr
;
std
::
unordered_map
<
std
::
string
,
std
::
unique_ptr
<
any
>
>
attr
;
// storage of existing triggers
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
function
<
void
(
Op
*
)
>
>
>
tmap
;
// group of each operator.
std
::
vector
<
std
::
unordered_set
<
std
::
string
>
>
op_group
;
// get singleton of the
// get singleton of the
static
OpManager
*
Global
()
{
static
OpManager
*
Global
()
{
static
OpManager
inst
;
static
OpManager
inst
;
...
@@ -66,10 +72,42 @@ const any* Op::GetAttrMap(const std::string& key) {
...
@@ -66,10 +72,42 @@ const any* Op::GetAttrMap(const std::string& key) {
void
Op
::
UpdateAttrMap
(
const
std
::
string
&
key
,
void
Op
::
UpdateAttrMap
(
const
std
::
string
&
key
,
std
::
function
<
void
(
any
*
)
>
updater
)
{
std
::
function
<
void
(
any
*
)
>
updater
)
{
OpManager
*
mgr
=
OpManager
::
Global
();
OpManager
*
mgr
=
OpManager
::
Global
();
std
::
lock_guard
<
std
::
mutex
>
(
mgr
->
mutex
);
std
::
lock_guard
<
std
::
recursive_
mutex
>
(
mgr
->
mutex
);
std
::
unique_ptr
<
any
>&
value
=
mgr
->
attr
[
key
];
std
::
unique_ptr
<
any
>&
value
=
mgr
->
attr
[
key
];
if
(
value
.
get
()
==
nullptr
)
value
.
reset
(
new
any
());
if
(
value
.
get
()
==
nullptr
)
value
.
reset
(
new
any
());
if
(
updater
!=
nullptr
)
updater
(
value
.
get
());
if
(
updater
!=
nullptr
)
updater
(
value
.
get
());
}
}
void
Op
::
AddGroupTrigger
(
const
std
::
string
&
group_name
,
std
::
function
<
void
(
Op
*
)
>
trigger
)
{
OpManager
*
mgr
=
OpManager
::
Global
();
std
::
lock_guard
<
std
::
recursive_mutex
>
(
mgr
->
mutex
);
auto
&
tvec
=
mgr
->
tmap
[
group_name
];
tvec
.
push_back
(
trigger
);
auto
&
op_group
=
mgr
->
op_group
;
for
(
const
Op
*
op
:
dmlc
::
Registry
<
Op
>::
List
())
{
if
(
op
->
index_
<
op_group
.
size
()
&&
op_group
[
op
->
index_
].
count
(
group_name
)
!=
0
)
{
trigger
((
Op
*
)
op
);
// NOLINT(*)
}
}
}
Op
&
Op
::
include
(
const
std
::
string
&
group_name
)
{
OpManager
*
mgr
=
OpManager
::
Global
();
std
::
lock_guard
<
std
::
recursive_mutex
>
(
mgr
->
mutex
);
auto
it
=
mgr
->
tmap
.
find
(
group_name
);
if
(
it
!=
mgr
->
tmap
.
end
())
{
for
(
auto
&
trigger
:
it
->
second
)
{
trigger
(
this
);
}
}
auto
&
op_group
=
mgr
->
op_group
;
if
(
index_
>=
op_group
.
size
())
{
op_group
.
resize
(
index_
+
1
);
}
op_group
[
index_
].
insert
(
group_name
);
return
*
this
;
}
}
// namespace nnvm
}
// namespace nnvm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment