Commit 2c0b79ae by Minjie Wang Committed by Tianqi Chen

ApplyPass -> ApplyPasses; Refactored infer pass; (#43)

* ApplyPass -> ApplyPasses; Refactored infer pass;

* lint fix
parent 24f1999c
...@@ -329,16 +329,16 @@ NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, ...@@ -329,16 +329,16 @@ NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
const char* key, const char* key,
SymbolHandle list); SymbolHandle list);
/*! /*!
* \brief Apply pass on the src graph. * \brief Apply passes on the src graph.
* \param src The source graph handle. * \param src The source graph handle.
* \param num_pass The number of pass to be applied. * \param num_pass The number of pass to be applied.
* \param pass_names The names of the pass. * \param pass_names The names of the pass.
* \param dst The result graph. * \param dst The result graph.
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
NNVM_DLL int NNGraphApplyPass(GraphHandle src, NNVM_DLL int NNGraphApplyPasses(GraphHandle src,
nn_uint num_pass, nn_uint num_pass,
const char** pass_names, const char** pass_names,
GraphHandle *dst); GraphHandle *dst);
#endif // NNVM_C_API_H_ #endif // NNVM_C_API_H_
...@@ -179,11 +179,11 @@ class IndexedGraph { ...@@ -179,11 +179,11 @@ class IndexedGraph {
* \param other The source graph. * \param other The source graph.
*/ */
explicit IndexedGraph(const Graph& other); explicit IndexedGraph(const Graph& other);
// node pointers in CSR structure. // Node pointers in CSR structure.
std::vector<Node> nodes_; std::vector<Node> nodes_;
// index all to input nodes // Index to all input nodes.
std::vector<uint32_t> input_nodes_; std::vector<uint32_t> input_nodes_;
// index to mutable input nodes // Index to all mutable input nodes.
std::unordered_set<uint32_t> mutable_input_nodes_; std::unordered_set<uint32_t> mutable_input_nodes_;
// space to store the outputs entries // space to store the outputs entries
std::vector<NodeEntry> outputs_; std::vector<NodeEntry> outputs_;
......
...@@ -18,7 +18,7 @@ namespace nnvm { ...@@ -18,7 +18,7 @@ namespace nnvm {
* \note Stored under ret.attrs["json"], provided by Pass "SaveJSON" * \note Stored under ret.attrs["json"], provided by Pass "SaveJSON"
* \code * \code
* Graph ret = ApplyPass(src_graph, {"SaveJSON"}); * Graph ret = ApplyPass(src_graph, "SaveJSON");
* const JSONString& json = ret.GetAttr<JSONString>("shape"); * const JSONString& json = ret.GetAttr<JSONString>("shape");
* \endcode * \endcode
*/ */
...@@ -29,7 +29,7 @@ using JSONString = std::string; ...@@ -29,7 +29,7 @@ using JSONString = std::string;
* \note Stored under graph.attrs["shape"], provided by Pass "InferShape" * \note Stored under graph.attrs["shape"], provided by Pass "InferShape"
* *
* \code * \code
* Graph g = ApplyPass(src_graph, {"InferShape"}); * Graph g = ApplyPass(src_graph, "InferShape");
* const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape"); * const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape");
* // get shape by entry id * // get shape by entry id
* TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)]; * TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)];
...@@ -44,7 +44,7 @@ using ShapeVector = std::vector<TShape>; ...@@ -44,7 +44,7 @@ using ShapeVector = std::vector<TShape>;
* \note Stored under graph.attrs["dtype"], provided by Pass "InferType" * \note Stored under graph.attrs["dtype"], provided by Pass "InferType"
* *
* \code * \code
* Graph g = ApplyPass(src_graph, {"InferType"}); * Graph g = ApplyPass(src_graph, "InferType");
* const DTypeVector& types = g.GetAttr<DTypeVector>("dtype"); * const DTypeVector& types = g.GetAttr<DTypeVector>("dtype");
* // get shape by entry id * // get shape by entry id
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)]; * int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
...@@ -59,7 +59,7 @@ using DTypeVector = std::vector<int>; ...@@ -59,7 +59,7 @@ using DTypeVector = std::vector<int>;
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice" * \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice"
* *
* \code * \code
* Graph g = ApplyPass(src_graph, {"PlaceDevice"}); * Graph g = ApplyPass(src_graph, "PlaceDevice");
* const &device = g.GetAttr<DeviceVector>("device"); * const &device = g.GetAttr<DeviceVector>("device");
* // get device by node_id * // get device by node_id
* int device_type = device[g.indexed_graph().node_id(my_node)]; * int device_type = device[g.indexed_graph().node_id(my_node)];
...@@ -83,7 +83,7 @@ using DeviceAssignMap = std::unordered_map<std::string, int>; ...@@ -83,7 +83,7 @@ using DeviceAssignMap = std::unordered_map<std::string, int>;
* If the storage id is -1 then the storage is not assigned. * If the storage id is -1 then the storage is not assigned.
* *
* \code * \code
* Graph g = ApplyPass(src_graph, {"PlanMemory"}); * Graph g = ApplyPass(src_graph, "PlanMemory");
* const &storage = g.GetAttr<StorageVector>("storage"); * const &storage = g.GetAttr<StorageVector>("storage");
* // get storage id by entry * // get storage id by entry
* int storage_id = storage[g.indexed_graph().entry_id(my_entry)]; * int storage_id = storage[g.indexed_graph().entry_id(my_entry)];
......
...@@ -29,11 +29,22 @@ typedef std::function<Graph (Graph src)> PassFunction; ...@@ -29,11 +29,22 @@ typedef std::function<Graph (Graph src)> PassFunction;
/*! /*!
* \brief Apply a series of pass transformations on the input graph. * \brief Apply a series of pass transformations on the input graph.
* \param src The graph to be transformed. * \param src The graph to be transformed.
* \param passes A list of pass names to be applied.
* \return The transformed graph
*/
Graph ApplyPasses(Graph src,
const std::vector<std::string>& passes);
/*!
* \brief Apply one pass to the graph.
* \param src The graph to be transformed.
* \param pass The name of pass to be applied. * \param pass The name of pass to be applied.
* \return The transformed graph. * \return The transformed graph.
*/ */
Graph ApplyPass(Graph src, inline Graph ApplyPass(Graph src, const std::string& pass) {
const std::vector<std::string>& pass); return ApplyPasses(src, {pass});
}
/*! /*!
* \brief Registry entry for DataIterator factory functions. * \brief Registry entry for DataIterator factory functions.
......
...@@ -28,7 +28,7 @@ namespace pass { ...@@ -28,7 +28,7 @@ namespace pass {
inline Graph LoadJSON(const std::string& json_str) { inline Graph LoadJSON(const std::string& json_str) {
Graph ret; Graph ret;
ret.attrs["json"] = std::make_shared<any>(json_str); ret.attrs["json"] = std::make_shared<any>(json_str);
return ApplyPass(ret, {"LoadJSON"}); return ApplyPass(ret, "LoadJSON");
} }
/*! /*!
...@@ -37,7 +37,7 @@ inline Graph LoadJSON(const std::string& json_str) { ...@@ -37,7 +37,7 @@ inline Graph LoadJSON(const std::string& json_str) {
* \return The json string. * \return The json string.
*/ */
inline std::string SaveJSON(Graph graph) { inline std::string SaveJSON(Graph graph) {
Graph ret = ApplyPass(std::move(graph), {"SaveJSON"}); Graph ret = ApplyPass(std::move(graph), "SaveJSON");
return ret.GetAttr<std::string>("json"); return ret.GetAttr<std::string>("json");
} }
...@@ -52,7 +52,7 @@ inline std::string SaveJSON(Graph graph) { ...@@ -52,7 +52,7 @@ inline std::string SaveJSON(Graph graph) {
* \return A graph with proper control flow dependencies added. * \return A graph with proper control flow dependencies added.
*/ */
inline Graph OrderMutation(Graph src) { inline Graph OrderMutation(Graph src) {
return ApplyPass(std::move(src), {"OrderMutation"}); return ApplyPass(std::move(src), "OrderMutation");
} }
/*! /*!
...@@ -73,7 +73,7 @@ inline Graph InferShape(Graph graph, ...@@ -73,7 +73,7 @@ inline Graph InferShape(Graph graph,
if (shape_attr_key.length() != 0) { if (shape_attr_key.length() != 0) {
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key)); graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
} }
return ApplyPass(std::move(graph), {"InferShape"}); return ApplyPass(std::move(graph), "InferShape");
} }
/*! /*!
...@@ -94,7 +94,7 @@ inline Graph InferType(Graph graph, ...@@ -94,7 +94,7 @@ inline Graph InferType(Graph graph,
if (dtype_attr_key.length() != 0) { if (dtype_attr_key.length() != 0) {
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(dtype_attr_key)); graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(dtype_attr_key));
} }
return ApplyPass(std::move(graph), {"InferType"}); return ApplyPass(std::move(graph), "InferType");
} }
/*! /*!
...@@ -118,7 +118,7 @@ inline Graph PlaceDevice(Graph graph, ...@@ -118,7 +118,7 @@ inline Graph PlaceDevice(Graph graph,
graph.attrs["device_group_attr_key"] = std::make_shared<any>(std::move(device_group_attr_key)); graph.attrs["device_group_attr_key"] = std::make_shared<any>(std::move(device_group_attr_key));
graph.attrs["device_assign_map"] = std::make_shared<any>(std::move(device_assign_map)); graph.attrs["device_assign_map"] = std::make_shared<any>(std::move(device_assign_map));
graph.attrs["device_copy_op"] = std::make_shared<any>(std::move(device_copy_op)); graph.attrs["device_copy_op"] = std::make_shared<any>(std::move(device_copy_op));
return ApplyPass(std::move(graph), {"PlaceDevice"}); return ApplyPass(std::move(graph), "PlaceDevice");
} }
/*! /*!
...@@ -149,7 +149,7 @@ inline Graph Gradient( ...@@ -149,7 +149,7 @@ inline Graph Gradient(
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun); graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
} }
return ApplyPass(std::move(graph), {"Gradient"}); return ApplyPass(std::move(graph), "Gradient");
} }
} // namespace pass } // namespace pass
......
...@@ -113,7 +113,7 @@ class Graph(object): ...@@ -113,7 +113,7 @@ class Graph(object):
cpass = c_array(ctypes.c_char_p, [c_str(key) for key in passes]) cpass = c_array(ctypes.c_char_p, [c_str(key) for key in passes])
ghandle = GraphHandle() ghandle = GraphHandle()
npass = nn_uint(len(passes)) npass = nn_uint(len(passes))
check_call(_LIB.NNGraphApplyPass(self.handle, npass, cpass, ctypes.byref(ghandle))) check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
return Graph(ghandle) return Graph(ghandle)
......
...@@ -82,17 +82,17 @@ int NNGraphGetJSONAttr(GraphHandle handle, ...@@ -82,17 +82,17 @@ int NNGraphGetJSONAttr(GraphHandle handle,
API_END(); API_END();
} }
int NNGraphApplyPass(GraphHandle src, int NNGraphApplyPasses(GraphHandle src,
nn_uint num_pass, nn_uint num_pass,
const char** pass_names, const char** pass_names,
GraphHandle *dst) { GraphHandle *dst) {
Graph* g = new Graph(); Graph* g = new Graph();
API_BEGIN(); API_BEGIN();
std::vector<std::string> vpass; std::vector<std::string> vpass;
for (nn_uint i = 0; i < num_pass; ++i) { for (nn_uint i = 0; i < num_pass; ++i) {
vpass.emplace_back(std::string(pass_names[i])); vpass.emplace_back(std::string(pass_names[i]));
} }
*g = ApplyPass(*static_cast<Graph*>(src), vpass); *g = ApplyPasses(*static_cast<Graph*>(src), vpass);
*dst = g; *dst = g;
API_END_HANDLE_ERROR(delete g); API_END_HANDLE_ERROR(delete g);
} }
...@@ -22,8 +22,8 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) { ...@@ -22,8 +22,8 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) {
return nullptr; return nullptr;
} }
Graph ApplyPass(Graph g, Graph ApplyPasses(Graph g,
const std::vector<std::string>& pass) { const std::vector<std::string>& pass) {
std::vector<const PassFunctionReg*> fpass; std::vector<const PassFunctionReg*> fpass;
for (auto& name : pass) { for (auto& name : pass) {
auto* reg = dmlc::Registry<PassFunctionReg>::Find(name); auto* reg = dmlc::Registry<PassFunctionReg>::Find(name);
......
...@@ -13,7 +13,7 @@ namespace { ...@@ -13,7 +13,7 @@ namespace {
template<typename AttrType, typename IsNone> template<typename AttrType, typename IsNone>
Graph InferAttr(Graph &&ret, Graph InferAttr(Graph &&ret,
const AttrType def_value, const AttrType default_val,
const char* infer_name, const char* infer_name,
const char* input_name, const char* input_name,
const char* attr_key_name, const char* attr_key_name,
...@@ -23,16 +23,16 @@ Graph InferAttr(Graph &&ret, ...@@ -23,16 +23,16 @@ Graph InferAttr(Graph &&ret,
using AttrVector = std::vector<AttrType>; using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph(); const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape = static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name); Op::GetAttr<FInferNodeEntryAttr<AttrType>>(infer_name);
static auto& backward_map = static auto& backward_map =
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex"); Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
// reshape shape vector // reshape shape vector
AttrVector rshape(idx.num_node_entries(), def_value); AttrVector rshape(idx.num_node_entries(), default_val);
if (ret.attrs.count(input_name) != 0) { if (ret.attrs.count(input_name) != 0) {
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name); const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
CHECK_LE(shape_args.size(), idx.input_nodes().size()) CHECK_LE(shape_args.size(), idx.input_nodes().size())
<< "shape args is more than number of arguments"; << "More provided shapes than number of arguments.";
for (size_t i = 0; i < shape_args.size(); ++i) { for (size_t i = 0; i < shape_args.size(); ++i) {
rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i]; rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
} }
...@@ -46,36 +46,41 @@ Graph InferAttr(Graph &&ret, ...@@ -46,36 +46,41 @@ Graph InferAttr(Graph &&ret,
ret.attrs.erase(attr_key_name); ret.attrs.erase(attr_key_name);
} }
// temp space for shape inference. // Temp space for shape inference.
std::vector<AttrType> ishape, oshape; std::vector<AttrType> ishape, oshape;
// number of completed nodes // number of completed nodes
size_t num_unknown = 0; size_t num_unknown = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid]; const auto& inode = idx[nid];
uint32_t num_inputs = inode.inputs.size(); const uint32_t num_inputs = inode.inputs.size();
uint32_t num_outputs = inode.source->num_outputs(); const uint32_t num_outputs = inode.source->num_outputs();
if (inode.source->is_variable()) { if (inode.source->is_variable()) {
if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) { // Variable node. No operator. Only one output entry.
CHECK(inode.source->op() == nullptr);
CHECK_EQ(num_outputs, 1);
const uint32_t out_ent_id = idx.entry_id(nid, 0);
if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
auto it = inode.source->attrs.dict.find(shape_attr_key); auto it = inode.source->attrs.dict.find(shape_attr_key);
if (it != inode.source->attrs.dict.end()) { if (it != inode.source->attrs.dict.end()) {
CHECK_EQ(num_outputs, 1);
std::istringstream is(it->second); std::istringstream is(it->second);
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid attribute"; CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
} }
} }
continue; } else if (finfer_shape.count(inode.source->op())) {
} // Forward operator inference.
if (finfer_shape.count(inode.source->op())) { ishape.resize(num_inputs, default_val);
ishape.resize(num_inputs, def_value);
for (uint32_t i = 0; i < ishape.size(); ++i) { for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
} }
oshape.resize(num_outputs, def_value); oshape.resize(num_outputs, default_val);
for (uint32_t i = 0; i < oshape.size(); ++i) { for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = rshape[idx.entry_id(nid, i)]; oshape[i] = rshape[idx.entry_id(nid, i)];
} }
num_unknown += // Call inference function of the operator.
!(finfer_shape[inode.source->op()](inode.source->attrs, &ishape, &oshape)); bool forward_known = finfer_shape[inode.source->op()](
inode.source->attrs, &ishape, &oshape);
num_unknown += !forward_known;
// Save to the result map.
for (uint32_t i = 0; i < num_inputs; ++i) { for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
} }
...@@ -83,10 +88,12 @@ Graph InferAttr(Graph &&ret, ...@@ -83,10 +88,12 @@ Graph InferAttr(Graph &&ret,
rshape[idx.entry_id(nid, i)] = oshape[i]; rshape[idx.entry_id(nid, i)] = oshape[i];
} }
} else if (backward_map.count(inode.source->op())) { } else if (backward_map.count(inode.source->op())) {
// backward operator inference. // Backward operator inference.
CHECK_GE(inode.control_deps.size(), 1) CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op"; << "BackwardOp need to have control_deps to its forward op";
const auto& fnode = idx[inode.control_deps[0]]; const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
// Inference the outputs of backward operator (equal to the inputs
// of its corresponding forward operator).
std::vector<uint32_t> out_map = std::vector<uint32_t> out_map =
backward_map[inode.source->op()](inode.source->attrs); backward_map[inode.source->op()](inode.source->attrs);
bool known = true; bool known = true;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment