Commit 869a953a by Tianqi Chen

[OP] Enable register via match tag (#57)

* [OP] Enable register via match tag

* more docs on usage
parent fa5c5883
...@@ -84,6 +84,7 @@ NNVM_REGISTER_OP(reshape) ...@@ -84,6 +84,7 @@ NNVM_REGISTER_OP(reshape)
NNVM_REGISTER_OP(cast) NNVM_REGISTER_OP(cast)
.describe("cast source type to target") .describe("cast source type to target")
.set_num_inputs(1) .set_num_inputs(1)
.include("ElementwiseOpAttr")
.set_attr_parser( .set_attr_parser(
[](NodeAttrs* attrs) { [](NodeAttrs* attrs) {
// parse attr parser to get target attribute // parse attr parser to get target attribute
...@@ -92,7 +93,6 @@ NNVM_REGISTER_OP(cast) ...@@ -92,7 +93,6 @@ NNVM_REGISTER_OP(cast)
CHECK(is >> dtype); CHECK(is >> dtype);
attrs->parsed = std::move(dtype); attrs->parsed = std::move(dtype);
}) })
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInferType>( .set_attr<FInferType>(
"FInferType", [](const NodeAttrs& attrs, "FInferType", [](const NodeAttrs& attrs,
std::vector<int> *itype, std::vector<int> *itype,
...@@ -101,23 +101,10 @@ NNVM_REGISTER_OP(cast) ...@@ -101,23 +101,10 @@ NNVM_REGISTER_OP(cast)
return true; return true;
}); });
NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("mul", n->attrs.name + "_grad",
{ograds[0], NodeEntry{n, 0, 0}})
};
});
NNVM_REGISTER_OP(identity) NNVM_REGISTER_OP(identity)
.describe("identity function") .describe("identity function")
.set_num_inputs(1) .set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", SameShape) .include("ElementwiseOpAttr")
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
...@@ -128,7 +115,7 @@ NNVM_REGISTER_OP(add) ...@@ -128,7 +115,7 @@ NNVM_REGISTER_OP(add)
.describe("add two data together") .describe("add two data together")
.set_num_inputs(2) .set_num_inputs(2)
.add_alias("__add_symbol__") .add_alias("__add_symbol__")
.set_attr<FInferShape>("FInferShape", SameShape) .include("ElementwiseOpAttr")
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0) .set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.set_attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
...@@ -139,6 +126,7 @@ NNVM_REGISTER_OP(add) ...@@ -139,6 +126,7 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP(mul) NNVM_REGISTER_OP(mul)
.describe("multiply two data together") .describe("multiply two data together")
.set_num_inputs(2) .set_num_inputs(2)
.include("ElementwiseOpAttr")
.set_attr<FInferShape>("FInferShape", SameShape) .set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0) .set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.set_attr<FGradient>( .set_attr<FGradient>(
...@@ -187,4 +175,22 @@ NNVM_REGISTER_OP(assign) ...@@ -187,4 +175,22 @@ NNVM_REGISTER_OP(assign)
return std::vector<uint32_t>{0}; return std::vector<uint32_t>{0};
}); });
NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
.set_attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.include("ElementwiseOpAttr")
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("mul", n->attrs.name + "_grad",
{ograds[0], NodeEntry{n, 0, 0}})
};
});
} // namespace myproject } // namespace myproject
...@@ -22,6 +22,7 @@ class Node; ...@@ -22,6 +22,7 @@ class Node;
struct NodeAttrs; struct NodeAttrs;
template<typename ValueType> template<typename ValueType>
class OpMap; class OpMap;
class OpGroup;
class OpRegistryEntry; class OpRegistryEntry;
using dmlc::ParamFieldInfo; using dmlc::ParamFieldInfo;
...@@ -44,7 +45,13 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max(); ...@@ -44,7 +45,13 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* NNVM_REGISTER_OP(add) * NNVM_REGISTER_OP(add)
* .describe("add two inputs together") * .describe("add two inputs together")
* .set_num_inputs(2) * .set_num_inputs(2)
* .set_attr<OpKernel>("gpu_kernel", AddKernel); * .set_attr<OpKernel>("OpKernel<gpu>", AddKernel)
* .include("ElementwiseOpAttr");
*
* // can register attribute by group
* // all the ops that include the group get the attribute.
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
* *
* NNVM_REGISTER_OP(sub) * NNVM_REGISTER_OP(sub)
* .describe("substract one tensor from another") * .describe("substract one tensor from another")
...@@ -53,7 +60,8 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max(); ...@@ -53,7 +60,8 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* // Can call regster multiple times in different files * // Can call regster multiple times in different files
* // to register different part of information * // to register different part of information
* NNVM_REGISTER_OP(sub) * NNVM_REGISTER_OP(sub)
* .set_attr<OpKernel>("gpu_kernel", SubKernel); * .set_attr<OpKernel>("OpKernel<gpu>", SubKernel);
* .include("ElementwiseOpAttr");
* *
* // get operators from registry. * // get operators from registry.
* void my_function() { * void my_function() {
...@@ -65,7 +73,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max(); ...@@ -65,7 +73,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* *
* // get additional registered information, * // get additional registered information,
* // Assume user registered a OpKernel type attribute as gpu_kernel on each operator. * // Assume user registered a OpKernel type attribute as gpu_kernel on each operator.
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("gpu_kernel"); * const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("OpKernel<gpu>");
* // we can get the kernel functions by using operator as key. * // we can get the kernel functions by using operator as key.
* auto add_kernel = kernel[add]; * auto add_kernel = kernel[add];
* auto sub_kernel = kernel[sub]; * auto sub_kernel = kernel[sub];
...@@ -200,6 +208,23 @@ class Op { ...@@ -200,6 +208,23 @@ class Op {
*/ */
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*) inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
/*! /*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template<typename ValueType>
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value,
int plevel = 10);
/*!
* \brief Add another alias to this operator. * \brief Add another alias to this operator.
* The same Op can be queried with Op::Get(alias) * The same Op can be queried with Op::Get(alias)
* \param alias The alias of the operator. * \param alias The alias of the operator.
...@@ -207,14 +232,13 @@ class Op { ...@@ -207,14 +232,13 @@ class Op {
*/ */
Op& add_alias(const std::string& alias); // NOLINT(*) Op& add_alias(const std::string& alias); // NOLINT(*)
/*! /*!
* \brief Register additional attributes to operator. * \brief Include all the attributes from an registered op group.
* \param attr_name The name of the attribute. * \param group_name The name of the group.
* \param value The value to be set. * \return reference to self.
* \tparam ValueType The type of the value to be set. *
* \sa NNVM_REGISTER_OP_GROUP
*/ */
template<typename ValueType> Op& include(const std::string& group_name);
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value);
/*! /*!
* \brief Get an Op for a given operator name. * \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered. * Will raise an error if the op has not been registered.
...@@ -235,6 +259,7 @@ class Op { ...@@ -235,6 +259,7 @@ class Op {
private: private:
template<typename ValueType> template<typename ValueType>
friend class OpMap; friend class OpMap;
friend class OpGroup;
friend class dmlc::Registry<Op>; friend class dmlc::Registry<Op>;
// Program internal unique index of operator. // Program internal unique index of operator.
// Used to help index the program. // Used to help index the program.
...@@ -246,6 +271,13 @@ class Op { ...@@ -246,6 +271,13 @@ class Op {
// update the attribute OpMap // update the attribute OpMap
static void UpdateAttrMap(const std::string& key, static void UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater); std::function<void(any*)> updater);
// add a trigger based on tag matching on certain tag attribute
// This will apply trigger on all the op such that
// include the corresponding group.
// The trigger will also be applied to all future registrations
// that calls include
static void AddGroupTrigger(const std::string& group_name,
std::function<void(Op*)> trigger);
}; };
/*! /*!
...@@ -285,14 +317,44 @@ class OpMap { ...@@ -285,14 +317,44 @@ class OpMap {
OpMap() = default; OpMap() = default;
}; };
/*!
* \brief auxiliary data structure used to
* set attributes to a group of operators
*/
class OpGroup {
public:
/*! \brief the tag key to be matched */
std::string group_name;
/*!
* \brief Register additional attributes to operator group.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template<typename ValueType>
inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value,
int plevel = 1);
};
// internal macros to make // internal macros to make
#define NNVM_REGISTER_VAR_DEF(OpName) \ #define NNVM_REGISTER_VAR_DEF(OpName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
#define NNVM_REGISTER_GVAR_DEF(TagName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName
/*! /*!
* \def NNVM_REGISTER_OP * \def NNVM_REGISTER_OP
* \brief Register * \brief Register a new operator, or set attribute of the corresponding op.
* This macro must be used under namespace dmlc, and only used once in cc file. *
* \param OpName The name of registry * \param OpName The name of registry
* *
* \code * \code
...@@ -308,6 +370,31 @@ class OpMap { ...@@ -308,6 +370,31 @@ class OpMap {
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
/*!
* \def NNVM_REGISTER_OP_GROUP
* \brief Register attribute to a group of operators.
* These attributes will be registered to Op that include the group.
*
* \param GroupName The name of the group.
*
* \code
*
* NNVM_REGISTER_OP(add)
* .include("ElementwiseOpAttr");
*
* // register same attributes to all the ops that include the group
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
*
* NNVM_REGISTER_OP(mul)
* .include("ElementwiseOpAttr");
*
* \endcode
*/
#define NNVM_REGISTER_OP_GROUP(GroupName) \
DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \
::nnvm::OpGroup {#GroupName}
// implementations of template functions after this. // implementations of template functions after this.
// member function of Op // member function of Op
template<typename ValueType> template<typename ValueType>
...@@ -330,9 +417,14 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) { ...@@ -330,9 +417,14 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
template<typename ValueType> template<typename ValueType>
inline Op& Op::set_attr( // NOLINT(*) inline Op& Op::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value) { const std::string& attr_name,
const ValueType& value,
int plevel) {
CHECK_GT(plevel, 0)
<< "plevel in set_attr must be greater than 0";
// update the attribute map of the key by creating new empty if needed. // update the attribute map of the key by creating new empty if needed.
UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) { UpdateAttrMap(attr_name,
[this, attr_name, value, plevel](any* pmap) {
// the callback is in lockscope so is threadsafe. // the callback is in lockscope so is threadsafe.
if (pmap->empty()) { if (pmap->empty()) {
OpMap<ValueType> pm; OpMap<ValueType> pm;
...@@ -353,15 +445,18 @@ inline Op& Op::set_attr( // NOLINT(*) ...@@ -353,15 +445,18 @@ inline Op& Op::set_attr( // NOLINT(*)
std::make_pair(ValueType(), 0)); std::make_pair(ValueType(), 0));
} }
std::pair<ValueType, int>& p = vec[index_]; std::pair<ValueType, int>& p = vec[index_];
CHECK(p.second == 0) CHECK(p.second != plevel)
<< "Attribute " << attr_name << "Attribute " << attr_name
<< " of operator " << this->name << " of operator " << this->name
<< " is already registered."; << " is already registered with same plevel=" << plevel;
vec[index_] = std::make_pair(value, 1); if (p.second < plevel) {
vec[index_] = std::make_pair(value, plevel);
}
}); });
return *this; return *this;
} }
inline Op& Op::describe(const std::string& descr) { // NOLINT(*) inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
this->description = descr; this->description = descr;
return *this; return *this;
...@@ -409,7 +504,7 @@ template<typename ValueType> ...@@ -409,7 +504,7 @@ template<typename ValueType>
inline int OpMap<ValueType>::count(const Op* op) const { inline int OpMap<ValueType>::count(const Op* op) const {
if (op == nullptr) return 0; if (op == nullptr) return 0;
const uint32_t idx = op->index_; const uint32_t idx = op->index_;
return idx < data_.size() ? data_[idx].second : 0; return idx < data_.size() ? (data_[idx].second != 0) : 0;
} }
template<typename ValueType> template<typename ValueType>
...@@ -433,6 +528,17 @@ inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def ...@@ -433,6 +528,17 @@ inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def
} }
} }
template<typename ValueType>
inline OpGroup& OpGroup::set_attr(const std::string& attr_name,
const ValueType& value,
int plevel) {
auto trigger = [attr_name, value, plevel](Op* op) {
op->set_attr<ValueType>(attr_name, value, plevel);
};
Op::AddGroupTrigger(group_name, trigger);
return *this;
}
} // namespace nnvm } // namespace nnvm
#endif // NNVM_OP_H_ #endif // NNVM_OP_H_
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
#include <memory> #include <memory>
#include <atomic> #include <atomic>
#include <mutex> #include <mutex>
#include <unordered_set>
namespace dmlc { namespace dmlc {
// enable registry // enable registry
...@@ -20,11 +21,16 @@ namespace nnvm { ...@@ -20,11 +21,16 @@ namespace nnvm {
// single manager of operator information. // single manager of operator information.
struct OpManager { struct OpManager {
// mutex to avoid registration from multiple threads. // mutex to avoid registration from multiple threads.
std::mutex mutex; // recursive is needed for trigger(which calls UpdateAttrMap)
std::recursive_mutex mutex;
// global operator counter // global operator counter
std::atomic<int> op_counter{0}; std::atomic<int> op_counter{0};
// storage of additional attribute table. // storage of additional attribute table.
std::unordered_map<std::string, std::unique_ptr<any> > attr; std::unordered_map<std::string, std::unique_ptr<any> > attr;
// storage of existing triggers
std::unordered_map<std::string, std::vector<std::function<void(Op*)> > > tmap;
// group of each operator.
std::vector<std::unordered_set<std::string> > op_group;
// get singleton of the // get singleton of the
static OpManager* Global() { static OpManager* Global() {
static OpManager inst; static OpManager inst;
...@@ -66,10 +72,42 @@ const any* Op::GetAttrMap(const std::string& key) { ...@@ -66,10 +72,42 @@ const any* Op::GetAttrMap(const std::string& key) {
void Op::UpdateAttrMap(const std::string& key, void Op::UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater) { std::function<void(any*)> updater) {
OpManager* mgr = OpManager::Global(); OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex>(mgr->mutex); std::lock_guard<std::recursive_mutex>(mgr->mutex);
std::unique_ptr<any>& value = mgr->attr[key]; std::unique_ptr<any>& value = mgr->attr[key];
if (value.get() == nullptr) value.reset(new any()); if (value.get() == nullptr) value.reset(new any());
if (updater != nullptr) updater(value.get()); if (updater != nullptr) updater(value.get());
} }
void Op::AddGroupTrigger(const std::string& group_name,
std::function<void(Op*)> trigger) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::recursive_mutex>(mgr->mutex);
auto& tvec = mgr->tmap[group_name];
tvec.push_back(trigger);
auto& op_group = mgr->op_group;
for (const Op* op : dmlc::Registry<Op>::List()) {
if (op->index_ < op_group.size() &&
op_group[op->index_].count(group_name) != 0) {
trigger((Op*)op); // NOLINT(*)
}
}
}
Op& Op::include(const std::string& group_name) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::recursive_mutex>(mgr->mutex);
auto it = mgr->tmap.find(group_name);
if (it != mgr->tmap.end()) {
for (auto& trigger : it->second) {
trigger(this);
}
}
auto& op_group = mgr->op_group;
if (index_ >= op_group.size()) {
op_group.resize(index_ + 1);
}
op_group[index_].insert(group_name);
return *this;
}
} // namespace nnvm } // namespace nnvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment