Commit 0081ad9a by Tianqi Chen

[Pass] Finish infershape testcase (#16)

parent bd20bfd8
...@@ -23,7 +23,7 @@ namespace nnvm { ...@@ -23,7 +23,7 @@ namespace nnvm {
* \param src The graph to be transformed. * \param src The graph to be transformed.
* \return The generated graph. * \return The generated graph.
*/ */
typedef std::function<Graph (const Graph& src)> PassFunction; typedef std::function<Graph (Graph src)> PassFunction;
/*! /*!
* \brief Apply a series of pass transformations on g. * \brief Apply a series of pass transformations on g.
...@@ -31,7 +31,7 @@ typedef std::function<Graph (const Graph& src)> PassFunction; ...@@ -31,7 +31,7 @@ typedef std::function<Graph (const Graph& src)> PassFunction;
* \param pass The name of pass to be applied. * \param pass The name of pass to be applied.
* \return The transformed graph * \return The transformed graph
*/ */
Graph ApplyPass(const Graph& src, Graph ApplyPass(Graph src,
const std::vector<std::string>& pass); const std::vector<std::string>& pass);
/*! /*!
......
/*!
* Copyright (c) 2016 by Contributors
* \file pass_functions.h
* \brief Pass functions that simply redirect the calls to ApplyPass
*
* This file serves as documentation on how to use functions implemented in "src/pass".
* It is totally optional to add these functions when you add a new pass, since
* ApplyPass can be directly called.
*/
#ifndef NNVM_PASS_FUNCTIONS_H_
#define NNVM_PASS_FUNCTIONS_H_
#include <string>
#include <memory>
#include "./base.h"
#include "./pass.h"
#include "./graph_attr_types.h"
namespace nnvm {
namespace pass {
/*!
* \brief Load a graph from JSON string, redirects to "LoadJSON" pass.
* \param json_str The json string.
* \return Loaded graph.
*/
inline Graph LoadJSON(const std::string& json_str) {
Graph ret;
ret.attrs["json"] = std::make_shared<any>(json_str);
return ApplyPass(ret, {"LoadJSON"});
}
/*!
* \brief Save a graph to json, redirects to "SaveJSON" pass.
* \param graph The to be saved.
* \return The json string.
*/
inline std::string SaveJSON(Graph graph) {
Graph ret = ApplyPass(std::move(graph), {"SaveJSON"});
return ret.GetAttr<std::string>("json");
}
/*!
* \brief Add control flow dependencies between nodes
* To correctly order mutation and read to resolve
* write after read problem and read after write problems.
* \param src source graph
* \return A graph that added control flow dependencies.
*/
inline Graph OrderMutation(Graph src) {
return ApplyPass(std::move(src), {"OrderMutation"});
}
/*!
* \brief Infer shapes in the graph given the information.
* \param graph source graph
* \param shape_args The shapes of aruguments to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
*/
inline Graph InferShape(Graph graph,
ShapeVector shape_args = {},
std::string shape_attr_key = "") {
if (shape_args.size() != 0) {
graph.attrs["shape_args"] = std::make_shared<any>(std::move(shape_args));
}
if (shape_attr_key.length() != 0) {
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
}
return ApplyPass(std::move(graph), {"InferShape"});
}
} // namespace pass
} // namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
...@@ -22,7 +22,7 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) { ...@@ -22,7 +22,7 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) {
return nullptr; return nullptr;
} }
Graph ApplyPass(const Graph& src, Graph ApplyPass(Graph g,
const std::vector<std::string>& pass) { const std::vector<std::string>& pass) {
std::vector<const PassFunctionReg*> fpass; std::vector<const PassFunctionReg*> fpass;
for (auto& name : pass) { for (auto& name : pass) {
...@@ -32,11 +32,9 @@ Graph ApplyPass(const Graph& src, ...@@ -32,11 +32,9 @@ Graph ApplyPass(const Graph& src,
fpass.push_back(reg); fpass.push_back(reg);
} }
Graph g;
const Graph* s = &src;
for (auto r : fpass) { for (auto r : fpass) {
for (auto& dep : r->graph_attr_dependency) { for (auto& dep : r->graph_attr_dependency) {
if (s->attrs.count(dep) == 0) { if (g.attrs.count(dep) == 0) {
auto* pass_dep = FindPassDep(dep); auto* pass_dep = FindPassDep(dep);
std::string msg; std::string msg;
if (pass_dep != nullptr) { if (pass_dep != nullptr) {
...@@ -48,8 +46,7 @@ Graph ApplyPass(const Graph& src, ...@@ -48,8 +46,7 @@ Graph ApplyPass(const Graph& src,
<< msg; << msg;
} }
} }
g = r->body(*s); g = r->body(std::move(g));
s = &g;
} }
return g; return g;
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <nnvm/base.h> #include <nnvm/base.h>
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
#include <nnvm/node.h>
#include <nnvm/graph_attr_types.h> #include <nnvm/graph_attr_types.h>
#include <utility> #include <utility>
...@@ -30,6 +31,31 @@ inline bool SameShape(const NodeAttrs& attrs, ...@@ -30,6 +31,31 @@ inline bool SameShape(const NodeAttrs& attrs,
return true; return true;
} }
// simple demonstration of reshape.
NNVM_REGISTER_OP(reshape)
.describe("reshape source to target shape")
.set_num_inputs(1)
.set_attr_parser(
[](NodeAttrs* attrs) {
// parse attr parser to get target attribute
TShape target;
std::istringstream is(attrs->dict.at("target"));
CHECK(is >> target);
attrs->parsed = std::move(target);
})
.attr<FInferShape>(
"FInferShape", [] (const NodeAttrs& attrs,
array_view<TShape*> ishape,
array_view<TShape*> oshape) {
// get parsed attribute
const TShape& target = nnvm::get<TShape>(attrs.parsed);
*oshape[0] = target;
if (ishape[0]->ndim() == 0) return false;
CHECK_EQ(ishape[0]->Size(), target.Size())
<< "Reshape op: source target shape mismatch";
return true;
});
NNVM_REGISTER_OP(add) NNVM_REGISTER_OP(add)
.describe("add two data together") .describe("add two data together")
.set_num_inputs(2) .set_num_inputs(2)
......
...@@ -10,19 +10,42 @@ ...@@ -10,19 +10,42 @@
namespace nnvm { namespace nnvm {
namespace pass { namespace pass {
Graph InferShape(const Graph& src) { Graph InferShape(Graph ret) {
Graph ret = src;
const IndexedGraph& idx = ret.indexed_graph(); const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape"); static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape");
// reshape shape vector // reshape shape vector
ShapeVector rshape(idx.num_node_entries()); ShapeVector rshape(idx.num_node_entries());
if (ret.attrs.count("shape_args") != 0) {
const ShapeVector& shape_args = ret.GetAttr<ShapeVector>("shape_args");
CHECK_LE(shape_args.size(), idx.arg_nodes().size())
<< "shape args is more than number of arguments";
for (size_t i = 0; i < shape_args.size(); ++i) {
rshape[idx.entry_id(idx.arg_nodes()[i], 0)] = shape_args[i];
}
}
std::string shape_attr_key;
if (ret.attrs.count("shape_attr_key") != 0) {
shape_attr_key = ret.GetAttr<std::string>("shape_attr_key");
}
// temp space for shape inference. // temp space for shape inference.
std::vector<TShape*> ishape, oshape; std::vector<TShape*> ishape, oshape;
// number of completed nodes // number of completed nodes
size_t num_known = 0; size_t num_known = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
if (inode.source->is_variable()) continue; if (inode.source->is_variable()) {
if (shape_attr_key.length() != 0) {
auto it = inode.source->attrs.dict.find(shape_attr_key);
if (it != inode.source->attrs.dict.end()) {
CHECK_EQ(inode.source->num_outputs(), 1);
std::istringstream is(it->second);
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid shape attribute";
}
}
continue;
}
ishape.resize(inode.inputs.size()); ishape.resize(inode.inputs.size());
for (uint32_t i = 0; i < ishape.size(); ++i) { for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = &rshape[idx.entry_id(inode.inputs[i])]; ishape[i] = &rshape[idx.entry_id(inode.inputs[i])];
...@@ -43,5 +66,13 @@ Graph InferShape(const Graph& src) { ...@@ -43,5 +66,13 @@ Graph InferShape(const Graph& src) {
return ret; return ret;
} }
NNVM_REGISTER_PASS(InferShape)
.describe("Infer the shape of each node entries.")
.set_body(InferShape)
.set_change_graph(false)
.provide_graph_attr("shape");
DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);
} // namespace pass } // namespace pass
} // namespace nnvm } // namespace nnvm
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file saveload_json.cc * \file order_mutation.cc
* \brief Add control flow dependencies between nodes * \brief Add control flow dependencies between nodes
* To correctly order mutation and read to resolve * To correctly order mutation and read to resolve
* write after read problem and read after write problems. * write after read problem and read after write problems.
......
...@@ -149,7 +149,7 @@ struct JSONGraph { ...@@ -149,7 +149,7 @@ struct JSONGraph {
}; };
// Load a graph from JSON file. // Load a graph from JSON file.
Graph LoadJSON(const Graph& src) { Graph LoadJSON(Graph src) {
CHECK_NE(src.attrs.count("json"), 0) CHECK_NE(src.attrs.count("json"), 0)
<< "Load JSON require json to be presented."; << "Load JSON require json to be presented.";
const std::string &json_str = const std::string &json_str =
...@@ -188,7 +188,7 @@ Graph LoadJSON(const Graph& src) { ...@@ -188,7 +188,7 @@ Graph LoadJSON(const Graph& src) {
} }
// save a graph to json // save a graph to json
Graph SaveJSON(const Graph& src) { Graph SaveJSON(Graph src) {
JSONGraph jgraph; JSONGraph jgraph;
std::unordered_map<Node*, uint32_t> node2index; std::unordered_map<Node*, uint32_t> node2index;
jgraph.node_row_ptr.push_back(0); jgraph.node_row_ptr.push_back(0);
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <nnvm/tuple.h> #include <nnvm/tuple.h>
#include <nnvm/c_api.h> #include <nnvm/c_api.h>
#include <nnvm/graph_attr_types.h> #include <nnvm/graph_attr_types.h>
#include <nnvm/pass_functions.h>
#include <dmlc/timer.h> #include <dmlc/timer.h>
#include <string> #include <string>
......
...@@ -35,7 +35,23 @@ def test_order_mutation_pass(): ...@@ -35,7 +35,23 @@ def test_order_mutation_pass():
assert nindex['add1'] in jnodes[nindex['assign']]['control_deps'] assert nindex['add1'] in jnodes[nindex['assign']]['control_deps']
assert jnodes[nindex['assign']]['inputs'][0][2] == 1 assert jnodes[nindex['assign']]['inputs'][0][2] == 1
def test_infer_shape():
x = sym.Variable('x', shape=(4, 2))
y = sym.add(x, x, name='add1')
y = sym.reshape(y, target=(2, 4), name="reshape1")
g = graph.create(y)
g._set_json_attr("shape_attr_key", "shape")
g = g.apply('InferShape')
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
if __name__ == "__main__": if __name__ == "__main__":
test_order_mutation_pass() test_order_mutation_pass()
test_graph_json_attr() test_graph_json_attr()
test_json_pass() test_json_pass()
test_infer_shape()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment