Commit 999dd1ef by yuruofeifei Committed by Tianqi Chen

[GRADIENT] Add backward operator to enable backward graph (#276)

* Update docs

* Add backward operator to enable backward graph

* Fix testing

* Refactor top level1 test code

* Fix format

* Test

* Added zeros ones op

* Register fill_like operator

* Fix unit test
parent e36fb360
...@@ -38,6 +38,8 @@ This level enables fully connected multi-layer perceptron. ...@@ -38,6 +38,8 @@ This level enables fully connected multi-layer perceptron.
nnvm.symbol.elemwise_sub nnvm.symbol.elemwise_sub
nnvm.symbol.elemwise_mul nnvm.symbol.elemwise_mul
nnvm.symbol.elemwise_div nnvm.symbol.elemwise_div
nnvm.symbol.fill
nnvm.symbol.fill_like
nnvm.symbol.flatten nnvm.symbol.flatten
nnvm.symbol.concatenate nnvm.symbol.concatenate
nnvm.symbol.expand_dims nnvm.symbol.expand_dims
...@@ -111,6 +113,8 @@ Detailed Definitions ...@@ -111,6 +113,8 @@ Detailed Definitions
.. autofunction:: nnvm.symbol.elemwise_sub .. autofunction:: nnvm.symbol.elemwise_sub
.. autofunction:: nnvm.symbol.elemwise_mul .. autofunction:: nnvm.symbol.elemwise_mul
.. autofunction:: nnvm.symbol.elemwise_div .. autofunction:: nnvm.symbol.elemwise_div
.. autofunction:: nnvm.symbol.fill
.. autofunction:: nnvm.symbol.fill_like
.. autofunction:: nnvm.symbol.flatten .. autofunction:: nnvm.symbol.flatten
.. autofunction:: nnvm.symbol.concatenate .. autofunction:: nnvm.symbol.concatenate
.. autofunction:: nnvm.symbol.expand_dims .. autofunction:: nnvm.symbol.expand_dims
......
...@@ -62,21 +62,24 @@ enum TypeFlag { ...@@ -62,21 +62,24 @@ enum TypeFlag {
kUint64 = 10, kUint64 = 10,
}; };
#define DMLC_DECLARE_DTYPE_FIELD(name) \
DMLC_DECLARE_FIELD(name) \
.add_enum("float16", kFloat16) \
.add_enum("float32", kFloat32) \
.add_enum("float64", kFloat64) \
.add_enum("uint8", kUint8) \
.add_enum("uint16", kUint16) \
.add_enum("uint32", kUint32) \
.add_enum("uint64", kUint64) \
.add_enum("int8", kInt8) \
.add_enum("int16", kInt16) \
.add_enum("int32", kInt32) \
.add_enum("int64", kInt64)
struct CastParam : public dmlc::Parameter<CastParam> { struct CastParam : public dmlc::Parameter<CastParam> {
int dtype; int dtype;
DMLC_DECLARE_PARAMETER(CastParam) { DMLC_DECLARE_PARAMETER(CastParam) {
DMLC_DECLARE_FIELD(dtype) DMLC_DECLARE_DTYPE_FIELD(dtype)
.add_enum("float16", kFloat16)
.add_enum("float32", kFloat32)
.add_enum("float64", kFloat64)
.add_enum("uint8", kUint8)
.add_enum("uint16", kUint16)
.add_enum("uint32", kUint32)
.add_enum("uint64", kUint64)
.add_enum("int8", kInt8)
.add_enum("int16", kInt16)
.add_enum("int32", kInt32)
.add_enum("int64", kInt64)
.describe("Output data type."); .describe("Output data type.");
} }
}; };
...@@ -155,6 +158,19 @@ struct ReduceParam : public dmlc::Parameter<ReduceParam> { ...@@ -155,6 +158,19 @@ struct ReduceParam : public dmlc::Parameter<ReduceParam> {
} }
}; };
struct InitOpParam : public dmlc::Parameter<InitOpParam> {
TShape shape;
int dtype;
double value;
DMLC_DECLARE_PARAMETER(InitOpParam) {
DMLC_DECLARE_FIELD(shape).set_default(TShape());
DMLC_DECLARE_DTYPE_FIELD(dtype).set_default(kFloat32)
.describe("Target data type.");
DMLC_DECLARE_FIELD(value).describe("Value to fill");
}
};
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
......
...@@ -7,7 +7,7 @@ class OpPattern(object): ...@@ -7,7 +7,7 @@ class OpPattern(object):
See Also See Also
-------- --------
top.tag : Contains explaination of the tag type. top.tag : Contains explanation of the tag type.
""" """
# Elementwise operator # Elementwise operator
ELEMWISE = 0 ELEMWISE = 0
......
...@@ -97,6 +97,16 @@ inline bool ElemwiseType(const NodeAttrs& attrs, ...@@ -97,6 +97,16 @@ inline bool ElemwiseType(const NodeAttrs& attrs,
.add_argument("data", "Tensor", "The input tensor.") .add_argument("data", "Tensor", "The input tensor.")
#define NNVM_REGISTER_INIT_OP(name) \
NNVM_REGISTER_OP(name) \
.set_num_inputs(0) \
.set_num_outputs(1) \
.set_attr_parser(ParamParser<InitOpParam>) \
.add_arguments(InitOpParam::__FIELDS__()) \
.set_attr<FInferShape>("FInferShape", ZeroShape) \
.set_attr<FInferType>("FInferType", ZeroType)
#define NNVM_REGISTER_ELEMWISE_BINARY_OP(name) \ #define NNVM_REGISTER_ELEMWISE_BINARY_OP(name) \
NNVM_REGISTER_OP(name) \ NNVM_REGISTER_OP(name) \
.set_num_inputs(2) \ .set_num_inputs(2) \
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/parameter.h> #include <dmlc/parameter.h>
#include <nnvm/top/tensor.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_set> #include <unordered_set>
...@@ -16,7 +17,7 @@ namespace nnvm { ...@@ -16,7 +17,7 @@ namespace nnvm {
namespace top { namespace top {
/*! /*!
* \brief Parse keyword arguments as PType arguments and save to parsed * \brief Parse keyword arguments as PType arguments and save to parsed
* \tparam PType the arameter type. * \tparam PType the parameter type.
* \param attrs The attributes. * \param attrs The attributes.
*/ */
template<typename PType> template<typename PType>
...@@ -202,6 +203,28 @@ inline std::string attr_assign_error_msg(const NodeAttrs& attrs, ...@@ -202,6 +203,28 @@ inline std::string attr_assign_error_msg(const NodeAttrs& attrs,
} \ } \
} }
/*!
* \brief macro assign rhs shape to lhs
* Use macro so we can see the error file more clearly
* \param lhs lhs shape
* \param rhs rhs shape
*/
#define SHAPE_ASSIGN(lhs, rhs) \
if ((lhs).ndim() == 0) (lhs) = (rhs); \
else \
CHECK_EQ(lhs, rhs) << "shape inference inconsistent"; \
/*!
* \brief macro assign rhs type to lhs
* Use macro so we can see the error file more clearly
* \param lhs lhs type
* \param rhs rhs type
*/
#define DTYPE_ASSIGN(lhs, rhs) \
if ((lhs) == -1) (lhs) = (rhs); \
else \
CHECK_EQ(lhs, rhs) << "type inference inconsistent"; \
// simply return the shape as same // simply return the shape as same
inline bool SameShape(const NodeAttrs& attrs, inline bool SameShape(const NodeAttrs& attrs,
std::vector<TShape> *ishape, std::vector<TShape> *ishape,
...@@ -216,6 +239,28 @@ inline bool SameShape(const NodeAttrs& attrs, ...@@ -216,6 +239,28 @@ inline bool SameShape(const NodeAttrs& attrs,
return true; return true;
} }
// return shape from node attrs
inline bool ZeroShape(const NodeAttrs& attrs,
std::vector<TShape> *ishape,
std::vector<TShape> *oshape) {
const TShape& ts = dmlc::get<InitOpParam>(attrs.parsed).shape;
if (ts.ndim() != 0) {
SHAPE_ASSIGN(oshape->at(0), ts);
return true;
} else {
return false;
}
}
// return type from node attrs
inline bool ZeroType(const NodeAttrs& attrs,
std::vector<int> *iattr,
std::vector<int> *oattr) {
int dtype = dmlc::get<InitOpParam>(attrs.parsed).dtype;
DTYPE_ASSIGN(oattr->at(0), dtype);
return true;
}
} // namespace top } // namespace top
} // namespace nnvm } // namespace nnvm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment