Commit 999dd1ef by yuruofeifei Committed by Tianqi Chen

[GRADIENT] Add backward operator to enable backward graph (#276)

* Update docs

* Add backward operator to enable backward graph

* Fix testing

* Refactor top level1 test code

* Fix format

* Test

* Added zeros ones op

* Register fill_like operator

* Fix unit test
parent e36fb360
......@@ -38,6 +38,8 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div
nnvm.symbol.fill
nnvm.symbol.fill_like
nnvm.symbol.flatten
nnvm.symbol.concatenate
nnvm.symbol.expand_dims
......@@ -111,6 +113,8 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.fill
.. autofunction:: nnvm.symbol.fill_like
.. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims
......
......@@ -62,21 +62,24 @@ enum TypeFlag {
kUint64 = 10,
};
#define DMLC_DECLARE_DTYPE_FIELD(name) \
DMLC_DECLARE_FIELD(name) \
.add_enum("float16", kFloat16) \
.add_enum("float32", kFloat32) \
.add_enum("float64", kFloat64) \
.add_enum("uint8", kUint8) \
.add_enum("uint16", kUint16) \
.add_enum("uint32", kUint32) \
.add_enum("uint64", kUint64) \
.add_enum("int8", kInt8) \
.add_enum("int16", kInt16) \
.add_enum("int32", kInt32) \
.add_enum("int64", kInt64)
struct CastParam : public dmlc::Parameter<CastParam> {
int dtype;
DMLC_DECLARE_PARAMETER(CastParam) {
DMLC_DECLARE_FIELD(dtype)
.add_enum("float16", kFloat16)
.add_enum("float32", kFloat32)
.add_enum("float64", kFloat64)
.add_enum("uint8", kUint8)
.add_enum("uint16", kUint16)
.add_enum("uint32", kUint32)
.add_enum("uint64", kUint64)
.add_enum("int8", kInt8)
.add_enum("int16", kInt16)
.add_enum("int32", kInt32)
.add_enum("int64", kInt64)
DMLC_DECLARE_DTYPE_FIELD(dtype)
.describe("Output data type.");
}
};
......@@ -155,6 +158,19 @@ struct ReduceParam : public dmlc::Parameter<ReduceParam> {
}
};
struct InitOpParam : public dmlc::Parameter<InitOpParam> {
TShape shape;
int dtype;
double value;
DMLC_DECLARE_PARAMETER(InitOpParam) {
DMLC_DECLARE_FIELD(shape).set_default(TShape());
DMLC_DECLARE_DTYPE_FIELD(dtype).set_default(kFloat32)
.describe("Target data type.");
DMLC_DECLARE_FIELD(value).describe("Value to fill");
}
};
} // namespace top
} // namespace nnvm
......
......@@ -7,7 +7,7 @@ class OpPattern(object):
See Also
--------
top.tag : Contains explaination of the tag type.
top.tag : Contains explanation of the tag type.
"""
# Elementwise operator
ELEMWISE = 0
......
......@@ -97,6 +97,16 @@ inline bool ElemwiseType(const NodeAttrs& attrs,
.add_argument("data", "Tensor", "The input tensor.")
#define NNVM_REGISTER_INIT_OP(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(0) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<InitOpParam>) \
.add_arguments(InitOpParam::__FIELDS__()) \
.set_attr<FInferShape>("FInferShape", ZeroShape) \
.set_attr<FInferType>("FInferType", ZeroType)
#define NNVM_REGISTER_ELEMWISE_BINARY_OP(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \
......
......@@ -8,6 +8,7 @@
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <nnvm/top/tensor.h>
#include <string>
#include <vector>
#include <unordered_set>
......@@ -16,7 +17,7 @@ namespace nnvm {
namespace top {
/*!
* \brief Parse keyword arguments as PType arguments and save to parsed
* \tparam PType the arameter type.
* \tparam PType the parameter type.
* \param attrs The attributes.
*/
template<typename PType>
......@@ -202,6 +203,28 @@ inline std::string attr_assign_error_msg(const NodeAttrs& attrs,
} \
}
/*!
* \brief macro assign rhs shape to lhs
* Use macro so we can see the error file more clearly
* \param lhs lhs shape
* \param rhs rhs shape
*/
#define SHAPE_ASSIGN(lhs, rhs) \
if ((lhs).ndim() == 0) (lhs) = (rhs); \
else \
CHECK_EQ(lhs, rhs) << "shape inference inconsistent"; \
/*!
* \brief macro assign rhs type to lhs
* Use macro so we can see the error file more clearly
* \param lhs lhs type
* \param rhs rhs type
*/
#define DTYPE_ASSIGN(lhs, rhs) \
if ((lhs) == -1) (lhs) = (rhs); \
else \
CHECK_EQ(lhs, rhs) << "type inference inconsistent"; \
// simply return the shape as same
inline bool SameShape(const NodeAttrs& attrs,
std::vector<TShape> *ishape,
......@@ -216,6 +239,28 @@ inline bool SameShape(const NodeAttrs& attrs,
return true;
}
// return shape from node attrs
inline bool ZeroShape(const NodeAttrs& attrs,
std::vector<TShape> *ishape,
std::vector<TShape> *oshape) {
const TShape& ts = dmlc::get<InitOpParam>(attrs.parsed).shape;
if (ts.ndim() != 0) {
SHAPE_ASSIGN(oshape->at(0), ts);
return true;
} else {
return false;
}
}
// return type from node attrs
inline bool ZeroType(const NodeAttrs& attrs,
std::vector<int> *iattr,
std::vector<int> *oattr) {
int dtype = dmlc::get<InitOpParam>(attrs.parsed).dtype;
DTYPE_ASSIGN(oattr->at(0), dtype);
return true;
}
} // namespace top
} // namespace nnvm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment