Commit cf02f5c9 by tqchen Committed by Tianqi Chen

[Pass] Enable BackwardOp

parent 8ffa4ac3
# NNVM: Build deep learning system by parts # NNVM: Build deep learning system by parts
NNVM is not a deep learning library. It is a modular, lightweight library to NNVM is not a deep learning library. It is a modular, decentralized and lightweight library to
help build deep learning libraries efficiently. help build deep learning libraries efficiently.
## What is it ## What is it
...@@ -8,7 +8,7 @@ help build deep learning libraries efficiently. ...@@ -8,7 +8,7 @@ help build deep learning libraries efficiently.
While most deep learning systems offer end to end solutions, While most deep learning systems offer end to end solutions,
it is interesting to ask if we can actually assemble a deep learning system by parts. it is interesting to ask if we can actually assemble a deep learning system by parts.
The goal is to enable hackers can customize optimizations, target platforms and set of operators they care about. The goal is to enable hackers can customize optimizations, target platforms and set of operators they care about.
We believe that the modular system is an interesting direction. We believe that the decentralized modular system is an interesting direction.
The hope is that effective parts can be assembled together just like you assemble your own desktops. The hope is that effective parts can be assembled together just like you assemble your own desktops.
So the customized deep learning solution can be minimax, minimum in terms of dependencies, So the customized deep learning solution can be minimax, minimum in terms of dependencies,
while maxiziming the users' need. while maxiziming the users' need.
...@@ -18,7 +18,7 @@ computation graph optimization such as memory reduction, device allocation, ...@@ -18,7 +18,7 @@ computation graph optimization such as memory reduction, device allocation,
operator fusion while being agnostic to the operator operator fusion while being agnostic to the operator
interface defintion and how operators are executed. interface defintion and how operators are executed.
NNVM is inspired by LLVM, aiming to be an intermediate representation library NNVM is inspired by LLVM, aiming to be an intermediate representation library
for neural nets and computation graphs in general. for neural nets and computation graphs generation and optimizations.
## Deep learning system by parts ## Deep learning system by parts
......
...@@ -80,6 +80,19 @@ using FInferShape = FInferNodeEntryAttr<TShape>; ...@@ -80,6 +80,19 @@ using FInferShape = FInferNodeEntryAttr<TShape>;
*/ */
using FInferType = FInferNodeEntryAttr<int>; using FInferType = FInferNodeEntryAttr<int>;
/*!
* \brief Whether this op is an explicit backward operator
*
* If TIsBackwardOp is set to be true:
* - The first control_deps of the node points to the corresponding forward operator.
* - The outputs operator corresponds to exactly inputs of forward op one by one.
*
* \note Register under "TIsBackwardOp", default to false.
*
* This enables easier shape/type inference for backward operators for slice and reduction.
*/
using TIsBackwardOp = bool;
} // namespace nnvm } // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_ #endif // NNVM_OP_ATTR_TYPES_H_
...@@ -135,7 +135,7 @@ class Symbol { ...@@ -135,7 +135,7 @@ class Symbol {
* \return Symbol that can be used to call compose further. * \return Symbol that can be used to call compose further.
*/ */
static Symbol CreateFunctor(const Op* op, static Symbol CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string>&& attrs); std::unordered_map<std::string, std::string> attrs);
/*! /*!
* \brief create variable symbol node * \brief create variable symbol node
* \param name name of the variable * \param name name of the variable
......
...@@ -437,7 +437,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op ...@@ -437,7 +437,7 @@ std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption op
} }
Symbol Symbol::CreateFunctor(const Op* op, Symbol Symbol::CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string>&& attrs) { std::unordered_map<std::string, std::string> attrs) {
Symbol s; Symbol s;
NodePtr n = Node::Create(); NodePtr n = Node::Create();
n->op = op; n->op = op;
......
...@@ -10,18 +10,21 @@ ...@@ -10,18 +10,21 @@
namespace nnvm { namespace nnvm {
namespace pass { namespace pass {
template<typename AttrType> template<typename AttrType, typename IsNone>
Graph InferAttr(Graph &&ret, Graph InferAttr(Graph &&ret,
const AttrType def_value, const AttrType def_value,
const char* infer_name, const char* infer_name,
const char* arg_name, const char* arg_name,
const char* attr_key_name, const char* attr_key_name,
const char* attr_name, const char* attr_name,
const char* known_name) { const char* known_name,
IsNone fis_none) {
using AttrVector = std::vector<AttrType>; using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph(); const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape = static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name); Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
static auto& is_backward =
Op::GetAttr<TIsBackwardOp>("TIsBackwardOp");
// reshape shape vector // reshape shape vector
AttrVector rshape(idx.num_node_entries(), def_value); AttrVector rshape(idx.num_node_entries(), def_value);
...@@ -66,6 +69,19 @@ Graph InferAttr(Graph &&ret, ...@@ -66,6 +69,19 @@ Graph InferAttr(Graph &&ret,
if (finfer_shape.count(inode.source->op)) { if (finfer_shape.count(inode.source->op)) {
num_known += num_known +=
finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape); finfer_shape[inode.source->op](inode.source->attrs, ishape, oshape);
} else if (is_backward.get(inode.source->op, false)) {
// backward operator inference.
CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op";
const auto& fnode = idx[inode.control_deps[0]];
CHECK_EQ(fnode.inputs.size(), inode.source->num_outputs())
<< "BackwardOp need to correspond to the forward node";
bool known = true;
for (size_t i = 0; i < fnode.inputs.size(); ++i) {
*oshape[i] = rshape[idx.entry_id(fnode.inputs[i])];
if (fis_none(*oshape[i])) known = false;
}
num_known += known;
} }
} }
// set the shapes // set the shapes
...@@ -79,13 +95,10 @@ NNVM_REGISTER_PASS(InferShape) ...@@ -79,13 +95,10 @@ NNVM_REGISTER_PASS(InferShape)
.describe("Infer the shape of each node entries.") .describe("Infer the shape of each node entries.")
.set_body([](Graph ret) { .set_body([](Graph ret) {
return InferAttr<TShape>( return InferAttr<TShape>(
std::move(ret), std::move(ret), TShape(),
TShape(), "FInferShape", "shape_args", "shape_attr_key",
"FInferShape", "shape", "shape_num_known_nodes",
"shape_args", [](const TShape& s) { return s.ndim() == 0; });
"shape_attr_key",
"shape",
"shape_num_known_nodes");
}) })
.set_change_graph(false) .set_change_graph(false)
.provide_graph_attr("shape"); .provide_graph_attr("shape");
...@@ -94,13 +107,10 @@ NNVM_REGISTER_PASS(InferType) ...@@ -94,13 +107,10 @@ NNVM_REGISTER_PASS(InferType)
.describe("Infer the dtype of each node entries.") .describe("Infer the dtype of each node entries.")
.set_body([](Graph ret) { .set_body([](Graph ret) {
return InferAttr<int>( return InferAttr<int>(
std::move(ret), std::move(ret), 0,
0, "FInferType", "dtype_args", "dtype_attr_key",
"FInferType", "dtype", "dtype_num_known_nodes",
"dtype_args", [](const int t) { return t == -1; });
"dtype_attr_key",
"dtype",
"dtype_num_known_nodes");
}) })
.set_change_graph(false) .set_change_graph(false)
.provide_graph_attr("dtype"); .provide_graph_attr("dtype");
......
...@@ -14,12 +14,13 @@ void test_speed() { ...@@ -14,12 +14,13 @@ void test_speed() {
size_t rep = 1000; size_t rep = 1000;
size_t n = 1000; size_t n = 1000;
std::unordered_map<std::string, const nnvm::Symbol*> tmp; std::unordered_map<std::string, const nnvm::Symbol*> tmp;
std::unordered_map<std::string, std::string> kwargs;
std::vector<const nnvm::Symbol*> vec{2}; std::vector<const nnvm::Symbol*> vec{2};
std::string name = "xx"; std::string name = "xx";
for (size_t t = 0; t < rep; ++t) { for (size_t t = 0; t < rep; ++t) {
nnvm::Symbol s = nnvm::Symbol::CreateVariable("x"); nnvm::Symbol s = nnvm::Symbol::CreateVariable("x");
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
nnvm::Symbol nw = nnvm::Symbol::CreateFunctor(add, {}); nnvm::Symbol nw = nnvm::Symbol::CreateFunctor(add, kwargs);
vec[0] = &s; vec[0] = &s;
vec[1] =&s; vec[1] =&s;
tmp.clear(); tmp.clear();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment