Commit 8ffa4ac3 by Tianqi Chen

[Pass] enable infer type (#17)

parent 0081ad9a
......@@ -39,6 +39,21 @@ using JSONString = std::string;
*/
using ShapeVector = std::vector<TShape>;
/*!
* \brief The result holder of type of each NodeEntry in the graph.
* \note Stored under graph.attrs["dtype"], provided by Pass "InferType"
*
* \code
* Graph g = ApplyPass(src_graph, {"InferType"});
* const DTypeVector& types = g.GetAttr<ShapeVector>("dtype");
* // get shape by entry id
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferType
*/
using DTypeVector = std::vector<int>;
} // namespace nnvm
#endif // NNVM_GRAPH_ATTR_TYPES_H_
......@@ -51,24 +51,34 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
using FMutateInput = std::function<bool (const NodeAttrs& attrs, uint32_t index)>;
/*!
* \brief Inference function of certain type.
* \tparam AttrType The type of the attribute to be infered.
* \return whether all attributes are inferred.
*/
template<typename AttrType>
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
array_view<AttrType*> in_attrs,
array_view<AttrType*> out_attrs)>;
/*!
* \brief Shape inference function.
* Update the shapes given the input shape information.
* TShape.ndim() == 0 means the shape is still unknown.
*
* \param attrs The attributes of the node.
* \param in_shapes Array of shapes from the inputs.
* \param out_shapes Array of shapes from the outputs.
*
* \return Whether all the shapes are known.
*
* \note Register under "FInferShape",
* by default do not update any shapes.
*
* FInferShape is needed by shape inference
*/
using FInferShape = std::function<bool (const NodeAttrs& attrs,
array_view<TShape*> in_shapes,
array_view<TShape*> out_shapes)>;
using FInferShape = FInferNodeEntryAttr<TShape>;
/*!
* \brief Type inference function.
* Update the type given the known type information.
*
* \note Register under "FInferType",
* by default set all the output types to 0.
*/
using FInferType = FInferNodeEntryAttr<int>;
} // namespace nnvm
......
......@@ -71,6 +71,26 @@ inline Graph InferShape(Graph graph,
return ApplyPass(std::move(graph), {"InferShape"});
}
/*!
* \brief Infer types in the graph given the information.
* \param graph source graph
* \param shape_args The shapes of aruguments to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
*/
inline Graph InferType(Graph graph,
DTypeVector type_args = {},
std::string type_attr_key = "") {
if (type_args.size() != 0) {
graph.attrs["type_args"] = std::make_shared<any>(std::move(type_args));
}
if (type_attr_key.length() != 0) {
graph.attrs["type_attr_key"] = std::make_shared<any>(std::move(type_attr_key));
}
return ApplyPass(std::move(graph), {"InferType"});
}
} // namespace pass
} // namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
......@@ -13,6 +13,7 @@ namespace myproject {
using nnvm::FListInputNames;
using nnvm::FMutateInput;
using nnvm::FInferShape;
using nnvm::FInferType;
using nnvm::NodeAttrs;
using nnvm::TShape;
using nnvm::array_view;
......@@ -56,6 +57,28 @@ NNVM_REGISTER_OP(reshape)
return true;
});
NNVM_REGISTER_OP(cast)
.describe("cast source type to target")
.set_num_inputs(1)
.set_attr_parser(
[](NodeAttrs* attrs) {
// parse attr parser to get target attribute
int dtype;
std::istringstream is(attrs->dict.at("dtype"));
CHECK(is >> dtype);
attrs->parsed = std::move(dtype);
})
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInferType>(
"FInferType", [](const NodeAttrs& attrs,
array_view<int*> itype,
array_view<int*> otype) {
*otype[0] = nnvm::get<int>(attrs.parsed);
return true;
});
NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
......
/*!
* Copyright (c) 2016 by Contributors
* \file infer_shape.cc
* \brief Inference the shapes given
* \brief Inference the shapes given existin information.
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
......@@ -10,14 +10,23 @@
namespace nnvm {
namespace pass {
Graph InferShape(Graph ret) {
template<typename AttrType>
Graph InferAttr(Graph &&ret,
const AttrType def_value,
const char* infer_name,
const char* arg_name,
const char* attr_key_name,
const char* attr_name,
const char* known_name) {
using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape");
static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
// reshape shape vector
ShapeVector rshape(idx.num_node_entries());
AttrVector rshape(idx.num_node_entries(), def_value);
if (ret.attrs.count("shape_args") != 0) {
const ShapeVector& shape_args = ret.GetAttr<ShapeVector>("shape_args");
if (ret.attrs.count(arg_name) != 0) {
const AttrVector& shape_args = ret.GetAttr<AttrVector>(arg_name);
CHECK_LE(shape_args.size(), idx.arg_nodes().size())
<< "shape args is more than number of arguments";
for (size_t i = 0; i < shape_args.size(); ++i) {
......@@ -25,12 +34,12 @@ Graph InferShape(Graph ret) {
}
}
std::string shape_attr_key;
if (ret.attrs.count("shape_attr_key") != 0) {
shape_attr_key = ret.GetAttr<std::string>("shape_attr_key");
if (ret.attrs.count(attr_key_name) != 0) {
shape_attr_key = ret.GetAttr<std::string>(attr_key_name);
}
// temp space for shape inference.
std::vector<TShape*> ishape, oshape;
std::vector<AttrType*> ishape, oshape;
// number of completed nodes
size_t num_known = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
......@@ -41,7 +50,7 @@ Graph InferShape(Graph ret) {
if (it != inode.source->attrs.dict.end()) {
CHECK_EQ(inode.source->num_outputs(), 1);
std::istringstream is(it->second);
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid shape attribute";
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid attribute";
}
}
continue;
......@@ -60,19 +69,44 @@ Graph InferShape(Graph ret) {
}
}
// set the shapes
ret.attrs["shape"] = std::make_shared<any>(std::move(rshape));
ret.attrs[attr_name] = std::make_shared<any>(std::move(rshape));
// number of nodes who knows the shape.
ret.attrs["shape_num_known_nodes"] = std::make_shared<any>(num_known);
ret.attrs[known_name] = std::make_shared<any>(num_known);
return ret;
}
NNVM_REGISTER_PASS(InferShape)
.describe("Infer the shape of each node entries.")
.set_body(InferShape)
.set_body([](Graph ret) {
return InferAttr<TShape>(
std::move(ret),
TShape(),
"FInferShape",
"shape_args",
"shape_attr_key",
"shape",
"shape_num_known_nodes");
})
.set_change_graph(false)
.provide_graph_attr("shape");
NNVM_REGISTER_PASS(InferType)
.describe("Infer the dtype of each node entries.")
.set_body([](Graph ret) {
return InferAttr<int>(
std::move(ret),
0,
"FInferType",
"dtype_args",
"dtype_attr_key",
"dtype",
"dtype_num_known_nodes");
})
.set_change_graph(false)
.provide_graph_attr("dtype");
DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);
DMLC_JSON_ENABLE_ANY(DTypeVector, list_int);
} // namespace pass
} // namespace nnvm
......@@ -49,9 +49,37 @@ def test_infer_shape():
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
def test_infer_shape():
x = sym.Variable('x', shape=(4, 2))
y = sym.add(x, x, name='add1')
y = sym.reshape(y, target=(2, 4), name="reshape1")
g = graph.create(y)
g._set_json_attr("shape_attr_key", "shape")
g = g.apply('InferShape')
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert g.json_attr('shape')[jnode_row_ptr[nindex["reshape1"]]] == [2, 4]
assert g.json_attr('shape')[jnode_row_ptr[nindex["add1"]]] == [4, 2]
def test_infer_type():
x = sym.Variable('x')
y = sym.add(x, x, name='add1')
y = sym.cast(y, dtype=1, name="cast1")
g = graph.create(y)
g = g.apply('InferType')
jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
jnodes = jgraph['nodes']
jnode_row_ptr = jgraph['node_row_ptr']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert g.json_attr('dtype')[jnode_row_ptr[nindex["cast1"]]] == 1
assert g.json_attr('dtype')[jnode_row_ptr[nindex["add1"]]] == 0
if __name__ == "__main__":
test_order_mutation_pass()
test_graph_json_attr()
test_json_pass()
test_infer_shape()
test_infer_type()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment