Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
8bc7c3e4
Commit
8bc7c3e4
authored
Sep 09, 2016
by
Minjie Wang
Committed by
Tianqi Chen
May 29, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix comments while reading the codes. (#42)
parent
67179f78
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
116 additions
and
84 deletions
+116
-84
nnvm/include/nnvm/pass.h
+20
-16
nnvm/include/nnvm/pass_functions.h
+33
-22
nnvm/include/nnvm/symbolic.h
+59
-42
nnvm/include/nnvm/tuple.h
+4
-4
No files found.
nnvm/include/nnvm/pass.h
View file @
8bc7c3e4
...
@@ -14,11 +14,12 @@
...
@@ -14,11 +14,12 @@
namespace
nnvm
{
namespace
nnvm
{
/*!
/*!
* \brief A PassFunction is a basic "Operator on Graph"
* \brief A PassFunction is an "Operator on Graph".
* It takes a source graph
* It takes a source graph and return a graph that may or may
* not be the same as the input one.
*
*
* A pass function can either change the graph structure
of g
,
* A pass function can either change the graph structure
(thus
,
* generating a new Graph, or add new attributes to the graph.
* generating a new Graph
)
, or add new attributes to the graph.
*
*
* \param src The graph to be transformed.
* \param src The graph to be transformed.
* \return The generated graph.
* \return The generated graph.
...
@@ -26,10 +27,10 @@ namespace nnvm {
...
@@ -26,10 +27,10 @@ namespace nnvm {
typedef
std
::
function
<
Graph
(
Graph
src
)
>
PassFunction
;
typedef
std
::
function
<
Graph
(
Graph
src
)
>
PassFunction
;
/*!
/*!
* \brief Apply a series of pass transformations on
g
.
* \brief Apply a series of pass transformations on
the input graph
.
* \param src The graph to be transformed.
* \param src The graph to be transformed.
* \param pass The name of pass to be applied.
* \param pass The name of pass to be applied.
* \return The transformed graph
* \return The transformed graph
.
*/
*/
Graph
ApplyPass
(
Graph
src
,
Graph
ApplyPass
(
Graph
src
,
const
std
::
vector
<
std
::
string
>&
pass
);
const
std
::
vector
<
std
::
string
>&
pass
);
...
@@ -52,36 +53,39 @@ struct PassFunctionReg
...
@@ -52,36 +53,39 @@ struct PassFunctionReg
/*! \brief generated targets of graph attributes */
/*! \brief generated targets of graph attributes */
std
::
vector
<
std
::
string
>
graph_attr_targets
;
std
::
vector
<
std
::
string
>
graph_attr_targets
;
/*!
/*!
* \brief
s
et whether this pass will change graph structure.
* \brief
S
et whether this pass will change graph structure.
* \param v
the value to set
* \param v
If true, the pass will change graph structure.
* \return
r
eference to self.
* \return
R
eference to self.
*/
*/
PassFunctionReg
&
set_change_graph
(
bool
v
)
{
// NOLINT(*)
PassFunctionReg
&
set_change_graph
(
bool
v
)
{
// NOLINT(*)
change_graph
=
v
;
change_graph
=
v
;
return
*
this
;
return
*
this
;
}
}
/*!
/*!
* \brief Declare this pass require operator attribute attr_name to be available.
* \brief Declare that this pass will generate the given graph attribute name
* \param attr_name Name of the attribute.
* once it is applied on the graph.
* \return reference to self.
* \param attr_name Name of the graph attribute.
* \return Reference to self.
*/
*/
PassFunctionReg
&
provide_graph_attr
(
const
std
::
string
&
attr_name
)
{
// NOLINT(*)
PassFunctionReg
&
provide_graph_attr
(
const
std
::
string
&
attr_name
)
{
// NOLINT(*)
graph_attr_targets
.
push_back
(
attr_name
);
graph_attr_targets
.
push_back
(
attr_name
);
return
*
this
;
return
*
this
;
}
}
/*!
/*!
* \brief declare this pass require operator attribute attr_name to be available.
* \brief Declare this pass requires the given operator attribute to be
* available before being applied on the graph.
* \param attr_name Name of the attribute.
* \param attr_name Name of the attribute.
* \return
r
eference to self.
* \return
R
eference to self.
*/
*/
PassFunctionReg
&
depend_op_attr
(
const
std
::
string
&
attr_name
)
{
// NOLINT(*)
PassFunctionReg
&
depend_op_attr
(
const
std
::
string
&
attr_name
)
{
// NOLINT(*)
op_attr_dependency
.
push_back
(
attr_name
);
op_attr_dependency
.
push_back
(
attr_name
);
return
*
this
;
return
*
this
;
}
}
/*!
/*!
* \brief declare this pass require graph attribute attr_name to be available.
* \brief Declare this pass requires the given graph attribute to be
* available before being applied on the graph.
* \param attr_name Name of the attribute.
* \param attr_name Name of the attribute.
* \return
r
eference to self.
* \return
R
eference to self.
*/
*/
PassFunctionReg
&
depend_graph_attr
(
const
std
::
string
&
attr_name
)
{
// NOLINT(*)
PassFunctionReg
&
depend_graph_attr
(
const
std
::
string
&
attr_name
)
{
// NOLINT(*)
graph_attr_dependency
.
push_back
(
attr_name
);
graph_attr_dependency
.
push_back
(
attr_name
);
...
...
nnvm/include/nnvm/pass_functions.h
View file @
8bc7c3e4
...
@@ -33,7 +33,7 @@ inline Graph LoadJSON(const std::string& json_str) {
...
@@ -33,7 +33,7 @@ inline Graph LoadJSON(const std::string& json_str) {
/*!
/*!
* \brief Save a graph to json, redirects to "SaveJSON" pass.
* \brief Save a graph to json, redirects to "SaveJSON" pass.
* \param graph The
to be saved
.
* \param graph The
graph to be saved as json format
.
* \return The json string.
* \return The json string.
*/
*/
inline
std
::
string
SaveJSON
(
Graph
graph
)
{
inline
std
::
string
SaveJSON
(
Graph
graph
)
{
...
@@ -42,11 +42,14 @@ inline std::string SaveJSON(Graph graph) {
...
@@ -42,11 +42,14 @@ inline std::string SaveJSON(Graph graph) {
}
}
/*!
/*!
* \brief Add control flow dependencies between nodes
* \brief Add control flow dependencies between nodes.
* To correctly order mutation and read to resolve
*
* write after read problem and read after write problems.
* This function will enforce the correct order between
* \param src source graph
* write (mutable operators) and read (immutable operators)
* \return A graph that added control flow dependencies.
* to sovle write-after-read and read-after-write problems.
*
* \param src The input graph.
* \return A graph with proper control flow dependencies added.
*/
*/
inline
Graph
OrderMutation
(
Graph
src
)
{
inline
Graph
OrderMutation
(
Graph
src
)
{
return
ApplyPass
(
std
::
move
(
src
),
{
"OrderMutation"
});
return
ApplyPass
(
std
::
move
(
src
),
{
"OrderMutation"
});
...
@@ -54,11 +57,12 @@ inline Graph OrderMutation(Graph src) {
...
@@ -54,11 +57,12 @@ inline Graph OrderMutation(Graph src) {
/*!
/*!
* \brief Infer shapes in the graph given the information.
* \brief Infer shapes in the graph given the information.
* \param graph source graph
* \param graph The input graph.
* \param shape_inputs The shapes of aruguments to the graph.
* \param shape_inputs The shapes of input symbols to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape.
* \param shape_attr_key The key to the node attribute that can indicate shape. This is
* the place where manual hint for shapes could be injected.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
*
The index of ShapeVector is given by graph.indexed_graph().entry_id
*
The index of ShapeVector is given by graph.indexed_graph().entry_id.
*/
*/
inline
Graph
InferShape
(
Graph
graph
,
inline
Graph
InferShape
(
Graph
graph
,
ShapeVector
shape_inputs
,
ShapeVector
shape_inputs
,
...
@@ -74,11 +78,12 @@ inline Graph InferShape(Graph graph,
...
@@ -74,11 +78,12 @@ inline Graph InferShape(Graph graph,
/*!
/*!
* \brief Infer types in the graph given the information.
* \brief Infer types in the graph given the information.
* \param graph source graph
* \param graph The input graph.
* \param dtype_inputs The shapes of inputs to the graph.
* \param dtype_inputs The types of input symbols to the graph.
* \param dtype_attr_key The key to the node attribute that can indicate shape.
* \param dtype_attr_key The key to the node attribute that can indicate types. This is
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* the place where manual hint for types could be injected.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
* \return A graph with new attribute "dtype" containing inferred type of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id.
*/
*/
inline
Graph
InferType
(
Graph
graph
,
inline
Graph
InferType
(
Graph
graph
,
DTypeVector
dtype_inputs
,
DTypeVector
dtype_inputs
,
...
@@ -93,10 +98,16 @@ inline Graph InferType(Graph graph,
...
@@ -93,10 +98,16 @@ inline Graph InferType(Graph graph,
}
}
/*!
/*!
* \brief Place the devices
* \brief Place the devices for each operator in the graph.
* \param graph source graph
*
* \param device_group_attr_key The attribute name for hinting the device group.
* Current device placement is quite simple. Each operator is assigned to a "group" (stored
* \param device_assign_map The assignment map of device
* in `device_group_attr_key` attribute). Each group is assigned to a device (stored in
* `device_assign_map` attribute). Operators will be placed to the device assigned to its
* group. Copy operators will be injected if cross device reference happens.
*
* \param graph The input graph.
* \param device_group_attr_key The attribute name for hints of device group.
* \param device_assign_map The assignment map of device.
* \param device_copy_op The name of copy op to be inserted when cross device copy happened.
* \param device_copy_op The name of copy op to be inserted when cross device copy happened.
* \return A graph with new attribute "device", cotaining device information of each node.
* \return A graph with new attribute "device", cotaining device information of each node.
*/
*/
...
@@ -112,13 +123,13 @@ inline Graph PlaceDevice(Graph graph,
...
@@ -112,13 +123,13 @@ inline Graph PlaceDevice(Graph graph,
/*!
/*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph
source graph
* \param graph
The input graph.
* \param ys The entries we want to take gradient from.
* \param ys The entries we want to take gradient from.
* \param xs The input to take gradient with respect to.
* \param xs The input to take gradient with respect to.
* \param ys_out_grad The symbol for additional gradient to be propagate back to y.
* \param ys_out_grad The symbol for additional gradient to be propagate back to y.
* \param aggregate_fun
aggregation function applied to aggregate the inputs
* \param aggregate_fun
Aggregation function applied to aggregate the inputs.
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \return A new graph, whose outputs correspond
s
to inputs of xs.
* \return A new graph, whose outputs correspond to inputs of xs.
*/
*/
inline
Graph
Gradient
(
inline
Graph
Gradient
(
Graph
graph
,
Graph
graph
,
...
...
nnvm/include/nnvm/symbolic.h
View file @
8bc7c3e4
...
@@ -19,7 +19,13 @@
...
@@ -19,7 +19,13 @@
namespace
nnvm
{
namespace
nnvm
{
/*!
/*!
* \brief Symbol is used to represent the
* \brief Symbol is help class used to represent the operator node in Graph.
*
* Symbol acts as an interface for building graphs from different components
* like Variable, Functor and Group. Symbol is also exported to python front-end
* (while Graph is not) to enable quick test and deployment. Conceptually,
* symbol is the final operation of a graph and thus including all the information
* required (the graph) to evaluate its output value.
*/
*/
class
Symbol
{
class
Symbol
{
public
:
public
:
...
@@ -47,42 +53,46 @@ class Symbol {
...
@@ -47,42 +53,46 @@ class Symbol {
std
::
vector
<
NodeEntry
>
outputs
;
std
::
vector
<
NodeEntry
>
outputs
;
/*!
/*!
* \brief
copy the symbol
* \brief
Copy the symbol.
* \return
a deep copy of the symbolic graph
.
* \return
A deep copy of this symbol
.
*/
*/
Symbol
Copy
()
const
;
Symbol
Copy
()
const
;
/*!
/*!
* \brief
p
rint the symbol info to output stream.
* \brief
P
rint the symbol info to output stream.
* \param os
the output stream we like to print to
* \param os
The output stream to print to.
*/
*/
void
Print
(
std
::
ostream
&
os
)
const
;
// NOLINT(*)
void
Print
(
std
::
ostream
&
os
)
const
;
// NOLINT(*)
/*!
/*!
* \brief
get the index
th element from the returned tuple.
* \brief
Get the index-
th element from the returned tuple.
* \param index
index of multi output
* \param index
Index of multi output.
* \return
t
he symbol corresponds to the indexed element.
* \return
T
he symbol corresponds to the indexed element.
*/
*/
Symbol
operator
[]
(
size_t
index
)
const
;
Symbol
operator
[]
(
size_t
index
)
const
;
/*!
/*!
* \brief List the input variable nodes
* \brief List the input variable nodes.
* \param option The options to list the arguments.
*
* The order of the returned list is the same as the order of the input list to `operator()`.
*
*
*
The position of the returned list also corresponds to calling position in operator()
*
\param option The options to list the arguments.
* \return
t
he arguments list of this symbol, they can be either named or unnamed (empty string).
* \return
T
he arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption
* \sa ListInputOption
*/
*/
std
::
vector
<
NodePtr
>
ListInputs
(
ListInputOption
option
)
const
;
std
::
vector
<
NodePtr
>
ListInputs
(
ListInputOption
option
)
const
;
/*!
/*!
* \brief List the input names.
* \brief List the input names.
* \param option The options to list the arguments.
*
*
* The position of the returned list also corresponds to calling position in operator()
* The order of the returned list is the same as the order of the input list to `operator()`.
* \return the arguments list of this symbol, they can be either named or unnamed (empty string).
*
* \param option The options to list the arguments.
* \return The arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption
* \sa ListInputOption
*/
*/
std
::
vector
<
std
::
string
>
ListInputNames
(
ListInputOption
option
)
const
;
std
::
vector
<
std
::
string
>
ListInputNames
(
ListInputOption
option
)
const
;
/*!
/*!
* \brief List the names of outputs for this symbol.
* \brief List the names of outputs for this symbol.
* For normal operators, it is usually symbol node name + "_output"
*
* For normal operators, it is usually symbol node name + "_output".
*
* \return get the descriptions of outputs for this symbol.
* \return get the descriptions of outputs for this symbol.
*/
*/
std
::
vector
<
std
::
string
>
ListOutputNames
()
const
;
std
::
vector
<
std
::
string
>
ListOutputNames
()
const
;
...
@@ -92,28 +102,30 @@ class Symbol {
...
@@ -92,28 +102,30 @@ class Symbol {
*
*
* The rest of the symbols will remain the same name.
* The rest of the symbols will remain the same name.
*
*
* \param args
positional arguments
* \param args
Positional arguments.
* \param kwargs
keyword arguments for the symbol
* \param kwargs
Keyword arguments for the symbol.
* \param name
n
ame of returned symbol.
* \param name
N
ame of returned symbol.
*/
*/
void
Compose
(
const
array_view
<
const
Symbol
*>&
args
,
void
Compose
(
const
array_view
<
const
Symbol
*>&
args
,
const
std
::
unordered_map
<
std
::
string
,
const
Symbol
*>&
kwargs
,
const
std
::
unordered_map
<
std
::
string
,
const
Symbol
*>&
kwargs
,
const
std
::
string
&
name
);
const
std
::
string
&
name
);
/*!
/*!
* \brief Apply the symbol as a function, compose with arguments
* \brief Apply the symbol as a function, compose with arguments
*
* This is equivalent to Copy then Compose.
* This is equivalent to Copy then Compose.
* \param args positional arguments for the symbol
*
* \param kwargs keyword arguments for the symbol
* \param args Positional arguments for the symbol.
* \param name name of returned symbol.
* \param kwargs Keyword arguments for the symbol.
* \return a new Symbol which is the composition of current symbol with its arguments
* \param name Name of returned symbol.
* \return A new Symbol which is the composition of current symbol with its arguments.
*/
*/
Symbol
operator
()
(
const
array_view
<
const
Symbol
*>&
args
,
Symbol
operator
()
(
const
array_view
<
const
Symbol
*>&
args
,
const
std
::
unordered_map
<
std
::
string
,
const
Symbol
*>&
kwargs
,
const
std
::
unordered_map
<
std
::
string
,
const
Symbol
*>&
kwargs
,
const
std
::
string
&
name
)
const
;
const
std
::
string
&
name
)
const
;
/*!
/*!
* \brief Add control flow depenencies to
operators involved
in symbols.
* \brief Add control flow depenencies to
the operators
in symbols.
*
For grouped sybmbol, an error will be raised.
*
*
This mutate
current symbolic Node.
*
For grouped symbol, an error will be raised. This mutates
current symbolic Node.
*
*
* \param src The symbols to depend on.
* \param src The symbols to depend on.
*/
*/
...
@@ -121,38 +133,43 @@ class Symbol {
...
@@ -121,38 +133,43 @@ class Symbol {
/*
/*
* \brief Get all the internal nodes of the symbol.
* \brief Get all the internal nodes of the symbol.
* \return symbol A new symbol whose output contains all the outputs of the symbols
* \return symbol A new symbol whose output contains all the outputs of the symbols
*
I
ncluding input variables and intermediate outputs.
*
i
ncluding input variables and intermediate outputs.
*/
*/
Symbol
GetInternals
()
const
;
Symbol
GetInternals
()
const
;
/*!
/*!
* \brief set additional attributes to current node.
* \brief Set additional attributes to current node.
*
* This only works for symbol with outputs from single operators.
* This only works for symbol with outputs from single operators.
* For grouped sy
b
mbol, an error will be raised.
* For grouped symbol, an error will be raised.
*
*
* This function mutate the node's symbol and is not recommended.
* This function mutate
s
the node's symbol and is not recommended.
*
*
* \param attrs The attributes to set.
* \param attrs The attributes to set.
*/
*/
void
SetAttrs
(
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
attrs
);
void
SetAttrs
(
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>
>&
attrs
);
/*!
/*!
* \brief Get attributes from the symbol.
* \brief Get attributes from the symbol.
*
* This only works for symbol with outputs from single operators.
* This only works for symbol with outputs from single operators.
* For grouped sybmbol, an error will be raised.
* For grouped symbol, an error will be raised.
*
* \param key Key of the attribute. When key == "name", it returns the name attirbute.
* \param key Key of the attribute. When key == "name", it returns the name attirbute.
* \param out
t
he output value of the attribute.
* \param out
T
he output value of the attribute.
* \return true
if the attribute exists, false if the attribute do
not exist.
* \return true
If the attribute exists, false if the attribute does
not exist.
*/
*/
bool
GetAttr
(
const
std
::
string
&
key
,
std
::
string
*
out
)
const
;
bool
GetAttr
(
const
std
::
string
&
key
,
std
::
string
*
out
)
const
;
/*!
/*!
* \brief Get attribute dictionary from the symbol.
* \brief Get attribute dictionary from the symbol.
* For grouped sybmbol, an error will be raised.
*
* \param option If recursive is set, the attributes of all children are retrieved,
* For grouped symbol, an error will be raised.
*
* \param option If recursive flag is set, the attributes of all children are retrieved.
* The name of symbol will be pre-pended to each key.
* The name of symbol will be pre-pended to each key.
* \return The created attribute.
* \return The created attribute.
*/
*/
std
::
unordered_map
<
std
::
string
,
std
::
string
>
ListAttrs
(
ListAttrOption
option
)
const
;
std
::
unordered_map
<
std
::
string
,
std
::
string
>
ListAttrs
(
ListAttrOption
option
)
const
;
/*!
/*!
* \brief
c
reate symbolic functor(AtomicSymbol) by given operator and attributes.
* \brief
C
reate symbolic functor(AtomicSymbol) by given operator and attributes.
* \param op The operator.
* \param op The operator.
* \param attrs The additional attributes.
* \param attrs The additional attributes.
* \return Symbol that can be used to call compose further.
* \return Symbol that can be used to call compose further.
...
@@ -160,15 +177,15 @@ class Symbol {
...
@@ -160,15 +177,15 @@ class Symbol {
static
Symbol
CreateFunctor
(
const
Op
*
op
,
static
Symbol
CreateFunctor
(
const
Op
*
op
,
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attrs
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
attrs
);
/*!
/*!
* \brief
create variable symbol node
* \brief
Create symbol node representing variable.
* \param name
name of the variable
* \param name
Name of the variable.
* \return
the new variable
* \return
The symbol.
*/
*/
static
Symbol
CreateVariable
(
const
std
::
string
&
name
);
static
Symbol
CreateVariable
(
const
std
::
string
&
name
);
/*!
/*!
* \brief
create equivalence of symbol by grouping the symbols together
* \brief
Create equivalence of symbol by grouping the symbols together.
* \param symbols
list of symbols
* \param symbols
A list of symbols to be grouped.
* \return
the grouped symbol
* \return
The grouped symbol.
*/
*/
static
Symbol
CreateGroup
(
const
std
::
vector
<
Symbol
>&
symbols
);
static
Symbol
CreateGroup
(
const
std
::
vector
<
Symbol
>&
symbols
);
};
};
...
...
nnvm/include/nnvm/tuple.h
View file @
8bc7c3e4
...
@@ -19,11 +19,11 @@ namespace nnvm {
...
@@ -19,11 +19,11 @@ namespace nnvm {
typedef
uint32_t
index_t
;
typedef
uint32_t
index_t
;
/*!
/*!
* \brief A dynamic sized array data strcuture
* \brief A dynamic sized array data strcuture that is optimized for storing
* that is optimized for storing small number of elements with same type.
* small number of elements with same type.
* Data will be stored in stack when number of elements is small.
*
*
* It is suitable to hold Shape of Tensor.
* Data will be stored in stack when number of elements is small.
* It is suitable to hold shape of Tensor.
*
*
* \tparam ValueType The type of data stored inside tuple.
* \tparam ValueType The type of data stored inside tuple.
* \sa TShape
* \sa TShape
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment