Commit 6ebbea5d by Tianqi Chen

[API] Change attr to explicit name set_attr (#46)

parent d0cc035c
......@@ -66,7 +66,7 @@ NNVM_REGISTER_OP(reshape)
CHECK(is >> target);
attrs->parsed = std::move(target);
})
.attr<FInferShape>(
.set_attr<FInferShape>(
"FInferShape", [] (const NodeAttrs& attrs,
std::vector<TShape> *ishape,
std::vector<TShape> *oshape) {
......@@ -78,7 +78,7 @@ NNVM_REGISTER_OP(reshape)
<< "Reshape op: source target shape mismatch";
return true;
})
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0);
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0);
NNVM_REGISTER_OP(cast)
......@@ -92,8 +92,8 @@ NNVM_REGISTER_OP(cast)
CHECK(is >> dtype);
attrs->parsed = std::move(dtype);
})
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInferType>(
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInferType>(
"FInferType", [](const NodeAttrs& attrs,
std::vector<int> *itype,
std::vector<int> *otype) {
......@@ -104,8 +104,8 @@ NNVM_REGISTER_OP(cast)
NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FGradient>(
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
......@@ -117,8 +117,8 @@ NNVM_REGISTER_OP(exp)
NNVM_REGISTER_OP(identity)
.describe("identity function")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FGradient>(
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{ograds[0]};
......@@ -128,9 +128,9 @@ NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.add_alias("__add_symbol__")
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>(
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{ograds[0], ograds[0]};
......@@ -139,9 +139,9 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP(mul)
.describe("multiply two data together")
.set_num_inputs(2)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>(
.set_attr<FInferShape>("FInferShape", SameShape)
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.set_attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
......@@ -167,23 +167,23 @@ NNVM_REGISTER_OP(__one__)
NNVM_REGISTER_OP(cross_device_copy)
.describe("Copy data across device.")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape);
.set_attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(conv2d)
.describe("take conv of input")
.set_num_inputs(2)
.attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "weight"};
});
NNVM_REGISTER_OP(add)
.attr<std::string>("nick_name", "plus");
.set_attr<std::string>("nick_name", "plus");
NNVM_REGISTER_OP(assign)
.set_num_inputs(2)
.set_num_outputs(1)
.attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
.set_attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0};
});
......
......@@ -44,7 +44,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* NNVM_REGISTER_OP(add)
* .describe("add two inputs together")
* .set_num_inputs(2)
* .attr<OpKernel>("gpu_kernel", AddKernel);
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
*
* NNVM_REGISTER_OP(sub)
* .describe("substract one tensor from another")
......@@ -53,7 +53,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
* // Can call regster multiple times in different files
* // to register different part of information
* NNVM_REGISTER_OP(sub)
* .attr<OpKernel>("gpu_kernel", SubKernel);
* .set_attr<OpKernel>("gpu_kernel", SubKernel);
*
* // get operators from registry.
* void my_function() {
......@@ -213,8 +213,8 @@ class Op {
* \tparam ValueType The type of the value to be set.
*/
template<typename ValueType>
inline Op& attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value);
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
const ValueType& value);
/*!
* \brief Get an Op for a given operator name.
* Will raise an error if the op has not been registered.
......@@ -300,7 +300,7 @@ class OpMap {
* NNVM_REGISTER_OP(add)
* .describe("add two inputs together")
* .set_num_inputs(2)
* .attr<OpKernel>("gpu_kernel", AddKernel);
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
*
* \endcode
*/
......@@ -329,7 +329,7 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
}
template<typename ValueType>
inline Op& Op::attr( // NOLINT(*)
inline Op& Op::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value) {
// update the attribute map of the key by creating new empty if needed.
UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment