Commit 2c0b79ae by Minjie Wang Committed by Tianqi Chen

ApplyPass -> ApplyPasses; Refactored infer pass; (#43)

* ApplyPass -> ApplyPasses; Refactored infer pass;

* lint fix
parent 24f1999c
......@@ -329,16 +329,16 @@ NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
const char* key,
SymbolHandle list);
/*!
* \brief Apply pass on the src graph.
* \brief Apply passes on the src graph.
* \param src The source graph handle.
* \param num_pass The number of pass to be applied.
* \param pass_names The names of the pass.
* \param dst The result graph.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphApplyPass(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst);
NNVM_DLL int NNGraphApplyPasses(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst);
#endif // NNVM_C_API_H_
......@@ -179,11 +179,11 @@ class IndexedGraph {
* \param other The source graph.
*/
explicit IndexedGraph(const Graph& other);
// node pointers in CSR structure.
// Node pointers in CSR structure.
std::vector<Node> nodes_;
// index all to input nodes
// Index to all input nodes.
std::vector<uint32_t> input_nodes_;
// index to mutable input nodes
// Index to all mutable input nodes.
std::unordered_set<uint32_t> mutable_input_nodes_;
// space to store the outputs entries
std::vector<NodeEntry> outputs_;
......
......@@ -18,7 +18,7 @@ namespace nnvm {
* \note Stored under ret.attrs["json"], provided by Pass "SaveJSON"
* \code
* Graph ret = ApplyPass(src_graph, {"SaveJSON"});
* Graph ret = ApplyPass(src_graph, "SaveJSON");
* const JSONString& json = ret.GetAttr<JSONString>("shape");
* \endcode
*/
......@@ -29,7 +29,7 @@ using JSONString = std::string;
* \note Stored under graph.attrs["shape"], provided by Pass "InferShape"
*
* \code
* Graph g = ApplyPass(src_graph, {"InferShape"});
* Graph g = ApplyPass(src_graph, "InferShape");
* const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape");
* // get shape by entry id
* TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)];
......@@ -44,7 +44,7 @@ using ShapeVector = std::vector<TShape>;
* \note Stored under graph.attrs["dtype"], provided by Pass "InferType"
*
* \code
* Graph g = ApplyPass(src_graph, {"InferType"});
* Graph g = ApplyPass(src_graph, "InferType");
* const DTypeVector& types = g.GetAttr<DTypeVector>("dtype");
* // get shape by entry id
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
......@@ -59,7 +59,7 @@ using DTypeVector = std::vector<int>;
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice"
*
* \code
* Graph g = ApplyPass(src_graph, {"PlaceDevice"});
* Graph g = ApplyPass(src_graph, "PlaceDevice");
* const &device = g.GetAttr<DeviceVector>("device");
* // get device by node_id
* int device_type = device[g.indexed_graph().node_id(my_node)];
......@@ -83,7 +83,7 @@ using DeviceAssignMap = std::unordered_map<std::string, int>;
* If the storage id is -1 then the storage is not assigned.
*
* \code
* Graph g = ApplyPass(src_graph, {"PlanMemory"});
* Graph g = ApplyPass(src_graph, "PlanMemory");
* const &storage = g.GetAttr<StorageVector>("storage");
* // get storage id by entry
* int storage_id = storage[g.indexed_graph().entry_id(my_entry)];
......
......@@ -29,11 +29,22 @@ typedef std::function<Graph (Graph src)> PassFunction;
/*!
* \brief Apply a series of pass transformations on the input graph.
* \param src The graph to be transformed.
* \param passes A list of pass names to be applied.
* \return The transformed graph
*/
Graph ApplyPasses(Graph src,
const std::vector<std::string>& passes);
/*!
* \brief Apply one pass to the graph.
* \param src The graph to be transformed.
* \param pass The name of pass to be applied.
* \return The transformed graph.
*/
Graph ApplyPass(Graph src,
const std::vector<std::string>& pass);
inline Graph ApplyPass(Graph src, const std::string& pass) {
return ApplyPasses(src, {pass});
}
/*!
* \brief Registry entry for DataIterator factory functions.
......
......@@ -28,7 +28,7 @@ namespace pass {
inline Graph LoadJSON(const std::string& json_str) {
Graph ret;
ret.attrs["json"] = std::make_shared<any>(json_str);
return ApplyPass(ret, {"LoadJSON"});
return ApplyPass(ret, "LoadJSON");
}
/*!
......@@ -37,7 +37,7 @@ inline Graph LoadJSON(const std::string& json_str) {
* \return The json string.
*/
inline std::string SaveJSON(Graph graph) {
Graph ret = ApplyPass(std::move(graph), {"SaveJSON"});
Graph ret = ApplyPass(std::move(graph), "SaveJSON");
return ret.GetAttr<std::string>("json");
}
......@@ -52,7 +52,7 @@ inline std::string SaveJSON(Graph graph) {
* \return A graph with proper control flow dependencies added.
*/
inline Graph OrderMutation(Graph src) {
return ApplyPass(std::move(src), {"OrderMutation"});
return ApplyPass(std::move(src), "OrderMutation");
}
/*!
......@@ -73,7 +73,7 @@ inline Graph InferShape(Graph graph,
if (shape_attr_key.length() != 0) {
graph.attrs["shape_attr_key"] = std::make_shared<any>(std::move(shape_attr_key));
}
return ApplyPass(std::move(graph), {"InferShape"});
return ApplyPass(std::move(graph), "InferShape");
}
/*!
......@@ -94,7 +94,7 @@ inline Graph InferType(Graph graph,
if (dtype_attr_key.length() != 0) {
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(dtype_attr_key));
}
return ApplyPass(std::move(graph), {"InferType"});
return ApplyPass(std::move(graph), "InferType");
}
/*!
......@@ -118,7 +118,7 @@ inline Graph PlaceDevice(Graph graph,
graph.attrs["device_group_attr_key"] = std::make_shared<any>(std::move(device_group_attr_key));
graph.attrs["device_assign_map"] = std::make_shared<any>(std::move(device_assign_map));
graph.attrs["device_copy_op"] = std::make_shared<any>(std::move(device_copy_op));
return ApplyPass(std::move(graph), {"PlaceDevice"});
return ApplyPass(std::move(graph), "PlaceDevice");
}
/*!
......@@ -149,7 +149,7 @@ inline Graph Gradient(
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
}
return ApplyPass(std::move(graph), {"Gradient"});
return ApplyPass(std::move(graph), "Gradient");
}
} // namespace pass
......
......@@ -113,7 +113,7 @@ class Graph(object):
cpass = c_array(ctypes.c_char_p, [c_str(key) for key in passes])
ghandle = GraphHandle()
npass = nn_uint(len(passes))
check_call(_LIB.NNGraphApplyPass(self.handle, npass, cpass, ctypes.byref(ghandle)))
check_call(_LIB.NNGraphApplyPasses(self.handle, npass, cpass, ctypes.byref(ghandle)))
return Graph(ghandle)
......
......@@ -82,17 +82,17 @@ int NNGraphGetJSONAttr(GraphHandle handle,
API_END();
}
int NNGraphApplyPass(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst) {
int NNGraphApplyPasses(GraphHandle src,
nn_uint num_pass,
const char** pass_names,
GraphHandle *dst) {
Graph* g = new Graph();
API_BEGIN();
std::vector<std::string> vpass;
for (nn_uint i = 0; i < num_pass; ++i) {
vpass.emplace_back(std::string(pass_names[i]));
}
*g = ApplyPass(*static_cast<Graph*>(src), vpass);
*g = ApplyPasses(*static_cast<Graph*>(src), vpass);
*dst = g;
API_END_HANDLE_ERROR(delete g);
}
......@@ -22,8 +22,8 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) {
return nullptr;
}
Graph ApplyPass(Graph g,
const std::vector<std::string>& pass) {
Graph ApplyPasses(Graph g,
const std::vector<std::string>& pass) {
std::vector<const PassFunctionReg*> fpass;
for (auto& name : pass) {
auto* reg = dmlc::Registry<PassFunctionReg>::Find(name);
......
......@@ -13,7 +13,7 @@ namespace {
template<typename AttrType, typename IsNone>
Graph InferAttr(Graph &&ret,
const AttrType def_value,
const AttrType default_val,
const char* infer_name,
const char* input_name,
const char* attr_key_name,
......@@ -23,16 +23,16 @@ Graph InferAttr(Graph &&ret,
using AttrVector = std::vector<AttrType>;
const IndexedGraph& idx = ret.indexed_graph();
static auto& finfer_shape =
Op::GetAttr<FInferNodeEntryAttr<AttrType> >(infer_name);
Op::GetAttr<FInferNodeEntryAttr<AttrType>>(infer_name);
static auto& backward_map =
Op::GetAttr<FBackwardOutToInIndex>("FBackwardOutToInIndex");
// reshape shape vector
AttrVector rshape(idx.num_node_entries(), def_value);
AttrVector rshape(idx.num_node_entries(), default_val);
if (ret.attrs.count(input_name) != 0) {
const AttrVector& shape_args = ret.GetAttr<AttrVector>(input_name);
CHECK_LE(shape_args.size(), idx.input_nodes().size())
<< "shape args is more than number of arguments";
<< "More provided shapes than number of arguments.";
for (size_t i = 0; i < shape_args.size(); ++i) {
rshape[idx.entry_id(idx.input_nodes()[i], 0)] = shape_args[i];
}
......@@ -46,36 +46,41 @@ Graph InferAttr(Graph &&ret,
ret.attrs.erase(attr_key_name);
}
// temp space for shape inference.
// Temp space for shape inference.
std::vector<AttrType> ishape, oshape;
// number of completed nodes
size_t num_unknown = 0;
for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
const auto& inode = idx[nid];
uint32_t num_inputs = inode.inputs.size();
uint32_t num_outputs = inode.source->num_outputs();
const uint32_t num_inputs = inode.inputs.size();
const uint32_t num_outputs = inode.source->num_outputs();
if (inode.source->is_variable()) {
if (shape_attr_key.length() != 0 && fis_none(rshape[idx.entry_id(nid, 0)])) {
// Variable node. No operator. Only one output entry.
CHECK(inode.source->op() == nullptr);
CHECK_EQ(num_outputs, 1);
const uint32_t out_ent_id = idx.entry_id(nid, 0);
if (shape_attr_key.length() != 0 && fis_none(rshape[out_ent_id])) {
auto it = inode.source->attrs.dict.find(shape_attr_key);
if (it != inode.source->attrs.dict.end()) {
CHECK_EQ(num_outputs, 1);
std::istringstream is(it->second);
CHECK(is >> rshape[idx.entry_id(nid, 0)]) << "Invalid attribute";
CHECK(is >> rshape[out_ent_id]) << "Invalid attribute";
}
}
continue;
}
if (finfer_shape.count(inode.source->op())) {
ishape.resize(num_inputs, def_value);
} else if (finfer_shape.count(inode.source->op())) {
// Forward operator inference.
ishape.resize(num_inputs, default_val);
for (uint32_t i = 0; i < ishape.size(); ++i) {
ishape[i] = rshape[idx.entry_id(inode.inputs[i])];
}
oshape.resize(num_outputs, def_value);
oshape.resize(num_outputs, default_val);
for (uint32_t i = 0; i < oshape.size(); ++i) {
oshape[i] = rshape[idx.entry_id(nid, i)];
}
num_unknown +=
!(finfer_shape[inode.source->op()](inode.source->attrs, &ishape, &oshape));
// Call inference function of the operator.
bool forward_known = finfer_shape[inode.source->op()](
inode.source->attrs, &ishape, &oshape);
num_unknown += !forward_known;
// Save to the result map.
for (uint32_t i = 0; i < num_inputs; ++i) {
rshape[idx.entry_id(inode.inputs[i])] = ishape[i];
}
......@@ -83,10 +88,12 @@ Graph InferAttr(Graph &&ret,
rshape[idx.entry_id(nid, i)] = oshape[i];
}
} else if (backward_map.count(inode.source->op())) {
// backward operator inference.
// Backward operator inference.
CHECK_GE(inode.control_deps.size(), 1)
<< "BackwardOp need to have control_deps to its forward op";
const auto& fnode = idx[inode.control_deps[0]];
const IndexedGraph::Node& fnode = idx[inode.control_deps[0]];
// Inference the outputs of backward operator (equal to the inputs
// of its corresponding forward operator).
std::vector<uint32_t> out_map =
backward_map[inode.source->op()](inode.source->attrs);
bool known = true;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment