Commit 8bc7c3e4 by Minjie Wang Committed by Tianqi Chen

Fix comments while reading the codes. (#42)

parent 67179f78
...@@ -14,11 +14,12 @@ ...@@ -14,11 +14,12 @@
namespace nnvm { namespace nnvm {
/*! /*!
* \brief A PassFunction is a basic "Operator on Graph" * \brief A PassFunction is an "Operator on Graph".
* It takes a source graph * It takes a source graph and return a graph that may or may
* not be the same as the input one.
* *
* A pass function can either change the graph structure of g, * A pass function can either change the graph structure (thus,
* generating a new Graph, or add new attributes to the graph. * generating a new Graph), or add new attributes to the graph.
* *
* \param src The graph to be transformed. * \param src The graph to be transformed.
* \return The generated graph. * \return The generated graph.
...@@ -26,10 +27,10 @@ namespace nnvm { ...@@ -26,10 +27,10 @@ namespace nnvm {
typedef std::function<Graph (Graph src)> PassFunction; typedef std::function<Graph (Graph src)> PassFunction;
/*! /*!
* \brief Apply a series of pass transformations on g. * \brief Apply a series of pass transformations on the input graph.
* \param src The graph to be transformed. * \param src The graph to be transformed.
* \param pass The name of pass to be applied. * \param pass The name of pass to be applied.
* \return The transformed graph * \return The transformed graph.
*/ */
Graph ApplyPass(Graph src, Graph ApplyPass(Graph src,
const std::vector<std::string>& pass); const std::vector<std::string>& pass);
...@@ -52,36 +53,39 @@ struct PassFunctionReg ...@@ -52,36 +53,39 @@ struct PassFunctionReg
/*! \brief generated targets of graph attributes */ /*! \brief generated targets of graph attributes */
std::vector<std::string> graph_attr_targets; std::vector<std::string> graph_attr_targets;
/*! /*!
* \brief set whether this pass will change graph structure. * \brief Set whether this pass will change graph structure.
* \param v the value to set * \param v If true, the pass will change graph structure.
* \return reference to self. * \return Reference to self.
*/ */
PassFunctionReg& set_change_graph(bool v) { // NOLINT(*) PassFunctionReg& set_change_graph(bool v) { // NOLINT(*)
change_graph = v; change_graph = v;
return *this; return *this;
} }
/*! /*!
* \brief Declare this pass require operator attribute attr_name to be available. * \brief Declare that this pass will generate the given graph attribute name
* \param attr_name Name of the attribute. * once it is applied on the graph.
* \return reference to self. * \param attr_name Name of the graph attribute.
* \return Reference to self.
*/ */
PassFunctionReg& provide_graph_attr(const std::string& attr_name) { // NOLINT(*) PassFunctionReg& provide_graph_attr(const std::string& attr_name) { // NOLINT(*)
graph_attr_targets.push_back(attr_name); graph_attr_targets.push_back(attr_name);
return *this; return *this;
} }
/*! /*!
* \brief declare this pass require operator attribute attr_name to be available. * \brief Declare this pass requires the given operator attribute to be
* available before being applied on the graph.
* \param attr_name Name of the attribute. * \param attr_name Name of the attribute.
* \return reference to self. * \return Reference to self.
*/ */
PassFunctionReg& depend_op_attr(const std::string& attr_name) { // NOLINT(*) PassFunctionReg& depend_op_attr(const std::string& attr_name) { // NOLINT(*)
op_attr_dependency.push_back(attr_name); op_attr_dependency.push_back(attr_name);
return *this; return *this;
} }
/*! /*!
* \brief declare this pass require graph attribute attr_name to be available. * \brief Declare this pass requires the given graph attribute to be
* available before being applied on the graph.
* \param attr_name Name of the attribute. * \param attr_name Name of the attribute.
* \return reference to self. * \return Reference to self.
*/ */
PassFunctionReg& depend_graph_attr(const std::string& attr_name) { // NOLINT(*) PassFunctionReg& depend_graph_attr(const std::string& attr_name) { // NOLINT(*)
graph_attr_dependency.push_back(attr_name); graph_attr_dependency.push_back(attr_name);
......
...@@ -33,7 +33,7 @@ inline Graph LoadJSON(const std::string& json_str) { ...@@ -33,7 +33,7 @@ inline Graph LoadJSON(const std::string& json_str) {
/*! /*!
* \brief Save a graph to json, redirects to "SaveJSON" pass. * \brief Save a graph to json, redirects to "SaveJSON" pass.
* \param graph The to be saved. * \param graph The graph to be saved as json format.
* \return The json string. * \return The json string.
*/ */
inline std::string SaveJSON(Graph graph) { inline std::string SaveJSON(Graph graph) {
...@@ -42,11 +42,14 @@ inline std::string SaveJSON(Graph graph) { ...@@ -42,11 +42,14 @@ inline std::string SaveJSON(Graph graph) {
} }
/*! /*!
* \brief Add control flow dependencies between nodes * \brief Add control flow dependencies between nodes.
* To correctly order mutation and read to resolve *
* write after read problem and read after write problems. * This function will enforce the correct order between
* \param src source graph * write (mutable operators) and read (immutable operators)
* \return A graph that added control flow dependencies. * to sovle write-after-read and read-after-write problems.
*
* \param src The input graph.
* \return A graph with proper control flow dependencies added.
*/ */
inline Graph OrderMutation(Graph src) { inline Graph OrderMutation(Graph src) {
return ApplyPass(std::move(src), {"OrderMutation"}); return ApplyPass(std::move(src), {"OrderMutation"});
...@@ -54,11 +57,12 @@ inline Graph OrderMutation(Graph src) { ...@@ -54,11 +57,12 @@ inline Graph OrderMutation(Graph src) {
/*! /*!
* \brief Infer shapes in the graph given the information. * \brief Infer shapes in the graph given the information.
* \param graph source graph * \param graph The input graph.
* \param shape_inputs The shapes of aruguments to the graph. * \param shape_inputs The shapes of input symbols to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape. * \param shape_attr_key The key to the node attribute that can indicate shape. This is
* the place where manual hint for shapes could be injected.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. * \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id * The index of ShapeVector is given by graph.indexed_graph().entry_id.
*/ */
inline Graph InferShape(Graph graph, inline Graph InferShape(Graph graph,
ShapeVector shape_inputs, ShapeVector shape_inputs,
...@@ -74,11 +78,12 @@ inline Graph InferShape(Graph graph, ...@@ -74,11 +78,12 @@ inline Graph InferShape(Graph graph,
/*! /*!
* \brief Infer types in the graph given the information. * \brief Infer types in the graph given the information.
* \param graph source graph * \param graph The input graph.
* \param dtype_inputs The shapes of inputs to the graph. * \param dtype_inputs The types of input symbols to the graph.
* \param dtype_attr_key The key to the node attribute that can indicate shape. * \param dtype_attr_key The key to the node attribute that can indicate types. This is
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry. * the place where manual hint for types could be injected.
* The index of ShapeVector is given by graph.indexed_graph().entry_id * \return A graph with new attribute "dtype" containing inferred type of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id.
*/ */
inline Graph InferType(Graph graph, inline Graph InferType(Graph graph,
DTypeVector dtype_inputs, DTypeVector dtype_inputs,
...@@ -93,10 +98,16 @@ inline Graph InferType(Graph graph, ...@@ -93,10 +98,16 @@ inline Graph InferType(Graph graph,
} }
/*! /*!
* \brief Place the devices * \brief Place the devices for each operator in the graph.
* \param graph source graph *
* \param device_group_attr_key The attribute name for hinting the device group. * Current device placement is quite simple. Each operator is assigned to a "group" (stored
* \param device_assign_map The assignment map of device * in `device_group_attr_key` attribute). Each group is assigned to a device (stored in
* `device_assign_map` attribute). Operators will be placed to the device assigned to its
* group. Copy operators will be injected if cross device reference happens.
*
* \param graph The input graph.
* \param device_group_attr_key The attribute name for hints of device group.
* \param device_assign_map The assignment map of device.
* \param device_copy_op The name of copy op to be inserted when cross device copy happened. * \param device_copy_op The name of copy op to be inserted when cross device copy happened.
* \return A graph with new attribute "device", cotaining device information of each node. * \return A graph with new attribute "device", cotaining device information of each node.
*/ */
...@@ -112,13 +123,13 @@ inline Graph PlaceDevice(Graph graph, ...@@ -112,13 +123,13 @@ inline Graph PlaceDevice(Graph graph,
/*! /*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys. * \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph source graph * \param graph The input graph.
* \param ys The entries we want to take gradient from. * \param ys The entries we want to take gradient from.
* \param xs The input to take gradient with respect to. * \param xs The input to take gradient with respect to.
* \param ys_out_grad The symbol for additional gradient to be propagate back to y. * \param ys_out_grad The symbol for additional gradient to be propagate back to y.
* \param aggregate_fun aggregation function applied to aggregate the inputs * \param aggregate_fun Aggregation function applied to aggregate the inputs.
* \param mirror_fun Optional mirror function to do mirror optimization and save memory. * \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \return A new graph, whose outputs corresponds to inputs of xs. * \return A new graph, whose outputs correspond to inputs of xs.
*/ */
inline Graph Gradient( inline Graph Gradient(
Graph graph, Graph graph,
......
...@@ -19,7 +19,13 @@ ...@@ -19,7 +19,13 @@
namespace nnvm { namespace nnvm {
/*! /*!
* \brief Symbol is used to represent the * \brief Symbol is help class used to represent the operator node in Graph.
*
* Symbol acts as an interface for building graphs from different components
* like Variable, Functor and Group. Symbol is also exported to python front-end
* (while Graph is not) to enable quick test and deployment. Conceptually,
* symbol is the final operation of a graph and thus including all the information
* required (the graph) to evaluate its output value.
*/ */
class Symbol { class Symbol {
public: public:
...@@ -47,42 +53,46 @@ class Symbol { ...@@ -47,42 +53,46 @@ class Symbol {
std::vector<NodeEntry> outputs; std::vector<NodeEntry> outputs;
/*! /*!
* \brief copy the symbol * \brief Copy the symbol.
* \return a deep copy of the symbolic graph. * \return A deep copy of this symbol.
*/ */
Symbol Copy() const; Symbol Copy() const;
/*! /*!
* \brief print the symbol info to output stream. * \brief Print the symbol info to output stream.
* \param os the output stream we like to print to * \param os The output stream to print to.
*/ */
void Print(std::ostream &os) const; // NOLINT(*) void Print(std::ostream &os) const; // NOLINT(*)
/*! /*!
* \brief get the index th element from the returned tuple. * \brief Get the index-th element from the returned tuple.
* \param index index of multi output * \param index Index of multi output.
* \return the symbol corresponds to the indexed element. * \return The symbol corresponds to the indexed element.
*/ */
Symbol operator[] (size_t index) const; Symbol operator[] (size_t index) const;
/*! /*!
* \brief List the input variable nodes * \brief List the input variable nodes.
* \param option The options to list the arguments. *
* The order of the returned list is the same as the order of the input list to `operator()`.
* *
* The position of the returned list also corresponds to calling position in operator() * \param option The options to list the arguments.
* \return the arguments list of this symbol, they can be either named or unnamed (empty string). * \return The arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption * \sa ListInputOption
*/ */
std::vector<NodePtr> ListInputs(ListInputOption option) const; std::vector<NodePtr> ListInputs(ListInputOption option) const;
/*! /*!
* \brief List the input names. * \brief List the input names.
* \param option The options to list the arguments.
* *
* The position of the returned list also corresponds to calling position in operator() * The order of the returned list is the same as the order of the input list to `operator()`.
* \return the arguments list of this symbol, they can be either named or unnamed (empty string). *
* \param option The options to list the arguments.
* \return The arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption * \sa ListInputOption
*/ */
std::vector<std::string> ListInputNames(ListInputOption option) const; std::vector<std::string> ListInputNames(ListInputOption option) const;
/*! /*!
* \brief List the names of outputs for this symbol. * \brief List the names of outputs for this symbol.
* For normal operators, it is usually symbol node name + "_output" *
* For normal operators, it is usually symbol node name + "_output".
*
* \return get the descriptions of outputs for this symbol. * \return get the descriptions of outputs for this symbol.
*/ */
std::vector<std::string> ListOutputNames() const; std::vector<std::string> ListOutputNames() const;
...@@ -92,28 +102,30 @@ class Symbol { ...@@ -92,28 +102,30 @@ class Symbol {
* *
* The rest of the symbols will remain the same name. * The rest of the symbols will remain the same name.
* *
* \param args positional arguments * \param args Positional arguments.
* \param kwargs keyword arguments for the symbol * \param kwargs Keyword arguments for the symbol.
* \param name name of returned symbol. * \param name Name of returned symbol.
*/ */
void Compose(const array_view<const Symbol*>& args, void Compose(const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs, const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name); const std::string& name);
/*! /*!
* \brief Apply the symbol as a function, compose with arguments * \brief Apply the symbol as a function, compose with arguments
*
* This is equivalent to Copy then Compose. * This is equivalent to Copy then Compose.
* \param args positional arguments for the symbol *
* \param kwargs keyword arguments for the symbol * \param args Positional arguments for the symbol.
* \param name name of returned symbol. * \param kwargs Keyword arguments for the symbol.
* \return a new Symbol which is the composition of current symbol with its arguments * \param name Name of returned symbol.
* \return A new Symbol which is the composition of current symbol with its arguments.
*/ */
Symbol operator () (const array_view<const Symbol*>& args, Symbol operator () (const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs, const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) const; const std::string& name) const;
/*! /*!
* \brief Add control flow depenencies to operators involved in symbols. * \brief Add control flow depenencies to the operators in symbols.
* For grouped sybmbol, an error will be raised. *
* This mutate current symbolic Node. * For grouped symbol, an error will be raised. This mutates current symbolic Node.
* *
* \param src The symbols to depend on. * \param src The symbols to depend on.
*/ */
...@@ -121,38 +133,43 @@ class Symbol { ...@@ -121,38 +133,43 @@ class Symbol {
/* /*
* \brief Get all the internal nodes of the symbol. * \brief Get all the internal nodes of the symbol.
* \return symbol A new symbol whose output contains all the outputs of the symbols * \return symbol A new symbol whose output contains all the outputs of the symbols
* Including input variables and intermediate outputs. * including input variables and intermediate outputs.
*/ */
Symbol GetInternals() const; Symbol GetInternals() const;
/*! /*!
* \brief set additional attributes to current node. * \brief Set additional attributes to current node.
*
* This only works for symbol with outputs from single operators. * This only works for symbol with outputs from single operators.
* For grouped sybmbol, an error will be raised. * For grouped symbol, an error will be raised.
* *
* This function mutate the node's symbol and is not recommended. * This function mutates the node's symbol and is not recommended.
* *
* \param attrs The attributes to set. * \param attrs The attributes to set.
*/ */
void SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs); void SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs);
/*! /*!
* \brief Get attributes from the symbol. * \brief Get attributes from the symbol.
*
* This only works for symbol with outputs from single operators. * This only works for symbol with outputs from single operators.
* For grouped sybmbol, an error will be raised. * For grouped symbol, an error will be raised.
*
* \param key Key of the attribute. When key == "name", it returns the name attirbute. * \param key Key of the attribute. When key == "name", it returns the name attirbute.
* \param out the output value of the attribute. * \param out The output value of the attribute.
* \return true if the attribute exists, false if the attribute do not exist. * \return true If the attribute exists, false if the attribute does not exist.
*/ */
bool GetAttr(const std::string& key, std::string* out) const; bool GetAttr(const std::string& key, std::string* out) const;
/*! /*!
* \brief Get attribute dictionary from the symbol. * \brief Get attribute dictionary from the symbol.
* For grouped sybmbol, an error will be raised. *
* \param option If recursive is set, the attributes of all children are retrieved, * For grouped symbol, an error will be raised.
*
* \param option If recursive flag is set, the attributes of all children are retrieved.
* The name of symbol will be pre-pended to each key. * The name of symbol will be pre-pended to each key.
* \return The created attribute. * \return The created attribute.
*/ */
std::unordered_map<std::string, std::string> ListAttrs(ListAttrOption option) const; std::unordered_map<std::string, std::string> ListAttrs(ListAttrOption option) const;
/*! /*!
* \brief create symbolic functor(AtomicSymbol) by given operator and attributes. * \brief Create symbolic functor(AtomicSymbol) by given operator and attributes.
* \param op The operator. * \param op The operator.
* \param attrs The additional attributes. * \param attrs The additional attributes.
* \return Symbol that can be used to call compose further. * \return Symbol that can be used to call compose further.
...@@ -160,15 +177,15 @@ class Symbol { ...@@ -160,15 +177,15 @@ class Symbol {
static Symbol CreateFunctor(const Op* op, static Symbol CreateFunctor(const Op* op,
std::unordered_map<std::string, std::string> attrs); std::unordered_map<std::string, std::string> attrs);
/*! /*!
* \brief create variable symbol node * \brief Create symbol node representing variable.
* \param name name of the variable * \param name Name of the variable.
* \return the new variable * \return The symbol.
*/ */
static Symbol CreateVariable(const std::string& name); static Symbol CreateVariable(const std::string& name);
/*! /*!
* \brief create equivalence of symbol by grouping the symbols together * \brief Create equivalence of symbol by grouping the symbols together.
* \param symbols list of symbols * \param symbols A list of symbols to be grouped.
* \return the grouped symbol * \return The grouped symbol.
*/ */
static Symbol CreateGroup(const std::vector<Symbol>& symbols); static Symbol CreateGroup(const std::vector<Symbol>& symbols);
}; };
......
...@@ -19,11 +19,11 @@ namespace nnvm { ...@@ -19,11 +19,11 @@ namespace nnvm {
typedef uint32_t index_t; typedef uint32_t index_t;
/*! /*!
* \brief A dynamic sized array data strcuture * \brief A dynamic sized array data strcuture that is optimized for storing
* that is optimized for storing small number of elements with same type. * small number of elements with same type.
* Data will be stored in stack when number of elements is small.
* *
* It is suitable to hold Shape of Tensor. * Data will be stored in stack when number of elements is small.
* It is suitable to hold shape of Tensor.
* *
* \tparam ValueType The type of data stored inside tuple. * \tparam ValueType The type of data stored inside tuple.
* \sa TShape * \sa TShape
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment