Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
0081ad9a
Commit
0081ad9a
authored
Jul 19, 2016
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Pass] Finish infershape testcase (#16)
parent
bd20bfd8
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
161 additions
and
14 deletions
+161
-14
nnvm/include/nnvm/pass.h
+2
-2
nnvm/include/nnvm/pass_functions.h
+76
-0
nnvm/src/core/pass.cc
+3
-6
nnvm/src/example/operator.cc
+26
-0
nnvm/src/pass/infer_shape.cc
+34
-3
nnvm/src/pass/order_mutation.cc
+1
-1
nnvm/src/pass/saveload_json.cc
+2
-2
nnvm/src/test_main.cc
+1
-0
nnvm/tests/python/test_graph.py
+16
-0
No files found.
nnvm/include/nnvm/pass.h
View file @
0081ad9a
...
@@ -23,7 +23,7 @@ namespace nnvm {
...
@@ -23,7 +23,7 @@ namespace nnvm {
* \param src The graph to be transformed.
* \param src The graph to be transformed.
* \return The generated graph.
* \return The generated graph.
*/
*/
typedef
std
::
function
<
Graph
(
const
Graph
&
src
)
>
PassFunction
;
typedef
std
::
function
<
Graph
(
Graph
src
)
>
PassFunction
;
/*!
/*!
* \brief Apply a series of pass transformations on g.
* \brief Apply a series of pass transformations on g.
...
@@ -31,7 +31,7 @@ typedef std::function<Graph (const Graph& src)> PassFunction;
...
@@ -31,7 +31,7 @@ typedef std::function<Graph (const Graph& src)> PassFunction;
* \param pass The name of pass to be applied.
* \param pass The name of pass to be applied.
* \return The transformed graph
* \return The transformed graph
*/
*/
Graph
ApplyPass
(
const
Graph
&
src
,
Graph
ApplyPass
(
Graph
src
,
const
std
::
vector
<
std
::
string
>&
pass
);
const
std
::
vector
<
std
::
string
>&
pass
);
/*!
/*!
...
...
nnvm/include/nnvm/pass_functions.h
0 → 100644
View file @
0081ad9a
/*!
* Copyright (c) 2016 by Contributors
* \file pass_functions.h
* \brief Pass functions that simply redirect the calls to ApplyPass
*
* This file serves as documentation on how to use functions implemented in "src/pass".
* It is totally optional to add these functions when you add a new pass, since
* ApplyPass can be directly called.
*/
#ifndef NNVM_PASS_FUNCTIONS_H_
#define NNVM_PASS_FUNCTIONS_H_
#include <string>
#include <memory>
#include "./base.h"
#include "./pass.h"
#include "./graph_attr_types.h"
namespace
nnvm
{
namespace
pass
{
/*!
* \brief Load a graph from JSON string, redirects to "LoadJSON" pass.
* \param json_str The json string.
* \return Loaded graph.
*/
inline
Graph
LoadJSON
(
const
std
::
string
&
json_str
)
{
Graph
ret
;
ret
.
attrs
[
"json"
]
=
std
::
make_shared
<
any
>
(
json_str
);
return
ApplyPass
(
ret
,
{
"LoadJSON"
});
}
/*!
* \brief Save a graph to json, redirects to "SaveJSON" pass.
* \param graph The to be saved.
* \return The json string.
*/
inline
std
::
string
SaveJSON
(
Graph
graph
)
{
Graph
ret
=
ApplyPass
(
std
::
move
(
graph
),
{
"SaveJSON"
});
return
ret
.
GetAttr
<
std
::
string
>
(
"json"
);
}
/*!
* \brief Add control flow dependencies between nodes
* To correctly order mutation and read to resolve
* write after read problem and read after write problems.
* \param src source graph
* \return A graph that added control flow dependencies.
*/
inline
Graph
OrderMutation
(
Graph
src
)
{
return
ApplyPass
(
std
::
move
(
src
),
{
"OrderMutation"
});
}
/*!
* \brief Infer shapes in the graph given the information.
* \param graph source graph
* \param shape_args The shapes of aruguments to the graph.
* \param shape_attr_key The key to the node attribute that can indicate shape.
* \return A graph with new attribute "shape" containing inferred shape of each NodeEntry.
* The index of ShapeVector is given by graph.indexed_graph().entry_id
*/
inline
Graph
InferShape
(
Graph
graph
,
ShapeVector
shape_args
=
{},
std
::
string
shape_attr_key
=
""
)
{
if
(
shape_args
.
size
()
!=
0
)
{
graph
.
attrs
[
"shape_args"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
shape_args
));
}
if
(
shape_attr_key
.
length
()
!=
0
)
{
graph
.
attrs
[
"shape_attr_key"
]
=
std
::
make_shared
<
any
>
(
std
::
move
(
shape_attr_key
));
}
return
ApplyPass
(
std
::
move
(
graph
),
{
"InferShape"
});
}
}
// namespace pass
}
// namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
nnvm/src/core/pass.cc
View file @
0081ad9a
...
@@ -22,7 +22,7 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) {
...
@@ -22,7 +22,7 @@ const PassFunctionReg* FindPassDep(const std::string&attr_name) {
return
nullptr
;
return
nullptr
;
}
}
Graph
ApplyPass
(
const
Graph
&
src
,
Graph
ApplyPass
(
Graph
g
,
const
std
::
vector
<
std
::
string
>&
pass
)
{
const
std
::
vector
<
std
::
string
>&
pass
)
{
std
::
vector
<
const
PassFunctionReg
*>
fpass
;
std
::
vector
<
const
PassFunctionReg
*>
fpass
;
for
(
auto
&
name
:
pass
)
{
for
(
auto
&
name
:
pass
)
{
...
@@ -32,11 +32,9 @@ Graph ApplyPass(const Graph& src,
...
@@ -32,11 +32,9 @@ Graph ApplyPass(const Graph& src,
fpass
.
push_back
(
reg
);
fpass
.
push_back
(
reg
);
}
}
Graph
g
;
const
Graph
*
s
=
&
src
;
for
(
auto
r
:
fpass
)
{
for
(
auto
r
:
fpass
)
{
for
(
auto
&
dep
:
r
->
graph_attr_dependency
)
{
for
(
auto
&
dep
:
r
->
graph_attr_dependency
)
{
if
(
s
->
attrs
.
count
(
dep
)
==
0
)
{
if
(
g
.
attrs
.
count
(
dep
)
==
0
)
{
auto
*
pass_dep
=
FindPassDep
(
dep
);
auto
*
pass_dep
=
FindPassDep
(
dep
);
std
::
string
msg
;
std
::
string
msg
;
if
(
pass_dep
!=
nullptr
)
{
if
(
pass_dep
!=
nullptr
)
{
...
@@ -48,8 +46,7 @@ Graph ApplyPass(const Graph& src,
...
@@ -48,8 +46,7 @@ Graph ApplyPass(const Graph& src,
<<
msg
;
<<
msg
;
}
}
}
}
g
=
r
->
body
(
*
s
);
g
=
r
->
body
(
std
::
move
(
g
));
s
=
&
g
;
}
}
return
g
;
return
g
;
}
}
...
...
nnvm/src/example/operator.cc
View file @
0081ad9a
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <nnvm/base.h>
#include <nnvm/base.h>
#include <nnvm/op.h>
#include <nnvm/op.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/node.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <utility>
#include <utility>
...
@@ -30,6 +31,31 @@ inline bool SameShape(const NodeAttrs& attrs,
...
@@ -30,6 +31,31 @@ inline bool SameShape(const NodeAttrs& attrs,
return
true
;
return
true
;
}
}
// simple demonstration of reshape.
NNVM_REGISTER_OP
(
reshape
)
.
describe
(
"reshape source to target shape"
)
.
set_num_inputs
(
1
)
.
set_attr_parser
(
[](
NodeAttrs
*
attrs
)
{
// parse attr parser to get target attribute
TShape
target
;
std
::
istringstream
is
(
attrs
->
dict
.
at
(
"target"
));
CHECK
(
is
>>
target
);
attrs
->
parsed
=
std
::
move
(
target
);
})
.
attr
<
FInferShape
>
(
"FInferShape"
,
[]
(
const
NodeAttrs
&
attrs
,
array_view
<
TShape
*>
ishape
,
array_view
<
TShape
*>
oshape
)
{
// get parsed attribute
const
TShape
&
target
=
nnvm
::
get
<
TShape
>
(
attrs
.
parsed
);
*
oshape
[
0
]
=
target
;
if
(
ishape
[
0
]
->
ndim
()
==
0
)
return
false
;
CHECK_EQ
(
ishape
[
0
]
->
Size
(),
target
.
Size
())
<<
"Reshape op: source target shape mismatch"
;
return
true
;
});
NNVM_REGISTER_OP
(
add
)
NNVM_REGISTER_OP
(
add
)
.
describe
(
"add two data together"
)
.
describe
(
"add two data together"
)
.
set_num_inputs
(
2
)
.
set_num_inputs
(
2
)
...
...
nnvm/src/pass/infer_shape.cc
View file @
0081ad9a
...
@@ -10,19 +10,42 @@
...
@@ -10,19 +10,42 @@
namespace
nnvm
{
namespace
nnvm
{
namespace
pass
{
namespace
pass
{
Graph
InferShape
(
const
Graph
&
src
)
{
Graph
InferShape
(
Graph
ret
)
{
Graph
ret
=
src
;
const
IndexedGraph
&
idx
=
ret
.
indexed_graph
();
const
IndexedGraph
&
idx
=
ret
.
indexed_graph
();
static
auto
&
finfer_shape
=
Op
::
GetAttr
<
FInferShape
>
(
"FInferShape"
);
static
auto
&
finfer_shape
=
Op
::
GetAttr
<
FInferShape
>
(
"FInferShape"
);
// reshape shape vector
// reshape shape vector
ShapeVector
rshape
(
idx
.
num_node_entries
());
ShapeVector
rshape
(
idx
.
num_node_entries
());
if
(
ret
.
attrs
.
count
(
"shape_args"
)
!=
0
)
{
const
ShapeVector
&
shape_args
=
ret
.
GetAttr
<
ShapeVector
>
(
"shape_args"
);
CHECK_LE
(
shape_args
.
size
(),
idx
.
arg_nodes
().
size
())
<<
"shape args is more than number of arguments"
;
for
(
size_t
i
=
0
;
i
<
shape_args
.
size
();
++
i
)
{
rshape
[
idx
.
entry_id
(
idx
.
arg_nodes
()[
i
],
0
)]
=
shape_args
[
i
];
}
}
std
::
string
shape_attr_key
;
if
(
ret
.
attrs
.
count
(
"shape_attr_key"
)
!=
0
)
{
shape_attr_key
=
ret
.
GetAttr
<
std
::
string
>
(
"shape_attr_key"
);
}
// temp space for shape inference.
// temp space for shape inference.
std
::
vector
<
TShape
*>
ishape
,
oshape
;
std
::
vector
<
TShape
*>
ishape
,
oshape
;
// number of completed nodes
// number of completed nodes
size_t
num_known
=
0
;
size_t
num_known
=
0
;
for
(
uint32_t
nid
=
0
;
nid
<
idx
.
num_nodes
();
++
nid
)
{
for
(
uint32_t
nid
=
0
;
nid
<
idx
.
num_nodes
();
++
nid
)
{
const
auto
&
inode
=
idx
[
nid
];
const
auto
&
inode
=
idx
[
nid
];
if
(
inode
.
source
->
is_variable
())
continue
;
if
(
inode
.
source
->
is_variable
())
{
if
(
shape_attr_key
.
length
()
!=
0
)
{
auto
it
=
inode
.
source
->
attrs
.
dict
.
find
(
shape_attr_key
);
if
(
it
!=
inode
.
source
->
attrs
.
dict
.
end
())
{
CHECK_EQ
(
inode
.
source
->
num_outputs
(),
1
);
std
::
istringstream
is
(
it
->
second
);
CHECK
(
is
>>
rshape
[
idx
.
entry_id
(
nid
,
0
)])
<<
"Invalid shape attribute"
;
}
}
continue
;
}
ishape
.
resize
(
inode
.
inputs
.
size
());
ishape
.
resize
(
inode
.
inputs
.
size
());
for
(
uint32_t
i
=
0
;
i
<
ishape
.
size
();
++
i
)
{
for
(
uint32_t
i
=
0
;
i
<
ishape
.
size
();
++
i
)
{
ishape
[
i
]
=
&
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])];
ishape
[
i
]
=
&
rshape
[
idx
.
entry_id
(
inode
.
inputs
[
i
])];
...
@@ -43,5 +66,13 @@ Graph InferShape(const Graph& src) {
...
@@ -43,5 +66,13 @@ Graph InferShape(const Graph& src) {
return
ret
;
return
ret
;
}
}
NNVM_REGISTER_PASS
(
InferShape
)
.
describe
(
"Infer the shape of each node entries."
)
.
set_body
(
InferShape
)
.
set_change_graph
(
false
)
.
provide_graph_attr
(
"shape"
);
DMLC_JSON_ENABLE_ANY
(
ShapeVector
,
list_shape
);
}
// namespace pass
}
// namespace pass
}
// namespace nnvm
}
// namespace nnvm
nnvm/src/pass/order_mutation.cc
View file @
0081ad9a
/*!
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2016 by Contributors
* \file
saveload_js
on.cc
* \file
order_mutati
on.cc
* \brief Add control flow dependencies between nodes
* \brief Add control flow dependencies between nodes
* To correctly order mutation and read to resolve
* To correctly order mutation and read to resolve
* write after read problem and read after write problems.
* write after read problem and read after write problems.
...
...
nnvm/src/pass/saveload_json.cc
View file @
0081ad9a
...
@@ -149,7 +149,7 @@ struct JSONGraph {
...
@@ -149,7 +149,7 @@ struct JSONGraph {
};
};
// Load a graph from JSON file.
// Load a graph from JSON file.
Graph
LoadJSON
(
const
Graph
&
src
)
{
Graph
LoadJSON
(
Graph
src
)
{
CHECK_NE
(
src
.
attrs
.
count
(
"json"
),
0
)
CHECK_NE
(
src
.
attrs
.
count
(
"json"
),
0
)
<<
"Load JSON require json to be presented."
;
<<
"Load JSON require json to be presented."
;
const
std
::
string
&
json_str
=
const
std
::
string
&
json_str
=
...
@@ -188,7 +188,7 @@ Graph LoadJSON(const Graph& src) {
...
@@ -188,7 +188,7 @@ Graph LoadJSON(const Graph& src) {
}
}
// save a graph to json
// save a graph to json
Graph
SaveJSON
(
const
Graph
&
src
)
{
Graph
SaveJSON
(
Graph
src
)
{
JSONGraph
jgraph
;
JSONGraph
jgraph
;
std
::
unordered_map
<
Node
*
,
uint32_t
>
node2index
;
std
::
unordered_map
<
Node
*
,
uint32_t
>
node2index
;
jgraph
.
node_row_ptr
.
push_back
(
0
);
jgraph
.
node_row_ptr
.
push_back
(
0
);
...
...
nnvm/src/test_main.cc
View file @
0081ad9a
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <nnvm/tuple.h>
#include <nnvm/tuple.h>
#include <nnvm/c_api.h>
#include <nnvm/c_api.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass_functions.h>
#include <dmlc/timer.h>
#include <dmlc/timer.h>
#include <string>
#include <string>
...
...
nnvm/tests/python/test_graph.py
View file @
0081ad9a
...
@@ -35,7 +35,23 @@ def test_order_mutation_pass():
...
@@ -35,7 +35,23 @@ def test_order_mutation_pass():
assert
nindex
[
'add1'
]
in
jnodes
[
nindex
[
'assign'
]][
'control_deps'
]
assert
nindex
[
'add1'
]
in
jnodes
[
nindex
[
'assign'
]][
'control_deps'
]
assert
jnodes
[
nindex
[
'assign'
]][
'inputs'
][
0
][
2
]
==
1
assert
jnodes
[
nindex
[
'assign'
]][
'inputs'
][
0
][
2
]
==
1
def
test_infer_shape
():
x
=
sym
.
Variable
(
'x'
,
shape
=
(
4
,
2
))
y
=
sym
.
add
(
x
,
x
,
name
=
'add1'
)
y
=
sym
.
reshape
(
y
,
target
=
(
2
,
4
),
name
=
"reshape1"
)
g
=
graph
.
create
(
y
)
g
.
_set_json_attr
(
"shape_attr_key"
,
"shape"
)
g
=
g
.
apply
(
'InferShape'
)
jgraph
=
json
.
loads
(
g
.
apply
(
'SaveJSON'
)
.
json_attr
(
'json'
))
jnodes
=
jgraph
[
'nodes'
]
jnode_row_ptr
=
jgraph
[
'node_row_ptr'
]
nindex
=
{
n
[
'name'
]:
i
for
i
,
n
in
enumerate
(
jnodes
)}
assert
g
.
json_attr
(
'shape'
)[
jnode_row_ptr
[
nindex
[
"reshape1"
]]]
==
[
2
,
4
]
assert
g
.
json_attr
(
'shape'
)[
jnode_row_ptr
[
nindex
[
"add1"
]]]
==
[
4
,
2
]
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_order_mutation_pass
()
test_order_mutation_pass
()
test_graph_json_attr
()
test_graph_json_attr
()
test_json_pass
()
test_json_pass
()
test_infer_shape
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment