Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
8ffa4ac3
Commit
8ffa4ac3
authored
Jul 19, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Pass] enable infer type (#17)
parent
0081ad9a
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
152 additions
and
22 deletions
+152
-22
nnvm/include/nnvm/graph_attr_types.h
+15
-0
nnvm/include/nnvm/op_attr_types.h
+19
-9
nnvm/include/nnvm/pass_functions.h
+20
-0
nnvm/src/example/operator.cc
+23
-0
nnvm/src/pass/infer_shape_type.cc
+47
-13
nnvm/tests/python/test_graph.py
+28
-0
No files found.
nnvm/include/nnvm/graph_attr_types.h
View file @
8ffa4ac3
...
@@ -39,6 +39,21 @@ using JSONString = std::string;
...
@@ -39,6 +39,21 @@ using JSONString = std::string;
*/
*/
using
ShapeVector
=
std
::
vector
<
TShape
>
;
using
ShapeVector
=
std
::
vector
<
TShape
>
;
/*!
* \brief The result holder of type of each NodeEntry in the graph.
* \note Stored under graph.attrs["dtype"], provided by Pass "InferType"
*
* \code
* Graph g = ApplyPass(src_graph, {"InferType"});
* const DTypeVector& types = g.GetAttr<ShapeVector>("dtype");
* // get shape by entry id
* int entry_type = dtypes[g.indexed_graph().entry_id(my_entry)];
* \endcode
*
* \sa FInferType
*/
using
DTypeVector
=
std
::
vector
<
int
>
;
}
// namespace nnvm
}
// namespace nnvm
#endif // NNVM_GRAPH_ATTR_TYPES_H_
#endif // NNVM_GRAPH_ATTR_TYPES_H_
nnvm/include/nnvm/op_attr_types.h
View file @
8ffa4ac3
...
@@ -51,24 +51,34 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
...
@@ -51,24 +51,34 @@ using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs
using
FMutateInput
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
uint32_t
index
)
>
;
using
FMutateInput
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
uint32_t
index
)
>
;
/*!
/*!
* \brief Inference function of certain type.
* \tparam AttrType The type of the attribute to be infered.
* \return whether all attributes are inferred.
*/
template
<
typename
AttrType
>
using
FInferNodeEntryAttr
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
array_view
<
AttrType
*>
in_attrs
,
array_view
<
AttrType
*>
out_attrs
)
>
;
/*!
* \brief Shape inference function.
* \brief Shape inference function.
* Update the shapes given the input shape information.
* Update the shapes given the input shape information.
* TShape.ndim() == 0 means the shape is still unknown.
* TShape.ndim() == 0 means the shape is still unknown.
*
*
* \param attrs The attributes of the node.
* \param in_shapes Array of shapes from the inputs.
* \param out_shapes Array of shapes from the outputs.
*
* \return Whether all the shapes are known.
*
* \note Register under "FInferShape",
* \note Register under "FInferShape",
* by default do not update any shapes.
* by default do not update any shapes.
*
*
* FInferShape is needed by shape inference
* FInferShape is needed by shape inference
*/
*/
using
FInferShape
=
std
::
function
<
bool
(
const
NodeAttrs
&
attrs
,
using
FInferShape
=
FInferNodeEntryAttr
<
TShape
>
;
array_view
<
TShape
*>
in_shapes
,
array_view
<
TShape
*>
out_shapes
)
>
;
/*!
* \brief Type inference function.
* Update the type given the known type information.
*
* \note Register under "FInferType",
* by default set all the output types to 0.
*/
using
FInferType
=
FInferNodeEntryAttr
<
int
>
;
}
// namespace nnvm
}
// namespace nnvm
...
...
nnvm/include/nnvm/pass_functions.h
View file @
8ffa4ac3
...
@@ -71,6 +71,26 @@ inline Graph InferShape(Graph graph,
...
@@ -71,6 +71,26 @@ inline Graph InferShape(Graph graph,
return
ApplyPass
(
std
::
move
(
graph
),
{
"InferShape"
});
return
ApplyPass
(
std
::
move
(
graph
),
{
"InferShape"
});
}
}
/*!
* \brief Infer types in the graph given the information.
* \param graph source graph
* \param shape_args The shapes of aruguments to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
*/
inline
Graph
InferType
(
Graph
graph
,
DTypeVector
type_args
=
{},
std
::
string
type_attr_key
=
""
)
{
if
(
type_args
.
size
()
!=
0
)
{
graph
.
attrs
[
"type_args"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
type_args
));
}
if
(
type_attr_key
.
length
()
!=
0
)
{
graph
.
attrs
[
"type_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
type_attr_key
));
}
return
ApplyPass
(
std
::
move
(
graph
),
{
"InferType"
});
}
}
// namespace pass
}
// namespace pass
}
// namespace nnvm
}
// namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
#endif // NNVM_PASS_FUNCTIONS_H_
nnvm/src/example/operator.cc
View file @
8ffa4ac3
...
@@ -13,6 +13,7 @@ namespace myproject {
...
@@ -13,6 +13,7 @@ namespace myproject {
using
nnvm
::
FListInputNames
;
using
nnvm
::
FListInputNames
;
using
nnvm
::
FMutateInput
;
using
nnvm
::
FMutateInput
;
using
nnvm
::
FInferShape
;
using
nnvm
::
FInferShape
;
using
nnvm
::
FInferType
;
using
nnvm
::
NodeAttrs
;
using
nnvm
::
NodeAttrs
;
using
nnvm
::
TShape
;
using
nnvm
::
TShape
;
using
nnvm
::
array_view
;
using
nnvm
::
array_view
;
...
@@ -56,6 +57,28 @@ NNVM_REGISTER_OP(reshape)
...
@@ -56,6 +57,28 @@ NNVM_REGISTER_OP(reshape)
return
true
;
return
true
;
});
});
NNVM_REGISTER_OP
(
cast
)
.
describe
(
"cast source type to target"
)
.
set_num_inputs
(
1
)
.
set_attr_parser
(
[](
NodeAttrs
*
attrs
)
{
// parse attr parser to get target attribute
int
dtype
;
std
::
istringstream
is
(
attrs
->
dict
.
at
(
"dtype"
));
CHECK
(
is
>>
dtype
);
attrs
->
parsed
=
std
::
move
(
dtype
);
})
.
attr
<
FInferShape
>
(
"FInferShape"
,
SameShape
)
.
attr
<
FInferType
>
(
"FInferType"
,
[](
const
NodeAttrs
&
attrs
,
array_view
<
int
*>
itype
,
array_view
<
int
*>
otype
)
{
*
otype
[
0
]
=
nnvm
::
get
<
int
>
(
attrs
.
parsed
);
return
true
;
});
NNVM_REGISTER_OP
(
add
)
NNVM_REGISTER_OP
(
add
)
.
describe
(
"add two data together"
)
.
describe
(
"add two data together"
)
.
set_num_inputs
(
2
)
.
set_num_inputs
(
2
)
...
...
nnvm/src/pass/infer_shape.cc
→
nnvm/src/pass/infer_shape
_type
.cc
View file @
8ffa4ac3
/*!
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2016 by Contributors
* \file infer_shape.cc
* \file infer_shape.cc
* \brief Inference the shapes given
* \brief Inference the shapes given
existin information.
*/
*/
#include <nnvm/pass.h>
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/op_attr_types.h>
...
@@ -10,14 +10,23 @@
...
@@ -10,14 +10,23 @@
namespace
nnvm
{
namespace
nnvm
{
namespace
pass
{
namespace
pass
{
Graph
InferShape
(
Graph
ret
)
{
template
<
typename
AttrType
>
Graph
InferAttr
(
Graph
&&
ret
,
const
AttrType
def_value
,
const
char
*
infer_name
,
const
char
*
arg_name
,
const
char
*
attr_key_name
,
const
char
*
attr_name
,
const
char
*
known_name
)
{
using
AttrVector
=
std
::
vector
<
AttrType
>
;
const
IndexedGraph
&
idx
=
ret
.
indexed_graph
();
const
IndexedGraph
&
idx
=
ret
.
indexed_graph
();
static
auto
&
finfer_shape
=
Op
::
GetAttr
<
FInferShape
>
(
"FInferShape"
);
static
auto
&
finfer_shape
=
Op
::
GetAttr
<
FInferNodeEntryAttr
<
AttrType
>
>
(
infer_name
);
// reshape shape vector
// reshape shape vector
ShapeVector
rshape
(
idx
.
num_node_entries
()
);
AttrVector
rshape
(
idx
.
num_node_entries
(),
def_value
);
if
(
ret
.
attrs
.
count
(
"shape_args"
)
!=
0
)
{
if
(
ret
.
attrs
.
count
(
arg_name
)
!=
0
)
{
const
ShapeVector
&
shape_args
=
ret
.
GetAttr
<
ShapeVector
>
(
"shape_args"
);
const
AttrVector
&
shape_args
=
ret
.
GetAttr
<
AttrVector
>
(
arg_name
);
CHECK_LE
(
shape_args
.
size
(),
idx
.
arg_nodes
().
size
())
CHECK_LE
(
shape_args
.
size
(),
idx
.
arg_nodes
().
size
())
<<
"shape args is more than number of arguments"
;
<<
"shape args is more than number of arguments"
;
for
(
size_t
i
=
0
;
i
<
shape_args
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
shape_args
.
size
();
++
i
)
{
...
@@ -25,12 +34,12 @@ Graph InferShape(Graph ret) {
...
@@ -25,12 +34,12 @@ Graph InferShape(Graph ret) {
}
}
}
}
std
::
string
shape_attr_key
;
std
::
string
shape_attr_key
;
if
(
ret
.
attrs
.
count
(
"shape_attr_key"
)
!=
0
)
{
if
(
ret
.
attrs
.
count
(
attr_key_name
)
!=
0
)
{
shape_attr_key
=
ret
.
GetAttr
<
std
::
string
>
(
"shape_attr_key"
);
shape_attr_key
=
ret
.
GetAttr
<
std
::
string
>
(
attr_key_name
);
}
}
// temp space for shape inference.
// temp space for shape inference.
std
::
vector
<
TSha
pe
*>
ishape
,
oshape
;
std
::
vector
<
AttrTy
pe
*>
ishape
,
oshape
;
// number of completed nodes
// number of completed nodes
size_t
num_known
=
0
;
size_t
num_known
=
0
;
for
(
uint32_t
nid
=
0
;
nid
<
idx
.
num_nodes
();
++
nid
)
{
for
(
uint32_t
nid
=
0
;
nid
<
idx
.
num_nodes
();
++
nid
)
{
...
@@ -41,7 +50,7 @@ Graph InferShape(Graph ret) {
...
@@ -41,7 +50,7 @@ Graph InferShape(Graph ret) {
if
(
it
!=
inode
.
source
->
attrs
.
dict
.
end
())
{
if
(
it
!=
inode
.
source
->
attrs
.
dict
.
end
())
{
CHECK_EQ
(
inode
.
source
->
num_outputs
(),
1
);
CHECK_EQ
(
inode
.
source
->
num_outputs
(),
1
);
std
::
istringstream
is
(
it
->
second
);
std
::
istringstream
is
(
it
->
second
);
CHECK
(
is
>>
rshape
[
idx
.
entry_id
(
nid
,
0
)])
<<
"Invalid
shape
attribute"
;
CHECK
(
is
>>
rshape
[
idx
.
entry_id
(
nid
,
0
)])
<<
"Invalid attribute"
;
}
}
}
}
continue
;
continue
;
...
@@ -60,19 +69,44 @@ Graph InferShape(Graph ret) {
...
@@ -60,19 +69,44 @@ Graph InferShape(Graph ret) {
}
}
}
}
// set the shapes
// set the shapes
ret
.
attrs
[
"shape"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
rshape
));
ret
.
attrs
[
attr_name
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
rshape
));
// number of nodes who knows the shape.
// number of nodes who knows the shape.
ret
.
attrs
[
"shape_num_known_nodes"
]
=
std
::
make_shared
<
any
>
(
num_known
);
ret
.
attrs
[
known_name
]
=
std
::
make_shared
<
any
>
(
num_known
);
return
ret
;
return
ret
;
}
}
NNVM_REGISTER_PASS
(
InferShape
)
NNVM_REGISTER_PASS
(
InferShape
)
.
describe
(
"Infer the shape of each node entries."
)
.
describe
(
"Infer the shape of each node entries."
)
.
set_body
(
InferShape
)
.
set_body
([](
Graph
ret
)
{
return
InferAttr
<
TShape
>
(
std
::
move
(
ret
),
TShape
(),
"FInferShape"
,
"shape_args"
,
"shape_attr_key"
,
"shape"
,
"shape_num_known_nodes"
);
})
.
set_change_graph
(
false
)
.
set_change_graph
(
false
)
.
provide_graph_attr
(
"shape"
);
.
provide_graph_attr
(
"shape"
);
NNVM_REGISTER_PASS
(
InferType
)
.
describe
(
"Infer the dtype of each node entries."
)
.
set_body
([](
Graph
ret
)
{
return
InferAttr
<
int
>
(
std
::
move
(
ret
),
0
,
"FInferType"
,
"dtype_args"
,
"dtype_attr_key"
,
"dtype"
,
"dtype_num_known_nodes"
);
})
.
set_change_graph
(
false
)
.
provide_graph_attr
(
"dtype"
);
DMLC_JSON_ENABLE_ANY
(
ShapeVector
,
list_shape
);
DMLC_JSON_ENABLE_ANY
(
ShapeVector
,
list_shape
);
DMLC_JSON_ENABLE_ANY
(
DTypeVector
,
list_int
);
}
// namespace pass
}
// namespace pass
}
// namespace nnvm
}
// namespace nnvm
nnvm/tests/python/test_graph.py
View file @
8ffa4ac3
...
@@ -49,9 +49,37 @@ def test_infer_shape():
...
@@ -49,9 +49,37 @@ def test_infer_shape():
assert
g
.
json_attr
(
'shape'
)[
jnode_row_ptr
[
nindex
[
"reshape1"
]]]
==
[
2
,
4
]
assert
g
.
json_attr
(
'shape'
)[
jnode_row_ptr
[
nindex
[
"reshape1"
]]]
==
[
2
,
4
]
assert
g
.
json_attr
(
'shape'
)[
jnode_row_ptr
[
nindex
[
"add1"
]]]
==
[
4
,
2
]
assert
g
.
json_attr
(
'shape'
)[
jnode_row_ptr
[
nindex
[
"add1"
]]]
==
[
4
,
2
]
def
test_infer_shape
():
x
=
sym
.
Variable
(
'x'
,
shape
=
(
4
,
2
))
y
=
sym
.
add
(
x
,
x
,
name
=
'add1'
)
y
=
sym
.
reshape
(
y
,
target
=
(
2
,
4
),
name
=
"reshape1"
)
g
=
graph
.
create
(
y
)
g
.
_set_json_attr
(
"shape_attr_key"
,
"shape"
)
g
=
g
.
apply
(
'InferShape'
)
jgraph
=
json
.
loads
(
g
.
apply
(
'SaveJSON'
)
.
json_attr
(
'json'
))
jnodes
=
jgraph
[
'nodes'
]
jnode_row_ptr
=
jgraph
[
'node_row_ptr'
]
nindex
=
{
n
[
'name'
]:
i
for
i
,
n
in
enumerate
(
jnodes
)}
assert
g
.
json_attr
(
'shape'
)[
jnode_row_ptr
[
nindex
[
"reshape1"
]]]
==
[
2
,
4
]
assert
g
.
json_attr
(
'shape'
)[
jnode_row_ptr
[
nindex
[
"add1"
]]]
==
[
4
,
2
]
def
test_infer_type
():
x
=
sym
.
Variable
(
'x'
)
y
=
sym
.
add
(
x
,
x
,
name
=
'add1'
)
y
=
sym
.
cast
(
y
,
dtype
=
1
,
name
=
"cast1"
)
g
=
graph
.
create
(
y
)
g
=
g
.
apply
(
'InferType'
)
jgraph
=
json
.
loads
(
g
.
apply
(
'SaveJSON'
)
.
json_attr
(
'json'
))
jnodes
=
jgraph
[
'nodes'
]
jnode_row_ptr
=
jgraph
[
'node_row_ptr'
]
nindex
=
{
n
[
'name'
]:
i
for
i
,
n
in
enumerate
(
jnodes
)}
assert
g
.
json_attr
(
'dtype'
)[
jnode_row_ptr
[
nindex
[
"cast1"
]]]
==
1
assert
g
.
json_attr
(
'dtype'
)[
jnode_row_ptr
[
nindex
[
"add1"
]]]
==
0
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_order_mutation_pass
()
test_order_mutation_pass
()
test_graph_json_attr
()
test_graph_json_attr
()
test_json_pass
()
test_json_pass
()
test_infer_shape
()
test_infer_shape
()
test_infer_type
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment