Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
bd20bfd8
Commit
bd20bfd8
authored
Jul 18, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Pass] Check in infershape, move indexedgraph to graph.h (#15)
parent
94ae677a
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
307 additions
and
112 deletions
+307
-112
nnvm/include/nnvm/graph.h
+149
-0
nnvm/include/nnvm/graph_attr_types.h
+25
-108
nnvm/include/nnvm/op_attr_types.h
+23
-0
nnvm/include/nnvm/tuple.h
+20
-1
nnvm/src/core/graph.cc
+8
-1
nnvm/src/example/operator.cc
+27
-2
nnvm/src/pass/infer_shape.cc
+47
-0
nnvm/src/pass/order_mutation.cc
+2
-0
nnvm/src/pass/saveload_json.cc
+6
-0
No files found.
nnvm/include/nnvm/graph.h
View file @
bd20bfd8
...
@@ -18,6 +18,8 @@
...
@@ -18,6 +18,8 @@
namespace
nnvm
{
namespace
nnvm
{
class
IndexedGraph
;
/*!
/*!
* \brief Symbolic computation graph.
* \brief Symbolic computation graph.
* This is the intermediate representation for optimization pass.
* This is the intermediate representation for optimization pass.
...
@@ -32,6 +34,145 @@ class Graph {
...
@@ -32,6 +34,145 @@ class Graph {
* and can be shared across multiple Instance of graph
* and can be shared across multiple Instance of graph
*/
*/
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
const
any
>
>
attrs
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
const
any
>
>
attrs
;
/*!
* \brief Get the attribute from attrs.
* \param attr_name the name of the attribute
* \return the reference to corresponding attribute
* \tparam T the type of the attribute.
*/
template
<
typename
T
>
inline
const
T
&
GetAttr
(
const
std
::
string
&
attr_name
);
/*!
* \brief get a indexed graph of current graph, if not exist, create it on demand
* \return The indexed graph.
* \sa IndexedGraph
*/
const
IndexedGraph
&
indexed_graph
();
private
:
// internal structure of indexed graph
std
::
shared_ptr
<
const
IndexedGraph
>
indexed_graph_
;
};
/*!
* \brief Auxililary data structure to index a graph.
* It maps Nodes in the graph to consecutive integers node_id.
* It also maps IndexedGraph::NodeEntry to consecutive integer entry_id.
* This allows storing properties of Node and NodeEntry into
* compact vector and quickly access them without resorting to hashmap.
*
* The node_id and entry_rptr are the same as the JSON graph produced by SaveJSON Pass.
*/
class
IndexedGraph
{
public
:
/*! \brief represents a data in the graph */
struct
NodeEntry
{
/*! \brief the source node id in the computation graph */
uint32_t
node_id
;
/*! \brief index of output from the source. */
uint32_t
index
;
/*!
* \brief compare equality
* \param other the other entry to compare
* \return whether two entries equals to each other
*/
inline
bool
operator
==
(
const
NodeEntry
&
other
)
const
{
return
node_id
==
other
.
node_id
&&
index
==
other
.
index
;
}
};
/*! \brief Node data structure in IndexedGraph */
struct
Node
{
/*! \brief pointer to the source node */
const
nnvm
::
Node
*
source
;
/*! \brief inputs to the node */
array_view
<
NodeEntry
>
inputs
;
/*! \brief control flow dependencies to the node */
array_view
<
uint32_t
>
control_deps
;
};
/*! \return number of nodes in the graph */
inline
size_t
num_nodes
()
const
{
return
nodes_
.
size
();
}
/*! \return total number of NodeEntry in the graph */
inline
size_t
num_node_entries
()
const
{
return
entry_rptr_
.
back
();
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param node_id The node index
* \param index the output index
* \return the unique index.
*/
inline
uint32_t
entry_id
(
uint32_t
node_id
,
uint32_t
index
)
const
{
return
entry_rptr_
[
node_id
]
+
index
;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param e The entry to query for index.
* \return the unique index.
*/
inline
uint32_t
entry_id
(
const
NodeEntry
&
e
)
const
{
return
entry_rptr_
[
e
.
node_id
]
+
e
.
index
;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given NodeEntry.
* \param e The entry to query for index.
* \return the unique index.
*/
inline
uint32_t
entry_id
(
const
nnvm
::
NodeEntry
&
e
)
const
{
return
entry_rptr_
[
node_id
(
e
.
node
.
get
())]
+
e
.
index
;
}
/*!
* \brief Get the corresponding node id for a given Node in the IndexedGraph.
* \param node The Node to query for index.
* \return the node index.
*/
inline
uint32_t
node_id
(
const
nnvm
::
Node
*
node
)
const
{
return
node2index_
.
at
(
node
);
}
/*!
* \brief Get the corresponding Node structure for a given node_id.
* \param node_id The node id
* \return const reference to the corresponding IndexedGraph::Node
*/
inline
const
Node
&
operator
[](
uint32_t
node_id
)
const
{
return
nodes_
[
node_id
];
}
/*!
* \brief Get the corresponding Node structure
* \param node The pointer to the Node structure
* \return const reference to the corresponding IndexedGraph::Node
*/
inline
const
Node
&
operator
[](
const
nnvm
::
Node
*
node
)
const
{
return
nodes_
[
node_id
(
node
)];
}
/*! \return list of argument nodes */
inline
const
std
::
vector
<
uint32_t
>&
arg_nodes
()
const
{
return
arg_nodes_
;
}
private
:
friend
class
Graph
;
/*!
* \brief Constructor an IndexedGraph from normal Graph
* \param other The source graph.
*/
explicit
IndexedGraph
(
const
Graph
&
other
);
// node pointers in CSR structure.
std
::
vector
<
Node
>
nodes_
;
// index to argument nodes
std
::
vector
<
uint32_t
>
arg_nodes_
;
// mapping from node to index.
std
::
unordered_map
<
const
nnvm
::
Node
*
,
uint32_t
>
node2index_
;
// CSR pointer of node entries
std
::
vector
<
size_t
>
entry_rptr_
;
// space to store input entries of each
std
::
vector
<
NodeEntry
>
input_entries_
;
// control flow dependencies
std
::
vector
<
uint32_t
>
control_deps_
;
};
};
/*!
/*!
...
@@ -45,6 +186,14 @@ template<typename FVisit>
...
@@ -45,6 +186,14 @@ template<typename FVisit>
inline
void
DFSVisit
(
const
std
::
vector
<
NodeEntry
>&
heads
,
FVisit
fvisit
);
inline
void
DFSVisit
(
const
std
::
vector
<
NodeEntry
>&
heads
,
FVisit
fvisit
);
// inline function implementations
// inline function implementations
template
<
typename
T
>
inline
const
T
&
Graph
::
GetAttr
(
const
std
::
string
&
attr_name
)
{
auto
it
=
attrs
.
find
(
attr_name
);
CHECK
(
it
!=
attrs
.
end
())
<<
"Cannot find attribute "
<<
attr_name
<<
" in the graph"
;
return
nnvm
::
get
<
T
>
(
*
it
->
second
);
}
template
<
typename
GNode
,
typename
HashType
,
template
<
typename
GNode
,
typename
HashType
,
typename
FVisit
,
typename
HashFunc
,
typename
FVisit
,
typename
HashFunc
,
typename
InDegree
,
typename
GetInput
>
typename
InDegree
,
typename
GetInput
>
...
...
nnvm/include/nnvm/graph_attr_types.h
View file @
bd20bfd8
...
@@ -7,120 +7,37 @@
...
@@ -7,120 +7,37 @@
#define NNVM_GRAPH_ATTR_TYPES_H_
#define NNVM_GRAPH_ATTR_TYPES_H_
#include <vector>
#include <vector>
#include <
unordered_map
>
#include <
string
>
#include "./
graph
.h"
#include "./
tuple
.h"
namespace
nnvm
{
namespace
nnvm
{
/*!
/*!
* \brief Auxililary data structure to index a graph.
* \brief The result holder of JSON serializer
* It maps Nodes in the graph to consecutive integers node_id.
*
* It also maps IndexedGraph::NodeEntry to consecutive integer entry_id.
* \note Stored under ret.attrs["json"], provided by Pass "SaveJSON"
* This allows storing properties of Node and NodeEntry into
* compact vector and quickly access them without resorting to hashmap.
* \code
*/
* Graph ret = ApplyPass(src_graph, {"SaveJSON"});
struct
IndexedGraph
{
* const JSONString& json = ret.GetAttr<JSONString>("shape");
public
:
* \endcode
/*! \brief represents a data in the graph */
struct
NodeEntry
{
/*! \brief the source node id in the computation graph */
uint32_t
node_id
;
/*! \brief index of output from the source. */
uint32_t
index
;
/*!
* \brief compare equality
* \param other the other entry to compare
* \return whether two entries equals to each other
*/
inline
bool
operator
==
(
const
NodeEntry
&
other
)
const
{
return
node_id
==
other
.
node_id
&&
index
==
other
.
index
;
}
};
/*! \brief Node data structure in IndexedGraph */
struct
Node
{
/*! \brief pointer to the source node */
const
nnvm
::
Node
*
source
;
/*! \brief inputs to the node */
array_view
<
NodeEntry
>
inputs
;
/*! \brief control flow dependencies to the node */
array_view
<
uint32_t
>
control_deps
;
};
/*! \return number of nodes in the graph */
inline
size_t
num_nodes
()
const
{
return
nodes_
.
size
();
}
/*! \return total number of NodeEntry in the graph */
inline
size_t
num_node_entries
()
const
{
return
entry_rptr_
.
back
();
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given IndexedGraph::NodeEntry
* \param e The entry to query for index.
* \return the unique index.
*/
inline
uint32_t
entry_id
(
const
NodeEntry
&
e
)
const
{
return
entry_rptr_
[
e
.
node_id
]
+
e
.
index
;
}
/*!
* \brief Get a unique entry id between 0 to num_node_entries()
* for a given NodeEntry.
* \param e The entry to query for index.
* \return the unique index.
*/
inline
uint32_t
entry_id
(
const
nnvm
::
NodeEntry
&
e
)
const
{
return
entry_rptr_
[
node_id
(
e
.
node
.
get
())]
+
e
.
index
;
}
/*!
* \brief Get the corresponding node id for a given Node in the IndexedGraph.
* \param node The Node to query for index.
* \return the node index.
*/
inline
uint32_t
node_id
(
const
nnvm
::
Node
*
node
)
const
{
return
node2index_
.
at
(
node
);
}
/*!
* \brief Get the corresponding Node structure for a given node_id.
* \param node_id The node id
* \return const reference to the corresponding IndexedGraph::Node
*/
inline
const
Node
&
operator
[](
uint32_t
node_id
)
const
{
return
nodes_
[
node_id
];
}
/*!
* \brief Get the corresponding Node structure
* \param node The pointer to the Node structure
* \return const reference to the corresponding IndexedGraph::Node
*/
inline
const
Node
&
operator
[](
const
nnvm
::
Node
*
node
)
const
{
return
nodes_
[
node_id
(
node
)];
}
/*! \return list of argument nodes */
inline
const
std
::
vector
<
uint32_t
>&
arg_nodes
()
const
{
return
arg_nodes_
;
}
/*!
* \brief Constructor an IndexedGraph from normal Graph
* \param other The source graph.
*/
*/
explicit
IndexedGraph
(
const
Graph
&
other
);
using
JSONString
=
std
::
string
;
// disallow copy assign
IndexedGraph
(
const
IndexedGraph
&
other
)
=
delete
;
private
:
/*!
// node pointers in CSR structure
.
* \brief The result holder of shape of each NodeEntry in the graph
.
std
::
vector
<
Node
>
nodes_
;
* \note Stored under graph.attrs["shape"], provided by Pass "InferShape"
// index to argument nodes
*
std
::
vector
<
uint32_t
>
arg_nodes_
;
* \code
// mapping from node to index.
* Graph g = ApplyPass(src_graph, {"InferShape"});
std
::
unordered_map
<
const
nnvm
::
Node
*
,
uint32_t
>
node2index_
;
* const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape")
;
// CSR pointer of node entries
* // get shape by entry id
std
::
vector
<
size_t
>
entry_rptr_
;
* TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)]
;
// space to store input entries of each
* \endcode
std
::
vector
<
NodeEntry
>
input_entries_
;
*
// control flow dependencies
* \sa FInferShape
std
::
vector
<
uint32_t
>
control_deps_
;
*/
}
;
using
ShapeVector
=
std
::
vector
<
TShape
>
;
}
// namespace nnvm
}
// namespace nnvm
...
...
nnvm/include/nnvm/op_attr_types.h
View file @
bd20bfd8
...
@@ -9,6 +9,8 @@
...
@@ -9,6 +9,8 @@
#include <vector>
#include <vector>
#include <string>
#include <string>
#include <functional>
#include <functional>
#include "./base.h"
#include "./tuple.h"
namespace
nnvm
{
namespace
nnvm
{
...
@@ -39,6 +41,7 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
...
@@ -39,6 +41,7 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
/*!
/*!
* \brief Check whether operator will mutate k-th input.
* \brief Check whether operator will mutate k-th input.
* \param attrs The attributes of the node.
* \param index The input index
* \param index The input index
* \return Whether this operator will mutate index-th input.
* \return Whether this operator will mutate index-th input.
*
*
...
@@ -47,6 +50,26 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
...
@@ -47,6 +50,26 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
*/
*/
using
FMutateInput
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
uint32_t
index
)
>
;
using
FMutateInput
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
uint32_t
index
)
>
;
/*!
* \brief Shape inference function.
* Update the shapes given the input shape information.
* TShape.ndim() == 0 means the shape is still unknown.
*
* \param attrs The attributes of the node.
* \param in_shapes Array of shapes from the inputs.
* \param out_shapes Array of shapes from the outputs.
*
* \return Whether all the shapes are known.
*
* \note Register under "FInferShape",
* by default do not update any shapes.
*
* FInferShape is needed by shape inference
*/
using
FInferShape
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
array_view
<
TShape
*>
in_shapes
,
array_view
<
TShape
*>
out_shapes
)
>
;
}
// namespace nnvm
}
// namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_
#endif // NNVM_OP_ATTR_TYPES_H_
nnvm/include/nnvm/tuple.h
View file @
bd20bfd8
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include <type_traits>
#include <type_traits>
#include <algorithm>
#include <algorithm>
#include <iostream>
#include <iostream>
#include "./base.h"
namespace
nnvm
{
namespace
nnvm
{
...
@@ -179,7 +180,23 @@ class Tuple {
...
@@ -179,7 +180,23 @@ class Tuple {
inline
const
ValueType
&
operator
[](
index_t
i
)
const
{
inline
const
ValueType
&
operator
[](
index_t
i
)
const
{
return
begin
()[
i
];
return
begin
()[
i
];
}
}
/*!
* \brief Save Tuple to JSON.
* \param writer JSONWriter
*/
inline
void
Save
(
dmlc
::
JSONWriter
*
writer
)
const
{
std
::
vector
<
ValueType
>
tmp
(
begin
(),
end
());
writer
->
Write
(
tmp
);
}
/*!
* \brief Load Tuple from JSON.
* \param reader JSONReader
*/
inline
void
Load
(
dmlc
::
JSONReader
*
reader
)
{
std
::
vector
<
ValueType
>
tmp
;
reader
->
Read
(
&
tmp
);
this
->
assign
(
tmp
.
begin
(),
tmp
.
end
());
}
/*!
/*!
* \brief allow output string of tuple to ostream
* \brief allow output string of tuple to ostream
* \param os the output stream
* \param os the output stream
...
@@ -287,6 +304,8 @@ class TShape : public Tuple<index_t> {
...
@@ -287,6 +304,8 @@ class TShape : public Tuple<index_t> {
public
:
public
:
// inheritate other constructors from Tuple
// inheritate other constructors from Tuple
using
Tuple
<
index_t
>::
Tuple
;
using
Tuple
<
index_t
>::
Tuple
;
/*! \brief default constructor */
TShape
()
=
default
;
/*!
/*!
* \brief copy constructor of TShape
* \brief copy constructor of TShape
* \param s source shape.
* \param s source shape.
...
...
nnvm/src/core/graph
_attr_types
.cc
→
nnvm/src/core/graph.cc
View file @
bd20bfd8
...
@@ -3,11 +3,18 @@
...
@@ -3,11 +3,18 @@
* \file graph_attr_types.cc
* \file graph_attr_types.cc
* \brief Graph node data structure.
* \brief Graph node data structure.
*/
*/
#include <nnvm/graph
_attr_types
.h>
#include <nnvm/graph.h>
#include <limits>
#include <limits>
namespace
nnvm
{
namespace
nnvm
{
const
IndexedGraph
&
Graph
::
indexed_graph
()
{
if
(
indexed_graph_
==
nullptr
)
{
indexed_graph_
.
reset
(
new
IndexedGraph
(
*
this
));
}
return
*
indexed_graph_
;
}
// implement constructor from graph
// implement constructor from graph
IndexedGraph
::
IndexedGraph
(
const
Graph
&
g
)
{
IndexedGraph
::
IndexedGraph
(
const
Graph
&
g
)
{
entry_rptr_
.
push_back
(
0
);
entry_rptr_
.
push_back
(
0
);
...
...
nnvm/src/example/operator.cc
View file @
bd20bfd8
// Copyright (c) 2016 by Contributors
// Copyright (c) 2016 by Contributors
// This is an example on how we can register operator information to NNVM
// This is an example on how we can register operator information to NNVM
#include <nnvm/base.h>
#include <nnvm/op.h>
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <utility>
#include <utility>
namespace
myproject
{
using
nnvm
::
FListInputNames
;
using
nnvm
::
FListInputNames
;
using
nnvm
::
FMutateInput
;
using
nnvm
::
FMutateInput
;
using
nnvm
::
FInferShape
;
using
nnvm
::
NodeAttrs
;
using
nnvm
::
NodeAttrs
;
using
nnvm
::
TShape
;
using
nnvm
::
array_view
;
// simply return the shape as same
inline
bool
SameShape
(
const
NodeAttrs
&
attrs
,
array_view
<
TShape
*>
ishape
,
array_view
<
TShape
*>
oshape
)
{
if
(
ishape
.
size
()
==
0
||
ishape
[
0
]
->
ndim
()
==
0
)
return
false
;
for
(
TShape
*
pshape
:
oshape
)
{
*
pshape
=
*
ishape
[
0
];
}
for
(
TShape
*
pshape
:
ishape
)
{
*
pshape
=
*
ishape
[
0
];
}
return
true
;
}
NNVM_REGISTER_OP
(
add
)
NNVM_REGISTER_OP
(
add
)
.
describe
(
"add two data together"
)
.
describe
(
"add two data together"
)
.
set_num_inputs
(
2
);
.
set_num_inputs
(
2
)
.
attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
);
NNVM_REGISTER_OP
(
__add_symbol__
)
NNVM_REGISTER_OP
(
__add_symbol__
)
.
describe
(
"Alias of add"
)
.
describe
(
"Alias of add"
)
...
@@ -20,7 +42,8 @@ NNVM_REGISTER_OP(__add_symbol__)
...
@@ -20,7 +42,8 @@ NNVM_REGISTER_OP(__add_symbol__)
NNVM_REGISTER_OP
(
exp
)
NNVM_REGISTER_OP
(
exp
)
.
describe
(
"take exponmential"
)
.
describe
(
"take exponmential"
)
.
set_num_inputs
(
1
)
.
set_num_inputs
(
1
)
.
attr
(
"inplace_pair"
,
std
::
make_pair
(
0
,
0
));
.
attr
(
"inplace_pair"
,
std
::
make_pair
(
0
,
0
))
.
attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
);
NNVM_REGISTER_OP
(
conv2d
)
NNVM_REGISTER_OP
(
conv2d
)
...
@@ -39,3 +62,5 @@ NNVM_REGISTER_OP(assign)
...
@@ -39,3 +62,5 @@ NNVM_REGISTER_OP(assign)
.
attr
<
FMutateInput
>
(
"FMutateInput"
,
[](
const
NodeAttrs
&
attrs
,
uint32_t
index
)
{
.
attr
<
FMutateInput
>
(
"FMutateInput"
,
[](
const
NodeAttrs
&
attrs
,
uint32_t
index
)
{
return
index
==
0
;
return
index
==
0
;
});
});
}
// namespace myproject
nnvm/src/pass/infer_shape.cc
0 → 100644
View file @
bd20bfd8
/*!
* Copyright (c) 2016 by Contributors
* \file infer_shape.cc
* \brief Inference the shapes given
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
namespace
nnvm
{
namespace
pass
{
Graph
InferShape
(
const
Graph
&
src
)
{
Graph
ret
=
src
;
const
IndexedGraph
&
idx
=
ret
.
indexed_graph
();
static
auto
&
finfer_shape
=
Op
::
GetAttr
<
FInferShape
>
(
"FInferShape"
);
// reshape shape vector
ShapeVector
rshape
(
idx
.
num_node_entries
());
// temp space for shape inference.
std
::
vector
<
TShape
*>
ishape
,
oshape
;
// number of completed nodes
size_t
num_known
=
0
;
for
(
uint32_t
nid
=
0
;
nid
<
idx
.
num_nodes
();
++
nid
)
{
const
auto
&
inode
=
idx
[
nid
];
if
(
inode
.
source
->
is_variable
())
continue
;
ishape
.
resize
(
inode
.
inputs
.
size
());
for
(
uint32_t
i
=
0
;
i
<
ishape
.
size
();
++
i
)
{
ishape
[
i
]
=
&
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])];
}
oshape
.
resize
(
inode
.
source
->
num_outputs
());
for
(
uint32_t
i
=
0
;
i
<
oshape
.
size
();
++
i
)
{
oshape
[
i
]
=
&
rshape
[
idx
.
entry_id
(
nid
,
i
)];
}
if
(
finfer_shape
.
count
(
inode
.
source
->
op
))
{
num_known
+=
finfer_shape
[
inode
.
source
->
op
](
inode
.
source
->
attrs
,
ishape
,
oshape
);
}
}
// set the shapes
ret
.
attrs
[
"shape"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
rshape
));
// number of nodes who knows the shape.
ret
.
attrs
[
"shape_num_known_nodes"
]
=
std
::
make_shared
<
any
>
(
num_known
);
return
ret
;
}
}
// namespace pass
}
// namespace nnvm
nnvm/src/pass/order_mutation.cc
View file @
bd20bfd8
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <nnvm/op_attr_types.h>
#include <nnvm/op_attr_types.h>
namespace
nnvm
{
namespace
nnvm
{
namespace
pass
{
template
<
typename
T
>
template
<
typename
T
>
inline
T
get_with_default
(
const
std
::
unordered_map
<
Node
*
,
T
>
&
map
,
inline
T
get_with_default
(
const
std
::
unordered_map
<
Node
*
,
T
>
&
map
,
...
@@ -139,4 +140,5 @@ NNVM_REGISTER_PASS(OrderMutation)
...
@@ -139,4 +140,5 @@ NNVM_REGISTER_PASS(OrderMutation)
.
set_body
(
OrderMutation
)
.
set_body
(
OrderMutation
)
.
set_change_graph
(
true
);
.
set_change_graph
(
true
);
}
// namespace pass
}
// namespace nnvm
}
// namespace nnvm
nnvm/src/pass/saveload_json.cc
View file @
bd20bfd8
...
@@ -120,6 +120,7 @@ struct JSONNode {
...
@@ -120,6 +120,7 @@ struct JSONNode {
struct
JSONGraph
{
struct
JSONGraph
{
std
::
vector
<
JSONNode
>
nodes
;
std
::
vector
<
JSONNode
>
nodes
;
std
::
vector
<
uint32_t
>
arg_nodes
;
std
::
vector
<
uint32_t
>
arg_nodes
;
std
::
vector
<
uint32_t
>
node_row_ptr
;
std
::
vector
<
JSONNode
::
Entry
>
heads
;
std
::
vector
<
JSONNode
::
Entry
>
heads
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
const
any
>
>
attrs
;
std
::
unordered_map
<
std
::
string
,
std
::
shared_ptr
<
const
any
>
>
attrs
;
...
@@ -127,6 +128,7 @@ struct JSONGraph {
...
@@ -127,6 +128,7 @@ struct JSONGraph {
writer
->
BeginObject
();
writer
->
BeginObject
();
writer
->
WriteObjectKeyValue
(
"nodes"
,
nodes
);
writer
->
WriteObjectKeyValue
(
"nodes"
,
nodes
);
writer
->
WriteObjectKeyValue
(
"arg_nodes"
,
arg_nodes
);
writer
->
WriteObjectKeyValue
(
"arg_nodes"
,
arg_nodes
);
writer
->
WriteObjectKeyValue
(
"node_row_ptr"
,
node_row_ptr
);
writer
->
WriteObjectKeyValue
(
"heads"
,
heads
);
writer
->
WriteObjectKeyValue
(
"heads"
,
heads
);
if
(
attrs
.
size
()
!=
0
)
{
if
(
attrs
.
size
()
!=
0
)
{
writer
->
WriteObjectKeyValue
(
"attrs"
,
attrs
);
writer
->
WriteObjectKeyValue
(
"attrs"
,
attrs
);
...
@@ -140,6 +142,7 @@ struct JSONGraph {
...
@@ -140,6 +142,7 @@ struct JSONGraph {
helper
.
DeclareField
(
"nodes"
,
&
nodes
);
helper
.
DeclareField
(
"nodes"
,
&
nodes
);
helper
.
DeclareField
(
"arg_nodes"
,
&
arg_nodes
);
helper
.
DeclareField
(
"arg_nodes"
,
&
arg_nodes
);
helper
.
DeclareField
(
"heads"
,
&
heads
);
helper
.
DeclareField
(
"heads"
,
&
heads
);
helper
.
DeclareOptionalField
(
"node_row_ptr"
,
&
node_row_ptr
);
helper
.
DeclareOptionalField
(
"attrs"
,
&
attrs
);
helper
.
DeclareOptionalField
(
"attrs"
,
&
attrs
);
helper
.
ReadAllFields
(
reader
);
helper
.
ReadAllFields
(
reader
);
}
}
...
@@ -188,6 +191,7 @@ Graph LoadJSON(const Graph& src) {
...
@@ -188,6 +191,7 @@ Graph LoadJSON(const Graph& src) {
Graph
SaveJSON
(
const
Graph
&
src
)
{
Graph
SaveJSON
(
const
Graph
&
src
)
{
JSONGraph
jgraph
;
JSONGraph
jgraph
;
std
::
unordered_map
<
Node
*
,
uint32_t
>
node2index
;
std
::
unordered_map
<
Node
*
,
uint32_t
>
node2index
;
jgraph
.
node_row_ptr
.
push_back
(
0
);
DFSVisit
(
src
.
outputs
,
[
&
node2index
,
&
jgraph
](
const
NodePtr
&
n
)
{
DFSVisit
(
src
.
outputs
,
[
&
node2index
,
&
jgraph
](
const
NodePtr
&
n
)
{
uint32_t
nid
=
static_cast
<
uint32_t
>
(
jgraph
.
nodes
.
size
());
uint32_t
nid
=
static_cast
<
uint32_t
>
(
jgraph
.
nodes
.
size
());
node2index
[
n
.
get
()]
=
nid
;
node2index
[
n
.
get
()]
=
nid
;
...
@@ -204,6 +208,8 @@ Graph SaveJSON(const Graph& src) {
...
@@ -204,6 +208,8 @@ Graph SaveJSON(const Graph& src) {
for
(
const
NodePtr
&
c
:
n
->
control_deps
)
{
for
(
const
NodePtr
&
c
:
n
->
control_deps
)
{
jnode
.
control_deps
.
push_back
(
node2index
.
at
(
c
.
get
()));
jnode
.
control_deps
.
push_back
(
node2index
.
at
(
c
.
get
()));
}
}
jgraph
.
node_row_ptr
.
push_back
(
jgraph
.
node_row_ptr
.
back
()
+
n
->
num_outputs
());
jgraph
.
nodes
.
emplace_back
(
std
::
move
(
jnode
));
jgraph
.
nodes
.
emplace_back
(
std
::
move
(
jnode
));
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment