Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2c0b79ae
Commit
2c0b79ae
authored
Sep 11, 2016
by
Minjie Wang
Committed by
Tianqi Chen
May 29, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ApplyPass -> ApplyPasses; Refactored infer pass; (#43)
* ApplyPass -> ApplyPasses; Refactored infer pass; * lint fix
parent
24f1999c
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
67 additions
and
49 deletions
+67
-49
nnvm/include/nnvm/c_api.h
+5
-5
nnvm/include/nnvm/graph.h
+3
-3
nnvm/include/nnvm/graph_attr_types.h
+5
-5
nnvm/include/nnvm/pass.h
+13
-2
nnvm/include/nnvm/pass_functions.h
+7
-7
nnvm/python/nnvm/graph.py
+1
-1
nnvm/src/c_api/c_api_graph.cc
+5
-5
nnvm/src/core/pass.cc
+2
-2
nnvm/src/pass/infer_shape_type.cc
+26
-19
No files found.
nnvm/include/nnvm/c_api.h
View file @
2c0b79ae
...
@@ -329,16 +329,16 @@ NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
...
@@ -329,16 +329,16 @@ NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
const
char
*
key
,
const
char
*
key
,
SymbolHandle
list
);
SymbolHandle
list
);
/*!
/*!
* \brief Apply pass on the src graph.
* \brief Apply pass
es
on the src graph.
* \param src The source graph handle.
* \param src The source graph handle.
* \param num_pass The number of pass to be applied.
* \param num_pass The number of pass to be applied.
* \param pass_names The names of the pass.
* \param pass_names The names of the pass.
* \param dst The result graph.
* \param dst The result graph.
* \return 0 when success, -1 when failure happens
* \return 0 when success, -1 when failure happens
*/
*/
NNVM_DLL
int
NNGraphApplyPass
(
GraphHandle
src
,
NNVM_DLL
int
NNGraphApplyPass
es
(
GraphHandle
src
,
nn_uint
num_pass
,
nn_uint
num_pass
,
const
char
**
pass_names
,
const
char
**
pass_names
,
GraphHandle
*
dst
);
GraphHandle
*
dst
);
#endif // NNVM_C_API_H_
#endif // NNVM_C_API_H_
nnvm/include/nnvm/graph.h
View file @
2c0b79ae
...
@@ -179,11 +179,11 @@ class IndexedGraph {
...
@@ -179,11 +179,11 @@ class IndexedGraph {
* \param other The source graph.
* \param other The source graph.
*/
*/
explicit
IndexedGraph
(
const
Graph
&
other
);
explicit
IndexedGraph
(
const
Graph
&
other
);
//
n
ode pointers in CSR structure.
//
N
ode pointers in CSR structure.
std
::
vector
<
Node
>
nodes_
;
std
::
vector
<
Node
>
nodes_
;
//
index all to input nodes
//
Index to all input nodes.
std
::
vector
<
uint32_t
>
input_nodes_
;
std
::
vector
<
uint32_t
>
input_nodes_
;
//
index to mutable input nodes
//
Index to all mutable input nodes.
std
::
unordered_set
<
uint32_t
>
mutable_input_nodes_
;
std
::
unordered_set
<
uint32_t
>
mutable_input_nodes_
;
// space to store the outputs entries
// space to store the outputs entries
std
::
vector
<
NodeEntry
>
outputs_
;
std
::
vector
<
NodeEntry
>
outputs_
;
...
...
nnvm/include/nnvm/graph_attr_types.h
View file @
2c0b79ae
...
@@ -18,7 +18,7 @@ namespace nnvm {
...
@@ -18,7 +18,7 @@ namespace nnvm {
* \note Stored under ret.attrs["json"], provided by Pass "SaveJSON"
* \note Stored under ret.attrs["json"], provided by Pass "SaveJSON"
* \code
* \code
* Graph ret = ApplyPass(src_graph,
{"SaveJSON"}
);
* Graph ret = ApplyPass(src_graph,
"SaveJSON"
);
* const JSONString& json = ret.GetAttr<JSONString>("shape");
* const JSONString& json = ret.GetAttr<JSONString>("shape");
* \endcode
* \endcode
*/
*/
...
@@ -29,7 +29,7 @@ using JSONString = std::string;
...
@@ -29,7 +29,7 @@ using JSONString = std::string;
* \note Stored under graph.attrs["shape"], provided by Pass "InferShape"
* \note Stored under graph.attrs["shape"], provided by Pass "InferShape"
*
*
* \code
* \code
* Graph g = ApplyPass(src_graph,
{"InferShape"}
);
* Graph g = ApplyPass(src_graph,
"InferShape"
);
* const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape");
* const ShapeVector& shapes = g.GetAttr<ShapeVector>("shape");
* // get shape by entry id
* // get shape by entry id
* TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)];
* TShape entry_shape = shapes[g.indexed_graph().entry_id(my_entry)];
...
@@ -44,7 +44,7 @@ using ShapeVector = std::vector<TShape>;
...
@@ -44,7 +44,7 @@ using ShapeVector = std::vector<TShape>;
* \note Stored under graph.attrs["dtype"], provided by Pass "InferType"
* \note Stored under graph.attrs["dtype"], provided by Pass "InferType"
*
*
* \code
* \code
* Graph g = ApplyPass(src_graph,
{"InferType"}
);
* Graph g = ApplyPass(src_graph,
"InferType"
);
* const DTypeVector& types = g.GetAttr<DTypeVector>("dtype");
* const DTypeVector& types = g.GetAttr<DTypeVector>("dtype");
* // get shape by entry id
* // get shape by entry id
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
...
@@ -59,7 +59,7 @@ using DTypeVector = std::vector<int>;
...
@@ -59,7 +59,7 @@ using DTypeVector = std::vector<int>;
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice"
* \note Stored under graph.attrs["device"], provided by Pass "PlaceDevice"
*
*
* \code
* \code
* Graph g = ApplyPass(src_graph,
{"PlaceDevice"}
);
* Graph g = ApplyPass(src_graph,
"PlaceDevice"
);
* const &device = g.GetAttr<DeviceVector>("device");
* const &device = g.GetAttr<DeviceVector>("device");
* // get device by node_id
* // get device by node_id
* int device_type = device[g.indexed_graph().node_id(my_node)];
* int device_type = device[g.indexed_graph().node_id(my_node)];
...
@@ -83,7 +83,7 @@ using DeviceAssignMap = std::unordered_map<std::string, int>;
...
@@ -83,7 +83,7 @@ using DeviceAssignMap = std::unordered_map<std::string, int>;
* If the storage id is -1 then the storage is not assigned.
* If the storage id is -1 then the storage is not assigned.
*
*
* \code
* \code
* Graph g = ApplyPass(src_graph,
{"PlanMemory"}
);
* Graph g = ApplyPass(src_graph,
"PlanMemory"
);
* const &storage = g.GetAttr<StorageVector>("storage");
* const &storage = g.GetAttr<StorageVector>("storage");
* // get storage id by entry
* // get storage id by entry
* int storage_id = storage[g.indexed_graph().entry_id(my_entry)];
* int storage_id = storage[g.indexed_graph().entry_id(my_entry)];
...
...
nnvm/include/nnvm/pass.h
View file @
2c0b79ae
...
@@ -29,11 +29,22 @@ typedef std::function<Graph (Graph src)> PassFunction;
...
@@ -29,11 +29,22 @@ typedef std::function<Graph (Graph src)> PassFunction;
/*!
/*!
* \brief Apply a series of pass transformations on the input graph.
* \brief Apply a series of pass transformations on the input graph.
* \param src The graph to be transformed.
* \param src The graph to be transformed.
* \param passes A list of pass names to be applied.
* \return The transformed graph
*/
Graph
ApplyPasses
(
Graph
src
,
const
std
::
vector
<
std
::
string
>&
passes
);
/*!
* \brief Apply one pass to the graph.
* \param src The graph to be transformed.
* \param pass The name of pass to be applied.
* \param pass The name of pass to be applied.
* \return The transformed graph.
* \return The transformed graph.
*/
*/
Graph
ApplyPass
(
Graph
src
,
inline
Graph
ApplyPass
(
Graph
src
,
const
std
::
string
&
pass
)
{
const
std
::
vector
<
std
::
string
>&
pass
);
return
ApplyPasses
(
src
,
{
pass
});
}
/*!
/*!
* \brief Registry entry for DataIterator factory functions.
* \brief Registry entry for DataIterator factory functions.
...
...
nnvm/include/nnvm/pass_functions.h
View file @
2c0b79ae
...
@@ -28,7 +28,7 @@ namespace pass {
...
@@ -28,7 +28,7 @@ namespace pass {
inline
Graph
LoadJSON
(
const
std
::
string
&
json_str
)
{
inline
Graph
LoadJSON
(
const
std
::
string
&
json_str
)
{
Graph
ret
;
Graph
ret
;
ret
.
attrs
[
"json"
]
=
std
::
make_shared
<
any
>
(
json_str
);
ret
.
attrs
[
"json"
]
=
std
::
make_shared
<
any
>
(
json_str
);
return
ApplyPass
(
ret
,
{
"LoadJSON"
}
);
return
ApplyPass
(
ret
,
"LoadJSON"
);
}
}
/*!
/*!
...
@@ -37,7 +37,7 @@ inline Graph LoadJSON(const std::string& json_str) {
...
@@ -37,7 +37,7 @@ inline Graph LoadJSON(const std::string& json_str) {
* \return The json string.
* \return The json string.
*/
*/
inline
std
::
string
SaveJSON
(
Graph
graph
)
{
inline
std
::
string
SaveJSON
(
Graph
graph
)
{
Graph
ret
=
ApplyPass
(
std
::
move
(
graph
),
{
"SaveJSON"
}
);
Graph
ret
=
ApplyPass
(
std
::
move
(
graph
),
"SaveJSON"
);
return
ret
.
GetAttr
<
std
::
string
>
(
"json"
);
return
ret
.
GetAttr
<
std
::
string
>
(
"json"
);
}
}
...
@@ -52,7 +52,7 @@ inline std::string SaveJSON(Graph graph) {
...
@@ -52,7 +52,7 @@ inline std::string SaveJSON(Graph graph) {
* \return A graph with proper control flow dependencies added.
* \return A graph with proper control flow dependencies added.
*/
*/
inline
Graph
OrderMutation
(
Graph
src
)
{
inline
Graph
OrderMutation
(
Graph
src
)
{
return
ApplyPass
(
std
::
move
(
src
),
{
"OrderMutation"
}
);
return
ApplyPass
(
std
::
move
(
src
),
"OrderMutation"
);
}
}
/*!
/*!
...
@@ -73,7 +73,7 @@ inline Graph InferShape(Graph graph,
...
@@ -73,7 +73,7 @@ inline Graph InferShape(Graph graph,
if
(
shape_attr_key
.
length
()
!=
0
)
{
if
(
shape_attr_key
.
length
()
!=
0
)
{
graph
.
attrs
[
"shape_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
shape_attr_key
));
graph
.
attrs
[
"shape_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
shape_attr_key
));
}
}
return
ApplyPass
(
std
::
move
(
graph
),
{
"InferShape"
}
);
return
ApplyPass
(
std
::
move
(
graph
),
"InferShape"
);
}
}
/*!
/*!
...
@@ -94,7 +94,7 @@ inline Graph InferType(Graph graph,
...
@@ -94,7 +94,7 @@ inline Graph InferType(Graph graph,
if
(
dtype_attr_key
.
length
()
!=
0
)
{
if
(
dtype_attr_key
.
length
()
!=
0
)
{
graph
.
attrs
[
"dtype_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
dtype_attr_key
));
graph
.
attrs
[
"dtype_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
dtype_attr_key
));
}
}
return
ApplyPass
(
std
::
move
(
graph
),
{
"InferType"
}
);
return
ApplyPass
(
std
::
move
(
graph
),
"InferType"
);
}
}
/*!
/*!
...
@@ -118,7 +118,7 @@ inline Graph PlaceDevice(Graph graph,
...
@@ -118,7 +118,7 @@ inline Graph PlaceDevice(Graph graph,
graph
.
attrs
[
"device_group_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
device_group_attr_key
));
graph
.
attrs
[
"device_group_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
device_group_attr_key
));
graph
.
attrs
[
"device_assign_map"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
device_assign_map
));
graph
.
attrs
[
"device_assign_map"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
device_assign_map
));
graph
.
attrs
[
"device_copy_op"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
device_copy_op
));
graph
.
attrs
[
"device_copy_op"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
device_copy_op
));
return
ApplyPass
(
std
::
move
(
graph
),
{
"PlaceDevice"
}
);
return
ApplyPass
(
std
::
move
(
graph
),
"PlaceDevice"
);
}
}
/*!
/*!
...
@@ -149,7 +149,7 @@ inline Graph Gradient(
...
@@ -149,7 +149,7 @@ inline Graph Gradient(
graph
.
attrs
[
"grad_mirror_fun"
]
=
std
::
make_shared
<
any
>
(
mirror_fun
);
graph
.
attrs
[
"grad_mirror_fun"
]
=
std
::
make_shared
<
any
>
(
mirror_fun
);
}
}
return
ApplyPass
(
std
::
move
(
graph
),
{
"Gradient"
}
);
return
ApplyPass
(
std
::
move
(
graph
),
"Gradient"
);
}
}
}
// namespace pass
}
// namespace pass
...
...
nnvm/python/nnvm/graph.py
View file @
2c0b79ae
...
@@ -113,7 +113,7 @@ class Graph(object):
...
@@ -113,7 +113,7 @@ class Graph(object):
cpass
=
c_array
(
ctypes
.
c_char_p
,
[
c_str
(
key
)
for
key
in
passes
])
cpass
=
c_array
(
ctypes
.
c_char_p
,
[
c_str
(
key
)
for
key
in
passes
])
ghandle
=
GraphHandle
()
ghandle
=
GraphHandle
()
npass
=
nn_uint
(
len
(
passes
))
npass
=
nn_uint
(
len
(
passes
))
check_call
(
_LIB
.
NNGraphApplyPass
(
self
.
handle
,
npass
,
cpass
,
ctypes
.
byref
(
ghandle
)))
check_call
(
_LIB
.
NNGraphApplyPass
es
(
self
.
handle
,
npass
,
cpass
,
ctypes
.
byref
(
ghandle
)))
return
Graph
(
ghandle
)
return
Graph
(
ghandle
)
...
...
nnvm/src/c_api/c_api_graph.cc
View file @
2c0b79ae
...
@@ -82,17 +82,17 @@ int NNGraphGetJSONAttr(GraphHandle handle,
...
@@ -82,17 +82,17 @@ int NNGraphGetJSONAttr(GraphHandle handle,
API_END
();
API_END
();
}
}
int
NNGraphApplyPass
(
GraphHandle
src
,
int
NNGraphApplyPass
es
(
GraphHandle
src
,
nn_uint
num_pass
,
nn_uint
num_pass
,
const
char
**
pass_names
,
const
char
**
pass_names
,
GraphHandle
*
dst
)
{
GraphHandle
*
dst
)
{
Graph
*
g
=
new
Graph
();
Graph
*
g
=
new
Graph
();
API_BEGIN
();
API_BEGIN
();
std
::
vector
<
std
::
string
>
vpass
;
std
::
vector
<
std
::
string
>
vpass
;
for
(
nn_uint
i
=
0
;
i
<
num_pass
;
++
i
)
{
for
(
nn_uint
i
=
0
;
i
<
num_pass
;
++
i
)
{
vpass
.
emplace_back
(
std
::
string
(
pass_names
[
i
]));
vpass
.
emplace_back
(
std
::
string
(
pass_names
[
i
]));
}
}
*
g
=
ApplyPass
(
*
static_cast
<
Graph
*>
(
src
),
vpass
);
*
g
=
ApplyPass
es
(
*
static_cast
<
Graph
*>
(
src
),
vpass
);
*
dst
=
g
;
*
dst
=
g
;
API_END_HANDLE_ERROR
(
delete
g
);
API_END_HANDLE_ERROR
(
delete
g
);
}
}
nnvm/src/core/pass.cc
View file @
2c0b79ae
...
@@ -22,8 +22,8 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) {
...
@@ -22,8 +22,8 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) {
return
nullptr
;
return
nullptr
;
}
}
Graph
ApplyPass
(
Graph
g
,
Graph
ApplyPass
es
(
Graph
g
,
const
std
::
vector
<
std
::
string
>&
pass
)
{
const
std
::
vector
<
std
::
string
>&
pass
)
{
std
::
vector
<
const
PassFunctionReg
*>
fpass
;
std
::
vector
<
const
PassFunctionReg
*>
fpass
;
for
(
auto
&
name
:
pass
)
{
for
(
auto
&
name
:
pass
)
{
auto
*
reg
=
dmlc
::
Registry
<
PassFunctionReg
>::
Find
(
name
);
auto
*
reg
=
dmlc
::
Registry
<
PassFunctionReg
>::
Find
(
name
);
...
...
nnvm/src/pass/infer_shape_type.cc
View file @
2c0b79ae
...
@@ -13,7 +13,7 @@ namespace {
...
@@ -13,7 +13,7 @@ namespace {
template
<
typename
AttrType
,
typename
IsNone
>
template
<
typename
AttrType
,
typename
IsNone
>
Graph
InferAttr
(
Graph
&&
ret
,
Graph
InferAttr
(
Graph
&&
ret
,
const
AttrType
def
_value
,
const
AttrType
def
ault_val
,
const
char
*
infer_name
,
const
char
*
infer_name
,
const
char
*
input_name
,
const
char
*
input_name
,
const
char
*
attr_key_name
,
const
char
*
attr_key_name
,
...
@@ -23,16 +23,16 @@ Graph InferAttr(Graph &&ret,
...
@@ -23,16 +23,16 @@ Graph InferAttr(Graph &&ret,
using
AttrVector
=
std
::
vector
<
AttrType
>
;
using
AttrVector
=
std
::
vector
<
AttrType
>
;
const
IndexedGraph
&
idx
=
ret
.
indexed_graph
();
const
IndexedGraph
&
idx
=
ret
.
indexed_graph
();
static
auto
&
finfer_shape
=
static
auto
&
finfer_shape
=
Op
::
GetAttr
<
FInferNodeEntryAttr
<
AttrType
>
>
(
infer_name
);
Op
::
GetAttr
<
FInferNodeEntryAttr
<
AttrType
>>
(
infer_name
);
static
auto
&
backward_map
=
static
auto
&
backward_map
=
Op
::
GetAttr
<
FBackwardOutToInIndex
>
(
"FBackwardOutToInIndex"
);
Op
::
GetAttr
<
FBackwardOutToInIndex
>
(
"FBackwardOutToInIndex"
);
// reshape shape vector
// reshape shape vector
AttrVector
rshape
(
idx
.
num_node_entries
(),
def
_value
);
AttrVector
rshape
(
idx
.
num_node_entries
(),
def
ault_val
);
if
(
ret
.
attrs
.
count
(
input_name
)
!=
0
)
{
if
(
ret
.
attrs
.
count
(
input_name
)
!=
0
)
{
const
AttrVector
&
shape_args
=
ret
.
GetAttr
<
AttrVector
>
(
input_name
);
const
AttrVector
&
shape_args
=
ret
.
GetAttr
<
AttrVector
>
(
input_name
);
CHECK_LE
(
shape_args
.
size
(),
idx
.
input_nodes
().
size
())
CHECK_LE
(
shape_args
.
size
(),
idx
.
input_nodes
().
size
())
<<
"
shape args is more than number of arguments
"
;
<<
"
More provided shapes than number of arguments.
"
;
for
(
size_t
i
=
0
;
i
<
shape_args
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
shape_args
.
size
();
++
i
)
{
rshape
[
idx
.
entry_id
(
idx
.
input_nodes
()[
i
],
0
)]
=
shape_args
[
i
];
rshape
[
idx
.
entry_id
(
idx
.
input_nodes
()[
i
],
0
)]
=
shape_args
[
i
];
}
}
...
@@ -46,36 +46,41 @@ Graph InferAttr(Graph &&ret,
...
@@ -46,36 +46,41 @@ Graph InferAttr(Graph &&ret,
ret
.
attrs
.
erase
(
attr_key_name
);
ret
.
attrs
.
erase
(
attr_key_name
);
}
}
//
t
emp space for shape inference.
//
T
emp space for shape inference.
std
::
vector
<
AttrType
>
ishape
,
oshape
;
std
::
vector
<
AttrType
>
ishape
,
oshape
;
// number of completed nodes
// number of completed nodes
size_t
num_unknown
=
0
;
size_t
num_unknown
=
0
;
for
(
uint32_t
nid
=
0
;
nid
<
idx
.
num_nodes
();
++
nid
)
{
for
(
uint32_t
nid
=
0
;
nid
<
idx
.
num_nodes
();
++
nid
)
{
const
auto
&
inode
=
idx
[
nid
];
const
auto
&
inode
=
idx
[
nid
];
uint32_t
num_inputs
=
inode
.
inputs
.
size
();
const
uint32_t
num_inputs
=
inode
.
inputs
.
size
();
uint32_t
num_outputs
=
inode
.
source
->
num_outputs
();
const
uint32_t
num_outputs
=
inode
.
source
->
num_outputs
();
if
(
inode
.
source
->
is_variable
())
{
if
(
inode
.
source
->
is_variable
())
{
if
(
shape_attr_key
.
length
()
!=
0
&&
fis_none
(
rshape
[
idx
.
entry_id
(
nid
,
0
)]))
{
// Variable node. No operator. Only one output entry.
CHECK
(
inode
.
source
->
op
()
==
nullptr
);
CHECK_EQ
(
num_outputs
,
1
);
const
uint32_t
out_ent_id
=
idx
.
entry_id
(
nid
,
0
);
if
(
shape_attr_key
.
length
()
!=
0
&&
fis_none
(
rshape
[
out_ent_id
]))
{
auto
it
=
inode
.
source
->
attrs
.
dict
.
find
(
shape_attr_key
);
auto
it
=
inode
.
source
->
attrs
.
dict
.
find
(
shape_attr_key
);
if
(
it
!=
inode
.
source
->
attrs
.
dict
.
end
())
{
if
(
it
!=
inode
.
source
->
attrs
.
dict
.
end
())
{
CHECK_EQ
(
num_outputs
,
1
);
std
::
istringstream
is
(
it
->
second
);
std
::
istringstream
is
(
it
->
second
);
CHECK
(
is
>>
rshape
[
idx
.
entry_id
(
nid
,
0
)
])
<<
"Invalid attribute"
;
CHECK
(
is
>>
rshape
[
out_ent_id
])
<<
"Invalid attribute"
;
}
}
}
}
continue
;
}
else
if
(
finfer_shape
.
count
(
inode
.
source
->
op
()))
{
}
// Forward operator inference.
if
(
finfer_shape
.
count
(
inode
.
source
->
op
()))
{
ishape
.
resize
(
num_inputs
,
default_val
);
ishape
.
resize
(
num_inputs
,
def_value
);
for
(
uint32_t
i
=
0
;
i
<
ishape
.
size
();
++
i
)
{
for
(
uint32_t
i
=
0
;
i
<
ishape
.
size
();
++
i
)
{
ishape
[
i
]
=
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])];
ishape
[
i
]
=
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])];
}
}
oshape
.
resize
(
num_outputs
,
def
_value
);
oshape
.
resize
(
num_outputs
,
def
ault_val
);
for
(
uint32_t
i
=
0
;
i
<
oshape
.
size
();
++
i
)
{
for
(
uint32_t
i
=
0
;
i
<
oshape
.
size
();
++
i
)
{
oshape
[
i
]
=
rshape
[
idx
.
entry_id
(
nid
,
i
)];
oshape
[
i
]
=
rshape
[
idx
.
entry_id
(
nid
,
i
)];
}
}
num_unknown
+=
// Call inference function of the operator.
!
(
finfer_shape
[
inode
.
source
->
op
()](
inode
.
source
->
attrs
,
&
ishape
,
&
oshape
));
bool
forward_known
=
finfer_shape
[
inode
.
source
->
op
()](
inode
.
source
->
attrs
,
&
ishape
,
&
oshape
);
num_unknown
+=
!
forward_known
;
// Save to the result map.
for
(
uint32_t
i
=
0
;
i
<
num_inputs
;
++
i
)
{
for
(
uint32_t
i
=
0
;
i
<
num_inputs
;
++
i
)
{
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])]
=
ishape
[
i
];
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])]
=
ishape
[
i
];
}
}
...
@@ -83,10 +88,12 @@ Graph InferAttr(Graph &&ret,
...
@@ -83,10 +88,12 @@ Graph InferAttr(Graph &&ret,
rshape
[
idx
.
entry_id
(
nid
,
i
)]
=
oshape
[
i
];
rshape
[
idx
.
entry_id
(
nid
,
i
)]
=
oshape
[
i
];
}
}
}
else
if
(
backward_map
.
count
(
inode
.
source
->
op
()))
{
}
else
if
(
backward_map
.
count
(
inode
.
source
->
op
()))
{
//
b
ackward operator inference.
//
B
ackward operator inference.
CHECK_GE
(
inode
.
control_deps
.
size
(),
1
)
CHECK_GE
(
inode
.
control_deps
.
size
(),
1
)
<<
"BackwardOp need to have control_deps to its forward op"
;
<<
"BackwardOp need to have control_deps to its forward op"
;
const
auto
&
fnode
=
idx
[
inode
.
control_deps
[
0
]];
const
IndexedGraph
::
Node
&
fnode
=
idx
[
inode
.
control_deps
[
0
]];
// Inference the outputs of backward operator (equal to the inputs
// of its corresponding forward operator).
std
::
vector
<
uint32_t
>
out_map
=
std
::
vector
<
uint32_t
>
out_map
=
backward_map
[
inode
.
source
->
op
()](
inode
.
source
->
attrs
);
backward_map
[
inode
.
source
->
op
()](
inode
.
source
->
attrs
);
bool
known
=
true
;
bool
known
=
true
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment