Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
7c95535c
Commit
7c95535c
authored
Sep 19, 2017
by
Tianqi Chen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[PASS] PrecomputePrune, add testcase (#14)
* [PASS] PrecomputePrune, add testcase * update comment
parent
d27c11e0
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
310 additions
and
48 deletions
+310
-48
nnvm/Makefile
+2
-1
nnvm/python/nnvm/compiler/__init__.py
+1
-1
nnvm/python/nnvm/compiler/build_module.py
+78
-7
nnvm/python/nnvm/compiler/graph_attr.py
+48
-16
nnvm/python/nnvm/compiler/graph_pass.py
+55
-0
nnvm/python/nnvm/graph.py
+6
-0
nnvm/python/nnvm/top/attr_dict.py
+15
-0
nnvm/python/nnvm/top/tensor.py
+17
-0
nnvm/src/compiler/packed_func_ext.cc
+14
-0
nnvm/src/compiler/pass/graph_fuse.cc
+0
-8
nnvm/src/compiler/pass/layout_transform.cc
+1
-1
nnvm/src/compiler/pass/precompute_prune.cc
+21
-12
nnvm/tests/python/compiler/test_build.py
+34
-2
nnvm/tests/python/compiler/test_graph_pass.py
+18
-0
No files found.
nnvm/Makefile
View file @
7c95535c
...
...
@@ -30,7 +30,7 @@ ifneq ($(ADD_CFLAGS), NONE)
endif
ifneq
($(ADD_LDFLAGS),
NONE)
L
F
FLAGS
+=
$(ADD_LDFLAGS)
L
D
FLAGS
+=
$(ADD_LDFLAGS)
endif
# plugin
...
...
@@ -46,6 +46,7 @@ ifeq ($(UNAME_S), Darwin)
SHARED_LIBRARY_SUFFIX
:=
dylib
WHOLE_ARCH
=
-all_load
NO_WHOLE_ARCH
=
-noall_load
LDFLAGS
+=
-undefined
dynamic_lookup
else
SHARED_LIBRARY_SUFFIX
:=
so
WHOLE_ARCH
=
--whole-archive
...
...
nnvm/python/nnvm/compiler/__init__.py
View file @
7c95535c
...
...
@@ -4,7 +4,7 @@ from __future__ import absolute_import
import
tvm
from
.
import
build_module
from
.
build_module
import
build
from
.
build_module
import
build
,
precompute_prune
,
_run_graph
from
..
import
symbol
as
_symbol
from
..
import
graph
as
_graph
...
...
nnvm/python/nnvm/compiler/build_module.py
View file @
7c95535c
...
...
@@ -3,8 +3,9 @@
from
__future__
import
absolute_import
as
_abs
import
tvm
from
.
import
graph_attr
from
.
import
graph_attr
,
graph_pass
from
..
import
graph
as
_graph
from
..
import
runtime
@tvm.register_func
(
"nnvm.compiler.lower"
)
def
_lower
(
sch
,
inputs
,
func_name
):
...
...
@@ -18,9 +19,6 @@ def _build(funcs, target):
return
tvm
.
build
(
funcs
,
target
=
target
)
_move_module
=
tvm
.
get_global_func
(
"nnvm.compiler._move_module"
)
def
optimize
(
graph
):
"""Perform graph optimization
...
...
@@ -70,10 +68,83 @@ def build(graph, target, shape, dtype="float32"):
raise
TypeError
(
"require shape to be dict"
)
graph
=
graph
if
isinstance
(
graph
,
_graph
.
Graph
)
else
_graph
.
create
(
graph
)
graph
=
graph_attr
.
set_shape
(
graph
,
shape
)
graph
=
graph_attr
.
set_dtype
(
graph
,
dtype
)
graph
=
graph_attr
.
set_shape
_inputs
(
graph
,
shape
)
graph
=
graph_attr
.
set_dtype
_inputs
(
graph
,
dtype
)
graph
.
_set_json_attr
(
"target"
,
target
,
"str"
)
graph
=
graph
.
apply
(
"InferShape"
)
.
apply
(
"InferType"
)
graph
=
graph
.
apply
(
"GraphFusePartition"
)
.
apply
(
"GraphFuse"
)
libmod
=
_move_module
(
graph
)
libmod
=
graph_attr
.
_move_out_module
(
graph
,
"module"
)
return
graph
,
libmod
def
_run_graph
(
graph
,
params
):
"""Helper utility to build and run and get outputs, only use cpu mode.
Parameters
----------
graph : Graph
The graph to be executed.
params: dict of str to ndarray
The parameter dictionary.
Returns
-------
out_dict: dict of str to tvm.NDArray
The output dictionaries.
"""
graph
=
graph
if
isinstance
(
graph
,
_graph
.
Graph
)
else
_graph
.
create
(
graph
)
shape
=
{
k
:
v
.
shape
for
k
,
v
in
params
.
items
()}
dtype
=
{
k
:
v
.
dtype
for
k
,
v
in
params
.
items
()}
target
=
"llvm"
ctx
=
tvm
.
cpu
(
0
)
_
,
oshape
=
graph_pass
.
infer_shape
(
graph
,
**
shape
)
_
,
odtype
=
graph_pass
.
infer_dtype
(
graph
,
**
dtype
)
graph
,
libmod
=
build
(
graph
,
target
,
shape
,
dtype
)
m
=
runtime
.
create
(
graph
,
libmod
,
ctx
)
set_input
,
run
,
get_output
=
m
[
"set_input"
],
m
[
"run"
],
m
[
"get_output"
]
for
k
,
v
in
params
.
items
():
set_input
(
k
,
tvm
.
nd
.
array
(
v
))
run
()
out_data
=
[]
for
i
,
kv
in
enumerate
(
zip
(
oshape
,
odtype
)):
shape
,
dtype
=
kv
arr
=
tvm
.
nd
.
empty
(
shape
,
dtype
,
ctx
)
get_output
(
i
,
arr
)
out_data
.
append
(
arr
)
return
out_data
def
precompute_prune
(
graph
,
params
):
"""Precompute the part of graph that can be pre-computed.
This will create a new graph that only contains the ops
that need to be computed depending on input as well as
updated version of param dict that pre-computes some of
intermediate results.
Parameters
----------
graph : Graph
The input graph
params : dict of str -> tvm.NDArray
The parameter dictionary of the graph
Returns
-------
pruned_graph : Graph
The pruned graph
new_params : dict of str-> tvm.NDArray
The updated dictionary of parameters.
"""
graph
=
graph
if
isinstance
(
graph
,
_graph
.
Graph
)
else
_graph
.
create
(
graph
)
graph
.
_set_json_attr
(
"param_name_list"
,
list
(
params
.
keys
()),
"list_str"
)
graph
=
graph
.
apply
(
"PrecomputePrune"
)
pre_graph
=
graph_attr
.
_move_out_graph
(
graph
,
"precompute_graph"
)
if
not
pre_graph
.
symbol
.
list_output_names
():
return
graph
,
params
out_names
=
pre_graph
.
json_attr
(
"output_names"
)
out_arrs
=
_run_graph
(
pre_graph
,
params
)
return
graph
,
dict
(
zip
(
out_names
,
out_arrs
))
nnvm/python/nnvm/compiler/graph_attr.py
View file @
7c95535c
# pylint: disable=invalid-name
"""Utilities to access graph attributes"""
from
__future__
import
absolute_import
as
_abs
def
set_shape
(
g
,
shape
):
"""Set the shape of graph nodes in the graph attribute.
import
tvm
def
set_shape_inputs
(
g
,
shape
):
"""Set the shape of input graph nodes in the graph attribute.
Parameters
----------
...
...
@@ -17,20 +20,24 @@ def set_shape(g, shape):
g : Graph
The updated graph with updated shape.
"""
index
=
g
.
index
list_shape
=
[[]]
*
index
.
num_node_entries
for
k
,
v
in
shape
.
items
():
list_shape
[
index
.
entry_id
(
k
)]
=
v
g
.
_set_json_attr
(
"shape"
,
list_shape
,
'list_shape'
)
list_shape
=
[
shape
.
get
(
name
,
())
for
name
in
g
.
index
.
input_names
]
g
.
_set_json_attr
(
"shape_inputs"
,
list_shape
,
'list_shape'
)
return
g
DTYPE_DICT
=
{
DTYPE_TO_TCODE
=
{
"default"
:
-
1
,
"float32"
:
0
}
def
set_dtype
(
g
,
dtype
):
"""Set the dtype of graph nodes
TCODE_TO_DTYPE
=
{
-
1
:
None
,
0
:
"float32"
}
def
set_dtype_inputs
(
g
,
dtype
):
"""Set the dtype inputs of graph nodes
Parameters
----------
...
...
@@ -45,12 +52,37 @@ def set_dtype(g, dtype):
g : Graph
The updated graph with updated dtype.
"""
index
=
g
.
index
if
isinstance
(
dtype
,
dict
):
list_dtype
=
[
-
1
]
*
index
.
num_node_entries
for
k
,
v
in
dtype
.
items
():
list_dtype
[
index
.
entry_id
(
k
)]
=
DTYPE_DICT
[
v
]
list_dtype
=
[
DTYPE_TO_TCODE
[
dtype
.
get
(
name
,
"default"
)]
for
name
in
g
.
index
.
input_names
]
else
:
list_dtype
=
[
DTYPE_DICT
[
dtype
]]
*
index
.
num_node_entries
g
.
_set_json_attr
(
"dtype"
,
list_dtype
,
"list_int"
)
list_dtype
=
[
DTYPE_TO_TCODE
[
dtype
]]
*
len
(
g
.
index
.
input_names
)
g
.
_set_json_attr
(
"dtype_inputs"
,
list_dtype
,
"list_int"
)
return
g
def
set_layout_inputs
(
g
,
layout
):
"""Set the layout inputs of graph nodes
Parameters
----------
g : Graph
The input graph
layout : dict of str to str or str
The input layout
Returns
-------
g : Graph
The updated graph with updated dtype.
"""
list_shape
=
[
layout
.
get
(
name
,
"default"
)
for
name
in
g
.
index
.
input_names
]
g
.
_set_json_attr
(
"layout_inputs"
,
list_shape
,
'list_str'
)
return
g
_move_out_module
=
tvm
.
get_global_func
(
"nnvm.graph_attr._move_module"
)
_move_out_graph
=
tvm
.
get_global_func
(
"nnvm.graph_attr._move_graph"
)
nnvm/python/nnvm/compiler/graph_pass.py
View file @
7c95535c
# pylint: disable=invalid-name
"""Namespace of graph pass.
Principle:
...
...
@@ -5,3 +6,57 @@ Principle:
- Composable API: break graph transformation pass as segments of small transformations.
"""
from
__future__
import
absolute_import
as
_abs
from
.
import
graph_attr
def
infer_shape
(
graph
,
**
shape
):
"""Infer the shape given the shape of inputs.
Parameters
----------
graph : Graph
The graph to perform shape inference from
Returns
-------
in_shape : list of tuple
Shape of inputs
out_shape: list of tuple
Shape of outputs
"""
graph
=
graph_attr
.
set_shape_inputs
(
graph
,
shape
)
graph
=
graph
.
apply
(
"InferShape"
)
shape
=
graph
.
json_attr
(
"shape"
)
index
=
graph
.
index
input_shape
=
[
shape
[
index
.
entry_id
(
x
)]
for
x
in
index
.
input_names
]
output_shape
=
[
shape
[
index
.
entry_id
(
x
)]
for
x
in
index
.
output_entries
]
return
input_shape
,
output_shape
def
infer_dtype
(
graph
,
**
dtype
):
"""Infer the type given the typeS of inputs.
Parameters
----------
graph : Graph
The graph to perform type inference from
Returns
-------
in_dtype : list of tuple
Dtype of inputs
out_dtype: list of tuple
Dtype of outputs
"""
graph
=
graph_attr
.
set_dtype_inputs
(
graph
,
dtype
)
graph
=
graph
.
apply
(
"InferType"
)
dtype
=
graph
.
json_attr
(
"dtype"
)
index
=
graph
.
index
input_dtype
=
[
graph_attr
.
TCODE_TO_DTYPE
[
dtype
[
index
.
entry_id
(
x
)]]
for
x
in
index
.
input_names
]
output_dtype
=
[
graph_attr
.
TCODE_TO_DTYPE
[
dtype
[
index
.
entry_id
(
x
)]]
for
x
in
index
.
output_entries
]
return
input_dtype
,
output_dtype
nnvm/python/nnvm/graph.py
View file @
7c95535c
...
...
@@ -24,6 +24,8 @@ class GraphIndex(object):
self
.
nodes
=
jgraph
[
"nodes"
]
self
.
entry_ptr
=
jgraph
[
"node_row_ptr"
]
self
.
_name2nodeid
=
{
n
[
"name"
]:
i
for
i
,
n
in
enumerate
(
self
.
nodes
)}
self
.
input_names
=
graph
.
symbol
.
list_input_names
()
self
.
output_entries
=
jgraph
[
"heads"
]
@property
def
num_nodes
(
self
):
...
...
@@ -66,6 +68,10 @@ class GraphIndex(object):
index : int
The entry index
"""
if
isinstance
(
key
,
(
list
,
tuple
)):
if
len
(
key
)
!=
3
:
raise
ValueError
(
"Expect entry index to be tuple of 3 elems"
)
key
,
value_index
,
_
=
key
idx
=
self
.
node_id
(
key
)
if
isinstance
(
key
,
str
)
else
key
assert
value_index
<
self
.
entry_ptr
[
idx
+
1
]
return
self
.
entry_ptr
[
idx
]
+
value_index
...
...
nnvm/python/nnvm/top/attr_dict.py
View file @
7c95535c
...
...
@@ -68,6 +68,21 @@ class AttrDict(object):
"""
return
int
(
self
[
key
])
def
get_float
(
self
,
key
):
"""Get float from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
value : float
The result value
"""
return
float
(
self
[
key
])
def
get_bool
(
self
,
key
):
"""Get bool from attr dict
...
...
nnvm/python/nnvm/top/tensor.py
View file @
7c95535c
...
...
@@ -17,6 +17,17 @@ def _schedule_broadcast(_, outs, target):
tvm
.
schedule
.
AutoInlineInjective
(
s
)
return
s
def
_compute_binary_scalar
(
f
):
"""auxiliary function"""
@tvm.tag_scope
(
"ewise"
)
def
_compute
(
attrs
,
x
):
x
=
x
[
0
]
scalar
=
attrs
.
get_float
(
"scalar"
)
scalar
=
tvm
.
const
(
scalar
,
x
.
dtype
)
return
tvm
.
compute
(
x
.
shape
,
lambda
*
i
:
f
(
x
(
*
i
),
scalar
))
return
_compute
_fschedule_broadcast
=
tvm
.
convert
(
_schedule_broadcast
)
# exp
...
...
@@ -25,6 +36,12 @@ reg.register_compute("exp",
reg
.
register_pattern
(
"exp"
,
OpPattern
.
ELEM_WISE
)
reg
.
register_schedule
(
"exp"
,
_fschedule_broadcast
)
# add scalar
reg
.
register_compute
(
"__add_scalar__"
,
_compute_binary_scalar
(
lambda
x
,
y
:
x
+
y
))
reg
.
register_pattern
(
"__add_scalar__"
,
OpPattern
.
ELEM_WISE
)
reg
.
register_schedule
(
"__add_scalar__"
,
_fschedule_broadcast
)
# broadcast_add
reg
.
register_compute
(
"broadcast_add"
,
lambda
_
,
x
:
topi
.
broadcast_add
(
x
[
0
],
x
[
1
]))
...
...
nnvm/src/compiler/packed_func_ext.cc
View file @
7c95535c
...
...
@@ -104,5 +104,19 @@ TVM_REGISTER_GLOBAL("nnvm._register_pattern")
Op
&
op
=
::
dmlc
::
Registry
<
nnvm
::
Op
>::
Get
()
->
__REGISTER_OR_GET__
(
args
[
0
]);
op
.
set_attr
<
TOpPattern
>
(
"TOpPattern"
,
args
[
1
].
operator
int
(),
args
[
2
]);
});
TVM_REGISTER_GLOBAL
(
"nnvm.graph_attr._move_module"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
const
nnvm
::
Graph
&
g
=
args
[
0
].
AsExtension
<
Graph
>
();
*
rv
=
const_cast
<
nnvm
::
Graph
*>
(
&
g
)
->
MoveCopyAttr
<
tvm
::
runtime
::
Module
>
(
args
[
1
]);
});
TVM_REGISTER_GLOBAL
(
"nnvm.graph_attr._move_graph"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
const
nnvm
::
Graph
&
g
=
args
[
0
].
AsExtension
<
Graph
>
();
*
rv
=
const_cast
<
nnvm
::
Graph
*>
(
&
g
)
->
MoveCopyAttr
<
nnvm
::
Graph
>
(
args
[
1
]);
});
}
// namespace compiler
}
// namespace nnvm
nnvm/src/compiler/pass/graph_fuse.cc
View file @
7c95535c
...
...
@@ -381,13 +381,5 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
NNVM_REGISTER_PASS
(
GraphFuse
)
.
set_body
(
GraphFuse
);
TVM_REGISTER_GLOBAL
(
"nnvm.compiler._move_module"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
const
nnvm
::
Graph
&
g
=
args
[
0
].
AsExtension
<
Graph
>
();
*
rv
=
const_cast
<
nnvm
::
Graph
*>
(
&
g
)
->
MoveCopyAttr
<
tvm
::
runtime
::
Module
>
(
"module"
);
});
}
// namespace compiler
}
// namespace nnvm
nnvm/src/compiler/pass/layout_transform.cc
View file @
7c95535c
...
...
@@ -44,7 +44,7 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
const
ShapeVector
&
shape_vec
=
src
.
GetAttr
<
ShapeVector
>
(
"shape"
);
const
std
::
vector
<
TLayoutInfo
>&
input_layouts
=
src
.
GetAttr
<
std
::
vector
<
TLayoutInfo
>
>
(
"layout
"
);
src
.
GetAttr
<
std
::
vector
<
TLayoutInfo
>
>
(
"layout_inputs
"
);
const
IndexedGraph
&
idx
=
src
.
indexed_graph
();
std
::
vector
<
TLayoutInfo
>
produce_vec
(
idx
.
num_node_entries
(),
GetDefaultLayout
());
...
...
nnvm/src/compiler/pass/pr
une_graph
.cc
→
nnvm/src/compiler/pass/pr
ecompute_prune
.cc
View file @
7c95535c
/*!
* Copyright (c) 2017 by Contributors
* \file pr
une_graph
.cc
* \brief
Prune the graph to do constant folding
.
* \file pr
ecompute_prune
.cc
* \brief
Split the graph into a pre-compute graph and a execution graph
.
*
* Th
is pass breaks the graph into pre-compute graph
*
and the execution graph
.
* Th
e pre-compute graph outputs parameters that can be taken
*
by execution graph during execution phase
.
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
...
...
@@ -16,11 +16,15 @@
namespace
nnvm
{
namespace
compiler
{
nnvm
::
Graph
PruneGraph
(
nnvm
::
Graph
src
)
{
const
auto
&
params
=
src
.
GetAttr
<
std
::
unordered_set
<
std
::
string
>
>
(
"params"
);
nnvm
::
Graph
PrecomputePrune
(
nnvm
::
Graph
src
)
{
const
auto
&
plist
=
src
.
GetAttr
<
std
::
vector
<
std
::
string
>
>
(
"param_name_list"
);
std
::
unordered_set
<
std
::
string
>
params
(
plist
.
begin
(),
plist
.
end
());
std
::
unordered_set
<
nnvm
::
Node
*>
pruned
;
nnvm
::
NodeEntryMap
<
nnvm
::
NodePtr
>
entry_var
;
std
::
unordered_set
<
std
::
string
>
unique_name
;
DFSVisit
(
src
.
outputs
,
[
&
](
const
nnvm
::
NodePtr
&
n
)
{
bool
can_be_pruned
=
true
;
if
(
n
->
is_variable
())
{
...
...
@@ -45,7 +49,12 @@ nnvm::Graph PruneGraph(nnvm::Graph src) {
nnvm
::
NodePtr
var
=
nnvm
::
Node
::
Create
();
var
->
attrs
.
name
=
e
.
node
->
attrs
.
name
+
"_output"
+
std
::
to_string
(
e
.
index
);
entry_var
.
emplace
(
e
,
var
);
CHECK
(
!
unique_name
.
count
(
var
->
attrs
.
name
));
unique_name
.
insert
(
var
->
attrs
.
name
);
}
// TODO(ziheng): this pass now mutates the original graph structure
// This might not be a good thing, change to copy the structure instead
//
e
=
nnvm
::
NodeEntry
{
entry_var
.
at
(
e
),
0
,
0
};
}
}
...
...
@@ -56,21 +65,21 @@ nnvm::Graph PruneGraph(nnvm::Graph src) {
pre_graph
.
outputs
.
reserve
(
entry_var
.
size
());
std
::
vector
<
std
::
string
>
output_names
;
output_names
.
reserve
(
entry_var
.
size
());
for
(
auto
kv
:
entry_var
)
{
if
(
kv
.
first
.
node
->
is_variable
())
continue
;
pre_graph
.
outputs
.
emplace_back
(
kv
.
first
);
output_names
.
emplace_back
(
kv
.
second
->
attrs
.
name
);
}
pre_graph
.
attrs
[
"
pruned_param
s"
]
=
// new parameter list
pre_graph
.
attrs
[
"
output_name
s"
]
=
std
::
make_shared
<
dmlc
::
any
>
(
std
::
move
(
output_names
));
src
.
attrs
[
"pre_graph"
]
=
src
.
attrs
[
"pre
compute
_graph"
]
=
std
::
make_shared
<
dmlc
::
any
>
(
std
::
move
(
pre_graph
));
return
src
;
}
NNVM_REGISTER_PASS
(
PruneGraph
)
.
set_body
(
PruneGraph
);
NNVM_REGISTER_PASS
(
PrecomputePrune
)
.
set_body
(
PrecomputePrune
);
}
// namespace compiler
}
// namespace nnvm
nnvm/tests/python/compiler/test_build.py
View file @
7c95535c
...
...
@@ -17,8 +17,8 @@ def test_compile():
m
=
nnvm
.
runtime
.
create
(
graph
,
lib
,
tvm
.
cpu
(
0
))
# get member functions
set_input
,
run
,
get_output
=
m
[
"set_input"
],
m
[
"run"
],
m
[
"get_output"
]
na
=
tvm
.
nd
.
array
(
np
.
ones
(
shape
)
.
astype
(
dtype
))
nb
=
tvm
.
nd
.
array
(
np
.
ones
(
shape
)
.
astype
(
dtype
))
na
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
dtype
))
nb
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
dtype
))
# set inputs
set_input
(
"x"
,
na
)
set_input
(
"y"
,
nb
)
...
...
@@ -30,5 +30,37 @@ def test_compile():
np
.
testing
.
assert_allclose
(
out
.
asnumpy
(),
np
.
exp
(
na
.
asnumpy
()
+
nb
.
asnumpy
()))
def
test_run
():
x
=
sym
.
Variable
(
"x"
)
y
=
sym
.
Variable
(
"y"
)
z
=
sym
.
exp
(
y
+
x
)
shape
=
(
10
,
10
)
dtype
=
tvm
.
float32
nx
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
dtype
))
ny
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
dtype
))
res
=
nnvm
.
compiler
.
_run_graph
(
z
,
{
"x"
:
nx
,
"y"
:
ny
})
np
.
testing
.
assert_allclose
(
res
[
0
]
.
asnumpy
(),
np
.
exp
(
nx
.
asnumpy
()
+
ny
.
asnumpy
()))
def
test_precompute_prune
():
x
=
sym
.
Variable
(
"x"
)
+
1
y
=
sym
.
Variable
(
"y"
)
z
=
y
+
x
shape
=
(
10
,
10
)
dtype
=
tvm
.
float32
nx
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
dtype
))
ny
=
tvm
.
nd
.
array
(
np
.
random
.
uniform
(
size
=
shape
)
.
astype
(
dtype
))
params
=
{
"x"
:
nx
}
graph
,
pdict
=
nnvm
.
compiler
.
precompute_prune
(
z
,
params
)
pdict
[
"y"
]
=
ny
res
=
nnvm
.
compiler
.
_run_graph
(
z
,
pdict
)
np
.
testing
.
assert_allclose
(
res
[
0
]
.
asnumpy
(),
nx
.
asnumpy
()
+
1
+
ny
.
asnumpy
())
if
__name__
==
"__main__"
:
test_compile
()
test_run
()
test_precompute_prune
()
nnvm/tests/python/compiler/test_graph_pass.py
0 → 100644
View file @
7c95535c
"""Unittest cases for graph pass"""
import
nnvm
import
nnvm.compiler
from
nnvm.compiler
import
graph_pass
def
test_infer_attr
():
x
=
nnvm
.
symbol
.
Variable
(
"x"
)
y
=
x
*
2
g
=
nnvm
.
graph
.
create
(
y
)
ishape
,
oshape
=
graph_pass
.
infer_shape
(
g
,
x
=
(
10
,
20
))
assert
tuple
(
oshape
[
0
])
==
(
10
,
20
)
itype
,
otype
=
graph_pass
.
infer_dtype
(
g
,
x
=
"float32"
)
assert
otype
[
0
]
==
"float32"
if
__name__
==
"__main__"
:
test_infer_attr
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment