Commit 869a953a by Tianqi Chen

[OP] Enable register via match tag (#57)

* [OP] Enable register via match tag

* more docs on usage
parent fa5c5883
......@@ -84,6 +84,7 @@ NNVM_REGISTER_OP(reshape)
NNVM_REGISTER_OP(cast)
.describe("cast source type to target")
.set_num_inputs(1)
.include("ElementwiseOpAttr")
.set_attr_parser(
[](NodeAttrs* attrs) {
// parse attr parser to get target attribute
......@@ -92,7 +93,6 @@ NNVM_REGISTER_OP(cast)
CHECK(is >> dtype);
attrs->parsed = std::move(dtype);
})
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInferType>(
"FInferType", [](const NodeAttrs& attrs,
std::vector<int> *itype,
......@@ -101,23 +101,10 @@ NNVM_REGISTER_OP(cast)
return true;
});
NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("mul", n->attrs.name + "_grad",
{ograds[0], NodeEntry{n, 0, 0}})
};
});
NNVM_REGISTER_OP(identity)
.describe("identity function")
.set_num_inputs(1)
.set_attr<FInferShape>("FInferShape", SameShape)
.include("ElementwiseOpAttr")
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
......@@ -128,7 +115,7 @@ NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.add_alias("__add_symbol__")
.set_attr<FInferShape>("FInferShape", SameShape)
.include("ElementwiseOpAttr")
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
......@@ -139,6 +126,7 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP(mul)
.describe("multiply two data together")
.set_num_inputs(2)
.include("ElementwiseOpAttr")
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.set_attr<FGradient>(
......@@ -187,4 +175,22 @@ NNVM_REGISTER_OP(assign)
return std::vector<uint32_t>{0};
});
NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
.set_attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.include("ElementwiseOpAttr")
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("mul", n->attrs.name + "_grad",
{ograds[0], NodeEntry{n, 0, 0}})
};
});
} // namespace myproject
......@@ -22,6 +22,7 @@ class Node;
struct NodeAttrs;
template<typename ValueType>
class OpMap;
class OpGroup;
class OpRegistryEntry;
using dmlc::ParamFieldInfo;
......@@ -44,7 +45,13 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* NNVM_REGISTER_OP(add)
* .describe("add two inputs together")
* .set_num_inputs(2)
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
* .set_attr<OpKernel>("OpKernel<gpu>", AddKernel)
* .include("ElementwiseOpAttr");
*
* // can register attribute by group
* // all the ops that include the group get the attribute.
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
*
* NNVM_REGISTER_OP(sub)
* .describe("substract one tensor from another")
......@@ -53,7 +60,8 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* // Can call regster multiple times in different files
* // to register different part of information
* NNVM_REGISTER_OP(sub)
* .set_attr<OpKernel>("gpu_kernel", SubKernel);
* .set_attr<OpKernel>("OpKernel<gpu>", SubKernel);
* .include("ElementwiseOpAttr");
*
* // get operators from registry.
* void my_function() {
......@@ -65,7 +73,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
*
* // get additional registered information,
* // Assume user registered a OpKernel type attribute as gpu_kernel on each operator.
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("gpu_kernel");
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("OpKernel<gpu>");
* // we can get the kernel functions by using operator as key.
* auto add_kernel = kernel[add];
* auto sub_kernel = kernel[sub];
......@@ -200,6 +208,23 @@ class Op {
*/
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template<typename ValueType>
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value,
int plevel = 10);
/*!
* \brief Add another alias to this operator.
* The same Op can be queried with Op::Get(alias)
* \param alias The alias of the operator.
......@@ -207,14 +232,13 @@ class Op {
*/
Op& add_alias(const std::string& alias); // NOLINT(*)
/*!
* \brief Register additional attributes to operator.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \tparam ValueType The type of the value to be set.
* \brief Include all the attributes from an registered op group.
* \param group_name The name of the group.
* \return reference to self.
*
* \sa NNVM_REGISTER_OP_GROUP
*/
template<typename ValueType>
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value);
Op& include(const std::string& group_name);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
......@@ -235,6 +259,7 @@ class Op {
private:
template<typename ValueType>
friend class OpMap;
friend class OpGroup;
friend class dmlc::Registry<Op>;
// Program internal unique index of operator.
// Used to help index the program.
......@@ -246,6 +271,13 @@ class Op {
// update the attribute OpMap
static void UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater);
// add a trigger based on tag matching on certain tag attribute
// This will apply trigger on all the op such that
// include the corresponding group.
// The trigger will also be applied to all future registrations
// that calls include
static void AddGroupTrigger(const std::string& group_name,
std::function<void(Op*)> trigger);
};
/*!
......@@ -285,14 +317,44 @@ class OpMap {
OpMap() = default;
};
/*!
* \brief auxiliary data structure used to
* set attributes to a group of operators
*/
class OpGroup {
public:
/*! \brief the tag key to be matched */
std::string group_name;
/*!
* \brief Register additional attributes to operator group.
* \param attr_name The name of the attribute.
* \param value The value to be set.
* \param plevel The priority level of this set,
* an higher priority level attribute
* will replace lower priority level attribute.
* Must be bigger than 0.
*
* Cannot set with same plevel twice in the code.
*
* \tparam ValueType The type of the value to be set.
*/
template<typename ValueType>
inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value,
int plevel = 1);
};
// internal macros to make
#define NNVM_REGISTER_VAR_DEF(OpName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
#define NNVM_REGISTER_GVAR_DEF(TagName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName
/*!
* \def NNVM_REGISTER_OP
* \brief Register
* This macro must be used under namespace dmlc, and only used once in cc file.
* \brief Register a new operator, or set attribute of the corresponding op.
*
* \param OpName The name of registry
*
* \code
......@@ -308,6 +370,31 @@ class OpMap {
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
/*!
* \def NNVM_REGISTER_OP_GROUP
* \brief Register attribute to a group of operators.
* These attributes will be registered to Op that include the group.
*
* \param GroupName The name of the group.
*
* \code
*
* NNVM_REGISTER_OP(add)
* .include("ElementwiseOpAttr");
*
* // register same attributes to all the ops that include the group
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
*
* NNVM_REGISTER_OP(mul)
* .include("ElementwiseOpAttr");
*
* \endcode
*/
#define NNVM_REGISTER_OP_GROUP(GroupName) \
DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \
::nnvm::OpGroup {#GroupName}
// implementations of template functions after this.
// member function of Op
template<typename ValueType>
......@@ -330,9 +417,14 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
template<typename ValueType>
inline Op& Op::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value) {
const std::string& attr_name,
const ValueType& value,
int plevel) {
CHECK_GT(plevel, 0)
<< "plevel in set_attr must be greater than 0";
// update the attribute map of the key by creating new empty if needed.
UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) {
UpdateAttrMap(attr_name,
[this, attr_name, value, plevel](any* pmap) {
// the callback is in lockscope so is threadsafe.
if (pmap->empty()) {
OpMap<ValueType> pm;
......@@ -353,15 +445,18 @@ inline Op& Op::set_attr( // NOLINT(*)
std::make_pair(ValueType(), 0));
}
std::pair<ValueType, int>& p = vec[index_];
CHECK(p.second == 0)
CHECK(p.second != plevel)
<< "Attribute " << attr_name
<< " of operator " << this->name
<< " is already registered.";
vec[index_] = std::make_pair(value, 1);
<< " is already registered with same plevel=" << plevel;
if (p.second < plevel) {
vec[index_] = std::make_pair(value, plevel);
}
});
return *this;
}
inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
this->description = descr;
return *this;
......@@ -409,7 +504,7 @@ template<typename ValueType>
inline int OpMap<ValueType>::count(const Op* op) const {
if (op == nullptr) return 0;
const uint32_t idx = op->index_;
return idx < data_.size() ? data_[idx].second : 0;
return idx < data_.size() ? (data_[idx].second != 0) : 0;
}
template<typename ValueType>
......@@ -433,6 +528,17 @@ inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def
}
}
template<typename ValueType>
inline OpGroup& OpGroup::set_attr(const std::string& attr_name,
const ValueType& value,
int plevel) {
auto trigger = [attr_name, value, plevel](Op* op) {
op->set_attr<ValueType>(attr_name, value, plevel);
};
Op::AddGroupTrigger(group_name, trigger);
return *this;
}
} // namespace nnvm
#endif // NNVM_OP_H_
......@@ -9,6 +9,7 @@
#include <memory>
#include <atomic>
#include <mutex>
#include <unordered_set>
namespace dmlc {
// enable registry
......@@ -20,11 +21,16 @@ namespace nnvm {
// single manager of operator information.
struct OpManager {
// mutex to avoid registration from multiple threads.
std::mutex mutex;
// recursive is needed for trigger(which calls UpdateAttrMap)
std::recursive_mutex mutex;
// global operator counter
std::atomic<int> op_counter{0};
// storage of additional attribute table.
std::unordered_map<std::string, std::unique_ptr<any> > attr;
// storage of existing triggers
std::unordered_map<std::string, std::vector<std::function<void(Op*)> > > tmap;
// group of each operator.
std::vector<std::unordered_set<std::string> > op_group;
// get singleton of the
static OpManager* Global() {
static OpManager inst;
......@@ -66,10 +72,42 @@ const any* Op::GetAttrMap(const std::string& key) {
void Op::UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex>(mgr->mutex);
std::lock_guard<std::recursive_mutex>(mgr->mutex);
std::unique_ptr<any>& value = mgr->attr[key];
if (value.get() == nullptr) value.reset(new any());
if (updater != nullptr) updater(value.get());
}
void Op::AddGroupTrigger(const std::string& group_name,
std::function<void(Op*)> trigger) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::recursive_mutex>(mgr->mutex);
auto& tvec = mgr->tmap[group_name];
tvec.push_back(trigger);
auto& op_group = mgr->op_group;
for (const Op* op : dmlc::Registry<Op>::List()) {
if (op->index_ < op_group.size() &&
op_group[op->index_].count(group_name) != 0) {
trigger((Op*)op); // NOLINT(*)
}
}
}
Op& Op::include(const std::string& group_name) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::recursive_mutex>(mgr->mutex);
auto it = mgr->tmap.find(group_name);
if (it != mgr->tmap.end()) {
for (auto& trigger : it->second) {
trigger(this);
}
}
auto& op_group = mgr->op_group;
if (index_ >= op_group.size()) {
op_group.resize(index_ + 1);
}
op_group[index_].insert(group_name);
return *this;
}
} // namespace nnvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment