Commit 0081ad9a by Tianqi Chen

[Pass] Finish infershape testcase (#16)

parent bd20bfd8
......@@ -23,7 +23,7 @@ namespace nnvm {
* \param src The graph to be transformed.
* \return The generated graph.
*/
typedef std::function<Graph (const Graph& src)> PassFunction;
typedef std::function<Graph (Graph src)> PassFunction;
/*!
* \brief Apply a series of pass transformations on g.
......@@ -31,7 +31,7 @@ typedef std::function<Graph (const Graph& src)> PassFunction;
* \param pass The name of pass to be applied.
* \return The transformed graph
*/
Graph ApplyPass(const Graph& src,
Graph ApplyPass(Graph src,
const std::vector<std::string>& pass);
/*!
......
/*!
* Copyright (c) 2016 by Contributors
* \file pass_functions.h
* \brief Pass functions that simply redirect the calls to ApplyPass
*
* This file serves as documentation on how to use functions implemented in "src/pass".
* It is totally optional to add these functions when you add a new pass, since
* ApplyPass can be directly called.
*/
#ifndef NNVM_PASS_FUNCTIONS_H_
#define NNVM_PASS_FUNCTIONS_H_
#include <string>
#include <memory>
#include "./base.h"
#include "./pass.h"
#include "./graph_attr_types.h"
namespace nnvm {
namespace pass {
/*!
* \brief Load a graph from JSON string, redirects to "LoadJSON" pass.
* \param json_str The json string.
* \return Loaded graph.
*/
inline Graph LoadJSON(const std::string& json_str) {
Graph ret;
ret.attrs["json"] = std::make_shared<any>(json_str);
return ApplyPass(ret, {"LoadJSON"});
}
/*!
* \brief Save a graph to json, redirects to "SaveJSON" pass.
* \param graph The to be saved.
* \return The json string.
*/
inline std::string SaveJSON(Graph graph) {
Graph ret = ApplyPass(std::move(graph), {"SaveJSON"});
return ret.GetAttr<std::string>("json");
}
/*!
* \brief Add control flow dependencies between nodes
* To correctly order mutation and read to resolve
* write after read problem and read after write problems.
* \param src source graph
* \return A graph that added control flow dependencies.
*/
inline Graph OrderMutation(Graph src) {
return ApplyPass(std::move(src), {"OrderMutation"});
}
/*!
* \brief Infer shapes in the graph given the information.
* \param graph source graph
* \param shape_args The shapes of aruguments to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
*/
inline Graph InferShape(Graph graph,
ShapeVector shape_args = {},
std::string shape_attr_key = "") {
if (shape_args.size() != 0) {
graph.attrs["shape_args"] = std::make_shared<any>(std::move(shape_args));
}
if (shape_attr_key.length() != 0) {
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
}
return ApplyPass(std::move(graph), {"InferShape"});
}
} // namespace pass
} // namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
......@@ -22,7 +22,7 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) {
return nullptr;
}
Graph ApplyPass(const Graph& src,
Graph ApplyPass(Graph g,
const std::vector<std::string>& pass) {
std::vector<const PassFunctionReg*> fpass;
for (auto& name : pass) {
......@@ -32,11 +32,9 @@ Graph ApplyPass(const Graph& src,
fpass.push_back(reg);
}
Graph g;
const Graph* s = &src;
for (auto r : fpass) {
for (auto& dep : r->graph_attr_dependency) {
if (s->attrs.count(dep) == 0) {
if (g.attrs.count(dep) == 0) {
auto* pass_dep = FindPassDep(dep);
std::string msg;
if (pass_dep != nullptr) {
......@@ -48,8 +46,7 @@ Graph ApplyPass(const Graph& src,
<< msg;
}
}
g = r->body(*s);
s = &g;
g = r->body(std::move(g));
}
return g;
}
......
......@@ -4,6 +4,7 @@
#include <nnvm/base.h>
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/node.h>
#include <nnvm/graph_attr_types.h>
#include <utility>
......@@ -30,6 +31,31 @@ inline bool SameShape(const NodeAttrs& attrs,
return true;
}
// simple demonstration of reshape.
NNVM_REGISTER_OP(reshape)
.describe("reshape source to target shape")
.set_num_inputs(1)
.set_attr_parser(
[](NodeAttrs* attrs) {
// parse attr parser to get target attribute
TShape target;
std::istringstream is(attrs->dict.at("target"));
CHECK(is >> target);
attrs->parsed = std::move(target);
})
.attr<FInferShape>(
"FInferShape", [] (const NodeAttrs& attrs,
array_view<TShape*> ishape,
array_view<TShape*> oshape) {
// get parsed attribute
const TShape& target = nnvm::get<TShape>(attrs.parsed);
*oshape[0] = target;
if (ishape[0]->ndim() == 0) return false;
CHECK_EQ(ishape[0]->Size(), target.Size())
<< "Reshape op: source target shape mismatch";
return true;
});
NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
......
......@@ -10,19 +10,42 @@
namespace nnvm {
namespace pass {
Graph InferShape(const Graph& src) {
Graph ret = src;
Graph InferShape(Graph ret) {
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape");
// reshape shape vector
ShapeVector rshape(idx.num_node_entries());
if (ret.attrs.count("shape_args") != 0) {
const ShapeVector& shape_args = ret.GetAttr<ShapeVector>("shape_args");
CHECK_LE(shape_args.size(), idx.arg_nodes().size())
<< "shape args is more than number of arguments";
for (size_t i = 0; i < shape_args.size(); ++i) {
rshape[idx.entry_id(idx.arg_nodes()[i], 0)] = shape_args[i];
}
}
std::string shape_attr_key;
if (ret.attrs.count("shape_attr_key") != 0) {
shape_attr_key = ret.GetAttr<std::string>("shape_attr_key");
}
// temp space for shape inference.
std::vector<TShape*> ishape, oshape;
// number of completed nodes
size_t num_known = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
if (inode.source->is_variable()) continue;
if (inode.source->is_variable()) {
if (shape_attr_key.length() != 0) {
auto it = inode.source->attrs.dict.find(shape_attr_key);
if (it != inode.source->attrs.dict.end()) {
CHECK_EQ(inode.source->num_outputs(), 1);
std::istringstream is(it->second);
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid shape attribute";
}
}
continue;
}
ishape.resize(inode.inputs.size());
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = &rshape[idx.entry_id(inode.inputs[i])];
......@@ -43,5 +66,13 @@ Graph InferShape(const Graph& src) {
return ret;
}
NNVM_REGISTER_PASS(InferShape)
.describe("Infer the shape of each node entries.")
.set_body(InferShape)
.set_change_graph(false)
.provide_graph_attr("shape");
DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);
} // namespace pass
} // namespace nnvm
/*!
* Copyright (c) 2016 by Contributors
* \file saveload_json.cc
* \file order_mutation.cc
* \brief Add control flow dependencies between nodes
* To correctly order mutation and read to resolve
* write after read problem and read after write problems.
......
......@@ -149,7 +149,7 @@ struct JSONGraph {
};
// Load a graph from JSON file.
Graph LoadJSON(const Graph& src) {
Graph LoadJSON(Graph src) {
CHECK_NE(src.attrs.count("json"), 0)
<< "Load JSON require json to be presented.";
const std::string &json_str =
......@@ -188,7 +188,7 @@ Graph LoadJSON(const Graph& src) {
}
// save a graph to json
Graph SaveJSON(const Graph& src) {
Graph SaveJSON(Graph src) {
JSONGraph jgraph;
std::unordered_map<Node*, uint32_t> node2index;
jgraph.node_row_ptr.push_back(0);
......
......@@ -4,6 +4,7 @@
#include <nnvm/tuple.h>
#include <nnvm/c_api.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass_functions.h>
#include <dmlc/timer.h>
#include <string>
......
......@@ -35,7 +35,23 @@ def test_order_mutation_pass():
assert nindex['add1'] in jnodes[nindex['assign']]['control_deps']
assert jnodes[nindex['assign']]['inputs'][0][2] == 1
def test_infer_shape():
x = sym.Variable('x', shape=(4, 2))
y = sym.add(x, x, name='add1')
y = sym.reshape(y, target=(2, 4), name="reshape1")
g = graph.create(y)
g._set_json_attr("shape_attr_key", "shape")
g = g.apply('InferShape')
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
if __name__ == "__main__":
test_order_mutation_pass()
test_graph_json_attr()
test_json_pass()
test_infer_shape()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment