Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
7c95535c
Commit
7c95535c
authored
7 years ago
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] PrecomputePrune, add testcase (#14)
* [PASS] PrecomputePrune, add testcase * update comment
parent
d27c11e0
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
310 additions
and
48 deletions
+310
-48
nnvm/Makefile
+2
-1
nnvm/python/nnvm/compiler/__init__.py
+1
-1
nnvm/python/nnvm/compiler/build_module.py
+78
-7
nnvm/python/nnvm/compiler/graph_attr.py
+48
-16
nnvm/python/nnvm/compiler/graph_pass.py
+55
-0
nnvm/python/nnvm/graph.py
+6
-0
nnvm/python/nnvm/top/attr_dict.py
+15
-0
nnvm/python/nnvm/top/tensor.py
+17
-0
nnvm/src/compiler/packed_func_ext.cc
+14
-0
nnvm/src/compiler/pass/graph_fuse.cc
+0
-8
nnvm/src/compiler/pass/layout_transform.cc
+1
-1
nnvm/src/compiler/pass/precompute_prune.cc
+21
-12
nnvm/tests/python/compiler/test_build.py
+34
-2
nnvm/tests/python/compiler/test_graph_pass.py
+18
-0
No files found.
nnvm/Makefile
View file @
7c95535c
...
...
@@ -30,7 +30,7 @@ ifneq ($(ADD_CFLAGS), NONE)
endif
ifneq
($(ADD_LDFLAGS),
NONE)
L
F
FLAGS
+=
$(ADD_LDFLAGS)
L
D
FLAGS
+=
$(ADD_LDFLAGS)
endif
# plugin
...
...
@@ -46,6 +46,7 @@ ifeq ($(UNAME_S), Darwin)
SHARED_LIBRARY_SUFFIX
:=
dylib
WHOLE_ARCH
=
-all_load
NO_WHOLE_ARCH
=
-noall_load
LDFLAGS
+=
-undefined
dynamic_lookup
else
SHARED_LIBRARY_SUFFIX
:=
so
WHOLE_ARCH
=
--whole-archive
...
...
This diff is collapsed.
Click to expand it.
nnvm/python/nnvm/compiler/__init__.py
View file @
7c95535c
...
...
@@ -4,7 +4,7 @@ from __future__ import absolute_import
import
tvm
from
.
import
build_module
from
.
build_module
import
build
from
.
build_module
import
build
,
precompute_prune
,
_run_graph
from
..
import
symbol
as
_symbol
from
..
import
graph
as
_graph
...
...
This diff is collapsed.
Click to expand it.
nnvm/python/nnvm/compiler/build_module.py
View file @
7c95535c
...
...
@@ -3,8 +3,9 @@
from
__future__
import
absolute_import
as
_abs
import
tvm
from
.
import
graph_attr
from
.
import
graph_attr
,
graph_pass
from
..
import
graph
as
_graph
from
..
import
runtime
@tvm.register_func
(
"nnvm.compiler.lower"
)
def
_lower
(
sch
,
inputs
,
func_name
):
...
...
@@ -18,9 +19,6 @@ def _build(funcs, target):
return
tvm
.
build
(
funcs
,
target
=
target
)
_move_module
=
tvm
.
get_global_func
(
"nnvm.compiler._move_module"
)
def
optimize
(
graph
):
"""Perform graph optimization
...
...
@@ -70,10 +68,83 @@ def build(graph, target, shape, dtype="float32"):
raise
TypeError
(
"require shape to be dict"
)
graph
=
graph
if
isinstance
(
graph
,
_graph
.
Graph
)
else
_graph
.
create
(
graph
)
graph
=
graph_attr
.
set_shape
(
graph
,
shape
)
graph
=
graph_attr
.
set_dtype
(
graph
,
dtype
)
graph
=
graph_attr
.
set_shape
_inputs
(
graph
,
shape
)
graph
=
graph_attr
.
set_dtype
_inputs
(
graph
,
dtype
)
graph
.
_set_json_attr
(
"target"
,
target
,
"str"
)
graph
=
graph
.
apply
(
"InferShape"
)
.
apply
(
"InferType"
)
graph
=
graph
.
apply
(
"GraphFusePartition"
)
.
apply
(
"GraphFuse"
)
libmod
=
_move_module
(
graph
)
libmod
=
graph_attr
.
_move_out_module
(
graph
,
"module"
)
return
graph
,
libmod
def
_run_graph
(
graph
,
params
):
"""Helper utility to build and run and get outputs, only use cpu mode.
Parameters
----------
graph : Graph
The graph to be executed.
params: dict of str to ndarray
The parameter dictionary.
Returns
-------
out_dict: dict of str to tvm.NDArray
The output dictionaries.
"""
graph
=
graph
if
isinstance
(
graph
,
_graph
.
Graph
)
else
_graph
.
create
(
graph
)
shape
=
{
k
:
v
.
shape
for
k
,
v
in
params
.
items
()}
dtype
=
{
k
:
v
.
dtype
for
k
,
v
in
params
.
items
()}
target
=
"llvm"
ctx
=
tvm
.
cpu
(
0
)
_
,
oshape
=
graph_pass
.
infer_shape
(
graph
,
**
shape
)
_
,
odtype
=
graph_pass
.
infer_dtype
(
graph
,
**
dtype
)
graph
,
libmod
=
build
(
graph
,
target
,
shape
,
dtype
)
m
=
runtime
.
create
(
graph
,
libmod
,
ctx
)
set_input
,
run
,
get_output
=
m
[
"set_input"
],
m
[
"run"
],
m
[
"get_output"
]
for
k
,
v
in
params
.
items
():
set_input
(
k
,
tvm
.
nd
.
array
(
v
))
run
()
out_data
=
[]
for
i
,
kv
in
enumerate
(
zip
(
oshape
,
odtype
)):
shape
,
dtype
=
kv
arr
=
tvm
.
nd
.
empty
(
shape
,
dtype
,
ctx
)
get_output
(
i
,
arr
)
out_data
.
append
(
arr
)
return
out_data
def
precompute_prune
(
graph
,
params
):
"""Precompute the part of graph that can be pre-computed.
This will create a new graph that only contains the ops
that need to be computed depending on input as well as
updated version of param dict that pre-computes some of
intermediate results.
Parameters
----------
graph : Graph
The input graph
params : dict of str -> tvm.NDArray
The parameter dictionary of the graph
Returns
-------
pruned_graph : Graph
The pruned graph
new_params : dict of str-> tvm.NDArray
The updated dictionary of parameters.
"""
graph
=
graph
if
isinstance
(
graph
,
_graph
.
Graph
)
else
_graph
.
create
(
graph
)
graph
.
_set_json_attr
(
"param_name_list"
,
list
(
params
.
keys
()),
"list_str"
)
graph
=
graph
.
apply
(
"PrecomputePrune"
)
pre_graph
=
graph_attr
.
_move_out_graph
(
graph
,
"precompute_graph"
)
if
not
pre_graph
.
symbol
.
list_output_names
():
return
graph
,
params
out_names
=
pre_graph
.
json_attr
(
"output_names"
)
out_arrs
=
_run_graph
(
pre_graph
,
params
)
return
graph
,
dict
(
zip
(
out_names
,
out_arrs
))
This diff is collapsed.
Click to expand it.
nnvm/python/nnvm/compiler/graph_attr.py
View file @
7c95535c
# pylint: disable=invalid-name
"""Utilities to access graph attributes"""
from
__future__
import
absolute_import
as
_abs
def
set_shape
(
g
,
shape
):
"""Set the shape of graph nodes in the graph attribute.
import
tvm
def
set_shape_inputs
(
g
,
shape
):
"""Set the shape of input graph nodes in the graph attribute.
Parameters
----------
...
...
@@ -17,20 +20,24 @@ def set_shape(g, shape):
g : Graph
The updated graph with updated shape.
"""
index
=
g
.
index
list_shape
=
[[]]
*
index
.
num_node_entries
for
k
,
v
in
shape
.
items
():
list_shape
[
index
.
entry_id
(
k
)]
=
v
g
.
_set_json_attr
(
"shape"
,
list_shape
,
'list_shape'
)
list_shape
=
[
shape
.
get
(
name
,
())
for
name
in
g
.
index
.
input_names
]
g
.
_set_json_attr
(
"shape_inputs"
,
list_shape
,
'list_shape'
)
return
g
DTYPE_DICT
=
{
DTYPE_TO_TCODE
=
{
"default"
:
-
1
,
"float32"
:
0
}
def
set_dtype
(
g
,
dtype
):
"""Set the dtype of graph nodes
TCODE_TO_DTYPE
=
{
-
1
:
None
,
0
:
"float32"
}
def
set_dtype_inputs
(
g
,
dtype
):
"""Set the dtype inputs of graph nodes
Parameters
----------
...
...
@@ -45,12 +52,37 @@ def set_dtype(g, dtype):
g : Graph
The updated graph with updated dtype.
"""
index
=
g
.
index
if
isinstance
(
dtype
,
dict
):
list_dtype
=
[
-
1
]
*
index
.
num_node_entries
for
k
,
v
in
dtype
.
items
():
list_dtype
[
index
.
entry_id
(
k
)]
=
DTYPE_DICT
[
v
]
list_dtype
=
[
DTYPE_TO_TCODE
[
dtype
.
get
(
name
,
"default"
)]
for
name
in
g
.
index
.
input_names
]
else
:
list_dtype
=
[
DTYPE_DICT
[
dtype
]]
*
index
.
num_node_entries
g
.
_set_json_attr
(
"dtype"
,
list_dtype
,
"list_int"
)
list_dtype
=
[
DTYPE_TO_TCODE
[
dtype
]]
*
len
(
g
.
index
.
input_names
)
g
.
_set_json_attr
(
"dtype_inputs"
,
list_dtype
,
"list_int"
)
return
g
def
set_layout_inputs
(
g
,
layout
):
"""Set the layout inputs of graph nodes
Parameters
----------
g : Graph
The input graph
layout : dict of str to str or str
The input layout
Returns
-------
g : Graph
The updated graph with updated dtype.
"""
list_shape
=
[
layout
.
get
(
name
,
"default"
)
for
name
in
g
.
index
.
input_names
]
g
.
_set_json_attr
(
"layout_inputs"
,
list_shape
,
'list_str'
)
return
g
_move_out_module
=
tvm
.
get_global_func
(
"nnvm.graph_attr._move_module"
)
_move_out_graph
=
tvm
.
get_global_func
(
"nnvm.graph_attr._move_graph"
)
This diff is collapsed.
Click to expand it.
nnvm/python/nnvm/compiler/graph_pass.py
View file @
7c95535c
# pylint: disable=invalid-name
"""Namespace of graph pass.
Principle:
...
...
@@ -5,3 +6,57 @@ Principle:
- Composable API: break graph transformation pass as segments of small transformations.
"""
from
__future__
import
absolute_import
as
_abs
from
.
import
graph_attr
def
infer_shape
(
graph
,
**
shape
):
"""Infer the shape given the shape of inputs.
Parameters
----------
graph : Graph
The graph to perform shape inference from
Returns
-------
in_shape : list of tuple
Shape of inputs
out_shape: list of tuple
Shape of outputs
"""
graph
=
graph_attr
.
set_shape_inputs
(
graph
,
shape
)
graph
=
graph
.
apply
(
"InferShape"
)
shape
=
graph
.
json_attr
(
"shape"
)
index
=
graph
.
index
input_shape
=
[
shape
[
index
.
entry_id
(
x
)]
for
x
in
index
.
input_names
]
output_shape
=
[
shape
[
index
.
entry_id
(
x
)]
for
x
in
index
.
output_entries
]
return
input_shape
,
output_shape
def
infer_dtype
(
graph
,
**
dtype
):
"""Infer the type given the typeS of inputs.
Parameters
----------
graph : Graph
The graph to perform type inference from
Returns
-------
in_dtype : list of tuple
Dtype of inputs
out_dtype: list of tuple
Dtype of outputs
"""
graph
=
graph_attr
.
set_dtype_inputs
(
graph
,
dtype
)
graph
=
graph
.
apply
(
"InferType"
)
dtype
=
graph
.
json_attr
(
"dtype"
)
index
=
graph
.
index
input_dtype
=
[
graph_attr
.
TCODE_TO_DTYPE
[
dtype
[
index
.
entry_id
(
x
)]]
for
x
in
index
.
input_names
]
output_dtype
=
[
graph_attr
.
TCODE_TO_DTYPE
[
dtype
[
index
.
entry_id
(
x
)]]
for
x
in
index
.
output_entries
]
return
input_dtype
,
output_dtype
This diff is collapsed.
Click to expand it.
nnvm/python/nnvm/graph.py
View file @
7c95535c
...
...
@@ -24,6 +24,8 @@ class GraphIndex(object):
self
.
nodes
=
jgraph
[
"nodes"
]
self
.
entry_ptr
=
jgraph
[
"node_row_ptr"
]
self
.
_name2nodeid
=
{
n
[
"name"
]:
i
for
i
,
n
in
enumerate
(
self
.
nodes
)}
self
.
input_names
=
graph
.
symbol
.
list_input_names
()
self
.
output_entries
=
jgraph
[
"heads"
]
@property
def
num_nodes
(
self
):
...
...
@@ -66,6 +68,10 @@ class GraphIndex(object):
index : int
The entry index
"""
if
isinstance
(
key
,
(
list
,
tuple
)):
if
len
(
key
)
!=
3
:
raise
ValueError
(
"Expect entry index to be tuple of 3 elems"
)
key
,
value_index
,
_
=
key
idx
=
self
.
node_id
(
key
)
if
isinstance
(
key
,
str
)
else
key
assert
value_index
<
self
.
entry_ptr
[
idx
+
1
]
return
self
.
entry_ptr
[
idx
]
+
value_index
...
...
This diff is collapsed.
Click to expand it.
nnvm/python/nnvm/top/attr_dict.py
View file @
7c95535c
...
...
@@ -68,6 +68,21 @@ class AttrDict(object):
"""
return
int
(
self
[
key
])
def
get_float
(
self
,
key
):
"""Get float from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
value : float
The result value
"""
return
float
(
self
[
key
])
def
get_bool
(
self
,
key
):
"""Get bool from attr dict
...
...
This diff is collapsed.
Click to expand it.
nnvm/python/nnvm/top/tensor.py
View file @
7c95535c
...
...
@@ -17,6 +17,17 @@ def _schedule_broadcast(_, outs, target):
tvm
.
schedule
.
AutoInlineInjective
(
s
)
return
s
def
_compute_binary_scalar
(
f
):
"""auxiliary function"""
@tvm.tag_scope
(
"ewise"
)
def
_compute
(
attrs
,
x
):
x
=
x
[
0
]
scalar
=
attrs
.
get_float
(
"scalar"
)
scalar
=
tvm
.
const
(
scalar
,
x
.
dtype
)
return
tvm
.
compute
(
x
.
shape
,
lambda
*
i
:
f
(
x
(
*
i
),
scalar
))
return
_compute
_fschedule_broadcast
=
tvm
.
convert
(
_schedule_broadcast
)
# exp
...
...
@@ -25,6 +36,12 @@ reg.register_compute("exp",
reg
.
register_pattern
(
"exp"
,
OpPattern
.
ELEM_WISE
)
reg
.
register_schedule
(
"exp"
,
_fschedule_broadcast
)
# add scalar
reg
.
register_compute
(
"__add_scalar__"
,
_compute_binary_scalar
(
lambda
x
,
y
:
x
+
y
))
reg
.
register_pattern
(
"__add_scalar__"
,
OpPattern
.
ELEM_WISE
)
reg
.
register_schedule
(
"__add_scalar__"
,
_fschedule_broadcast
)
# broadcast_add
reg
.
register_compute
(
"broadcast_add"
,
lambda
_
,
x
:
topi
.
broadcast_add
(
x
[
0
],
x
[
1
]))
...
...
This diff is collapsed.
Click to expand it.
nnvm/src/compiler/packed_func_ext.cc
View file @
7c95535c
...
...
@@ -104,5 +104,19 @@ TVM_REGISTER_GLOBAL("nnvm._register_pattern")
Op
&
op
=
::
dmlc
::
Registry
<
nnvm
::
Op
>::
Get
()
->
__REGISTER_OR_GET__
(
args
[
0
]);
op
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
args
[
1
].
operator
int
(),
args
[
2
]);
});
TVM_REGISTER_GLOBAL
(
"nnvm.graph_attr._move_module"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
const
nnvm
::
Graph
&
g
=
args
[
0
].
AsExtension
<
Graph
>
();
*
rv
=
const_cast
<
nnvm
::
Graph
*>
(
&
g
)
->
MoveCopyAttr
<
tvm
::
runtime
::
Module
>
(
args
[
1
]);
});
TVM_REGISTER_GLOBAL
(
"nnvm.graph_attr._move_graph"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
const
nnvm
::
Graph
&
g
=
args
[
0
].
AsExtension
<
Graph
>
();
*
rv
=
const_cast
<
nnvm
::
Graph
*>
(
&
g
)
->
MoveCopyAttr
<
nnvm
::
Graph
>
(
args
[
1
]);
});
}
// namespace compiler
}
// namespace nnvm
This diff is collapsed.
Click to expand it.
nnvm/src/compiler/pass/graph_fuse.cc
View file @
7c95535c
...
...
@@ -381,13 +381,5 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
NNVM_REGISTER_PASS
(
GraphFuse
)
.
set_body
(
GraphFuse
);
TVM_REGISTER_GLOBAL
(
"nnvm.compiler._move_module"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
const
nnvm
::
Graph
&
g
=
args
[
0
].
AsExtension
<
Graph
>
();
*
rv
=
const_cast
<
nnvm
::
Graph
*>
(
&
g
)
->
MoveCopyAttr
<
tvm
::
runtime
::
Module
>
(
"module"
);
});
}
// namespace compiler
}
// namespace nnvm
This diff is collapsed.
Click to expand it.
nnvm/src/compiler/pass/layout_transform.cc
View file @
7c95535c
...
...
@@ -44,7 +44,7 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
const
ShapeVector
&
shape_vec
=
src
.
GetAttr
<
ShapeVector
>
(
"shape"
);
const
std
::
vector
<
TLayoutInfo
>&
input_layouts
=
src
.
GetAttr
<
std
::
vector
<
TLayoutInfo
>
>
(
"layout
"
);
src
.
GetAttr
<
std
::
vector
<
TLayoutInfo
>
>
(
"layout_inputs
"
);
const
IndexedGraph
&
idx
=
src
.
indexed_graph
();
std
::
vector
<
TLayoutInfo
>
produce_vec
(
idx
.
num_node_entries
(),
GetDefaultLayout
());
...
...
This diff is collapsed.
Click to expand it.
nnvm/src/compiler/pass/pr
une_graph
.cc
→
nnvm/src/compiler/pass/pr
ecompute_prune
.cc
View file @
7c95535c
/*!
* Copyright (c) 2017 by Contributors
* \file pr
une_graph
.cc
* \brief
Prune the graph to do constant folding
.
* \file pr
ecompute_prune
.cc
* \brief
Split the graph into a pre-compute graph and a execution graph
.
*
* Th
is pass breaks the graph into pre-compute graph
*
and the execution graph
.
* Th
e pre-compute graph outputs parameters that can be taken
*
by execution graph during execution phase
.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
...
...
@@ -16,11 +16,15 @@
namespace
nnvm
{
namespace
compiler
{
nnvm
::
Graph
PruneGraph
(
nnvm
::
Graph
src
)
{
const
auto
&
params
=
src
.
GetAttr
<
std
::
unordered_set
<
std
::
string
>
>
(
"params"
);
nnvm
::
Graph
PrecomputePrune
(
nnvm
::
Graph
src
)
{
const
auto
&
plist
=
src
.
GetAttr
<
std
::
vector
<
std
::
string
>
>
(
"param_name_list"
);
std
::
unordered_set
<
std
::
string
>
params
(
plist
.
begin
(),
plist
.
end
());
std
::
unordered_set
<
nnvm
::
Node
*>
pruned
;
nnvm
::
NodeEntryMap
<
nnvm
::
NodePtr
>
entry_var
;
std
::
unordered_set
<
std
::
string
>
unique_name
;
DFSVisit
(
src
.
outputs
,
[
&
](
const
nnvm
::
NodePtr
&
n
)
{
bool
can_be_pruned
=
true
;
if
(
n
->
is_variable
())
{
...
...
@@ -45,7 +49,12 @@ nnvm::Graph PruneGraph(nnvm::Graph src) {
nnvm
::
NodePtr
var
=
nnvm
::
Node
::
Create
();
var
->
attrs
.
name
=
e
.
node
->
attrs
.
name
+
"_output"
+
std
::
to_string
(
e
.
index
);
entry_var
.
emplace
(
e
,
var
);
CHECK
(
!
unique_name
.
count
(
var
->
attrs
.
name
));
unique_name
.
insert
(
var
->
attrs
.
name
);
}
// TODO(ziheng): this pass now mutates the original graph structure
// This might not be a good thing, change to copy the structure instead
//
e
=
nnvm
::
NodeEntry
{
entry_var
.
at
(
e
),
0
,
0
};
}
}
...
...
@@ -56,21 +65,21 @@ nnvm::Graph PruneGraph(nnvm::Graph src) {
pre_graph
.
outputs
.
reserve
(
entry_var
.
size
());
std
::
vector
<
std
::
string
>
output_names
;
output_names
.
reserve
(
entry_var
.
size
());
for
(
auto
kv
:
entry_var
)
{
if
(
kv
.
first
.
node
->
is_variable
())
continue
;
pre_graph
.
outputs
.
emplace_back
(
kv
.
first
);
output_names
.
emplace_back
(
kv
.
second
->
attrs
.
name
);
}
pre_graph
.
attrs
[
"
pruned_param
s"
]
=
// new parameter list
pre_graph
.
attrs
[
"
output_name
s"
]
=
std
::
make_shared
<
dmlc
::
any
>
(
std
::
move
(
output_names
));
src
.
attrs
[
"pre_graph"
]
=
src
.
attrs
[
"pre
compute
_graph"
]
=
std
::
make_shared
<
dmlc
::
any
>
(
std
::
move
(
pre_graph
));
return
src
;
}
NNVM_REGISTER_PASS
(
PruneGraph
)
.
set_body
(
PruneGraph
);
NNVM_REGISTER_PASS
(
PrecomputePrune
)
.
set_body
(
PrecomputePrune
);
}
// namespace compiler
}
// namespace nnvm
This diff is collapsed.
Click to expand it.
nnvm/tests/python/compiler/test_build.py
View file @
7c95535c
...
...
@@ -17,8 +17,8 @@ def test_compile():
m
=
nnvm
.
runtime
.
create
(
graph
,
lib
,
tvm
.
cpu
(
0
))
# get member functions
set_input
,
run
,
get_output
=
m
[
"set_input"
],
m
[
"run"
],
m
[
"get_output"
]
na
=
tvm
.
nd
.
array
(
np
.
ones
(
shape
)
.
astype
(
dtype
))
nb
=
tvm
.
nd
.
array
(
np
.
ones
(
shape
)
.
astype
(
dtype
))
na
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
dtype
))
nb
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
dtype
))
# set inputs
set_input
(
"x"
,
na
)
set_input
(
"y"
,
nb
)
...
...
@@ -30,5 +30,37 @@ def test_compile():
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
np
.
exp
(
na
.
asnumpy
()
+
nb
.
asnumpy
()))
def
test_run
():
x
=
sym
.
Variable
(
"x"
)
y
=
sym
.
Variable
(
"y"
)
z
=
sym
.
exp
(
y
+
x
)
shape
=
(
10
,
10
)
dtype
=
tvm
.
float32
nx
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
dtype
))
ny
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
dtype
))
res
=
nnvm
.
compiler
.
_run_graph
(
z
,
{
"x"
:
nx
,
"y"
:
ny
})
np
.
testing
.
assert_allclose
(
res
[
0
]
.
asnumpy
(),
np
.
exp
(
nx
.
asnumpy
()
+
ny
.
asnumpy
()))
def
test_precompute_prune
():
x
=
sym
.
Variable
(
"x"
)
+
1
y
=
sym
.
Variable
(
"y"
)
z
=
y
+
x
shape
=
(
10
,
10
)
dtype
=
tvm
.
float32
nx
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
dtype
))
ny
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
dtype
))
params
=
{
"x"
:
nx
}
graph
,
pdict
=
nnvm
.
compiler
.
precompute_prune
(
z
,
params
)
pdict
[
"y"
]
=
ny
res
=
nnvm
.
compiler
.
_run_graph
(
z
,
pdict
)
np
.
testing
.
assert_allclose
(
res
[
0
]
.
asnumpy
(),
nx
.
asnumpy
()
+
1
+
ny
.
asnumpy
())
if
__name__
==
"__main__"
:
test_compile
()
test_run
()
test_precompute_prune
()
This diff is collapsed.
Click to expand it.
nnvm/tests/python/compiler/test_graph_pass.py
0 → 100644
View file @
7c95535c
"""Unittest cases for graph pass"""
import
nnvm
import
nnvm.compiler
from
nnvm.compiler
import
graph_pass
def
test_infer_attr
():
x
=
nnvm
.
symbol
.
Variable
(
"x"
)
y
=
x
*
2
g
=
nnvm
.
graph
.
create
(
y
)
ishape
,
oshape
=
graph_pass
.
infer_shape
(
g
,
x
=
(
10
,
20
))
assert
tuple
(
oshape
[
0
])
==
(
10
,
20
)
itype
,
otype
=
graph_pass
.
infer_dtype
(
g
,
x
=
"float32"
)
assert
otype
[
0
]
==
"float32"
if
__name__
==
"__main__"
:
test_infer_attr
()
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment