Commit 6ebbea5d by Tianqi Chen

[API] Change attr to explicit name set_attr (#46)

parent d0cc035c
...@@ -66,7 +66,7 @@ NNVM_REGISTER_OP(reshape) ...@@ -66,7 +66,7 @@ NNVM_REGISTER_OP(reshape)
CHECK(is >> target); CHECK(is >> target);
attrs->parsed = std::move(target); attrs->parsed = std::move(target);
}) })
.attr<FInferShape>( .set_attr<FInferShape>(
"FInferShape", [] (const NodeAttrs& attrs, "FInferShape", [] (const NodeAttrs& attrs,
std::vector<TShape> *ishape, std::vector<TShape> *ishape,
std::vector<TShape> *oshape) { std::vector<TShape> *oshape) {
...@@ -78,7 +78,7 @@ NNVM_REGISTER_OP(reshape) ...@@ -78,7 +78,7 @@ NNVM_REGISTER_OP(reshape)
<< "Reshape op: source target shape mismatch"; << "Reshape op: source target shape mismatch";
return true; return true;
}) })
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0); .set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0);
NNVM_REGISTER_OP(cast) NNVM_REGISTER_OP(cast)
...@@ -92,8 +92,8 @@ NNVM_REGISTER_OP(cast) ...@@ -92,8 +92,8 @@ NNVM_REGISTER_OP(cast)
CHECK(is >> dtype); CHECK(is >> dtype);
attrs->parsed = std::move(dtype); attrs->parsed = std::move(dtype);
}) })
.attr<FInferShape>("FInferShape", SameShape) .set_attr<FInferShape>("FInferShape", SameShape)
.attr<FInferType>( .set_attr<FInferType>(
"FInferType", [](const NodeAttrs& attrs, "FInferType", [](const NodeAttrs& attrs,
std::vector<int> *itype, std::vector<int> *itype,
std::vector<int> *otype) { std::vector<int> *otype) {
...@@ -104,8 +104,8 @@ NNVM_REGISTER_OP(cast) ...@@ -104,8 +104,8 @@ NNVM_REGISTER_OP(cast)
NNVM_REGISTER_OP(exp) NNVM_REGISTER_OP(exp)
.describe("take exponential") .describe("take exponential")
.set_num_inputs(1) .set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape) .set_attr<FInferShape>("FInferShape", SameShape)
.attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{ return std::vector<NodeEntry>{
...@@ -117,8 +117,8 @@ NNVM_REGISTER_OP(exp) ...@@ -117,8 +117,8 @@ NNVM_REGISTER_OP(exp)
NNVM_REGISTER_OP(identity) NNVM_REGISTER_OP(identity)
.describe("identity function") .describe("identity function")
.set_num_inputs(1) .set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape) .set_attr<FInferShape>("FInferShape", SameShape)
.attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) { const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{ograds[0]}; return std::vector<NodeEntry>{ograds[0]};
...@@ -128,9 +128,9 @@ NNVM_REGISTER_OP(add) ...@@ -128,9 +128,9 @@ NNVM_REGISTER_OP(add)
.describe("add two data together") .describe("add two data together")
.set_num_inputs(2) .set_num_inputs(2)
.add_alias("__add_symbol__") .add_alias("__add_symbol__")
.attr<FInferShape>("FInferShape", SameShape) .set_attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0) .set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{ograds[0], ograds[0]}; return std::vector<NodeEntry>{ograds[0], ograds[0]};
...@@ -139,9 +139,9 @@ NNVM_REGISTER_OP(add) ...@@ -139,9 +139,9 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP(mul) NNVM_REGISTER_OP(mul)
.describe("multiply two data together") .describe("multiply two data together")
.set_num_inputs(2) .set_num_inputs(2)
.attr<FInferShape>("FInferShape", SameShape) .set_attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0) .set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>( .set_attr<FGradient>(
"FGradient", [](const NodePtr& n, "FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){ const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{ return std::vector<NodeEntry>{
...@@ -167,23 +167,23 @@ NNVM_REGISTER_OP(__one__) ...@@ -167,23 +167,23 @@ NNVM_REGISTER_OP(__one__)
NNVM_REGISTER_OP(cross_device_copy) NNVM_REGISTER_OP(cross_device_copy)
.describe("Copy data across device.") .describe("Copy data across device.")
.set_num_inputs(1) .set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape); .set_attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(conv2d) NNVM_REGISTER_OP(conv2d)
.describe("take conv of input") .describe("take conv of input")
.set_num_inputs(2) .set_num_inputs(2)
.attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { .set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "weight"}; return std::vector<std::string>{"data", "weight"};
}); });
NNVM_REGISTER_OP(add) NNVM_REGISTER_OP(add)
.attr<std::string>("nick_name", "plus"); .set_attr<std::string>("nick_name", "plus");
NNVM_REGISTER_OP(assign) NNVM_REGISTER_OP(assign)
.set_num_inputs(2) .set_num_inputs(2)
.set_num_outputs(1) .set_num_outputs(1)
.attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) { .set_attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0}; return std::vector<uint32_t>{0};
}); });
......
...@@ -44,7 +44,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max(); ...@@ -44,7 +44,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* NNVM_REGISTER_OP(add) * NNVM_REGISTER_OP(add)
* .describe("add two inputs together") * .describe("add two inputs together")
* .set_num_inputs(2) * .set_num_inputs(2)
* .attr<OpKernel>("gpu_kernel", AddKernel); * .set_attr<OpKernel>("gpu_kernel", AddKernel);
* *
* NNVM_REGISTER_OP(sub) * NNVM_REGISTER_OP(sub)
* .describe("substract one tensor from another") * .describe("substract one tensor from another")
...@@ -53,7 +53,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max(); ...@@ -53,7 +53,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* // Can call regster multiple times in different files * // Can call regster multiple times in different files
* // to register different part of information * // to register different part of information
* NNVM_REGISTER_OP(sub) * NNVM_REGISTER_OP(sub)
* .attr<OpKernel>("gpu_kernel", SubKernel); * .set_attr<OpKernel>("gpu_kernel", SubKernel);
* *
* // get operators from registry. * // get operators from registry.
* void my_function() { * void my_function() {
...@@ -213,8 +213,8 @@ class Op { ...@@ -213,8 +213,8 @@ class Op {
* \tparam ValueType The type of the value to be set. * \tparam ValueType The type of the value to be set.
*/ */
template<typename ValueType> template<typename ValueType>
inline Op& attr(const std::string& attr_name, // NOLINT(*) inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value); const ValueType& value);
/*! /*!
* \brief Get an Op for a given operator name. * \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered. * Will raise an error if the op has not been registered.
...@@ -300,7 +300,7 @@ class OpMap { ...@@ -300,7 +300,7 @@ class OpMap {
* NNVM_REGISTER_OP(add) * NNVM_REGISTER_OP(add)
* .describe("add two inputs together") * .describe("add two inputs together")
* .set_num_inputs(2) * .set_num_inputs(2)
* .attr<OpKernel>("gpu_kernel", AddKernel); * .set_attr<OpKernel>("gpu_kernel", AddKernel);
* *
* \endcode * \endcode
*/ */
...@@ -329,7 +329,7 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) { ...@@ -329,7 +329,7 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
} }
template<typename ValueType> template<typename ValueType>
inline Op& Op::attr( // NOLINT(*) inline Op& Op::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value) { const std::string& attr_name, const ValueType& value) {
// update the attribute map of the key by creating new empty if needed. // update the attribute map of the key by creating new empty if needed.
UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) { UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment