Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f8f75ca2
Unverified
Commit
f8f75ca2
authored
Jan 21, 2020
by
masahi
Committed by
GitHub
Jan 21, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Expose relay BindParamsByName to Python (#4751)
* expose BindParamByName to python * fixed alpha equal test
parent
2c0c1849
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
124 additions
and
45 deletions
+124
-45
python/tvm/relay/build_module.py
+33
-6
src/relay/backend/build_module.cc
+47
-39
tests/python/relay/test_pass_fold_constant.py
+44
-0
No files found.
python/tvm/relay/build_module.py
View file @
f8f75ca2
...
...
@@ -51,6 +51,15 @@ def _update_target(target):
return
tgts
def
_convert_param_map
(
params
):
inputs
=
{}
for
name
,
param
in
params
.
items
():
if
isinstance
(
param
,
np
.
ndarray
):
param
=
_nd
.
array
(
param
)
inputs
[
name
]
=
_expr
.
const
(
param
)
return
inputs
class
BuildModule
(
object
):
"""Build a Relay function to run on TVM graph runtime. This class is used
to expose the `RelayBuildModule` APIs implemented in C++.
...
...
@@ -151,12 +160,7 @@ class BuildModule(object):
def
_set_params
(
self
,
params
):
inputs
=
{}
for
name
,
param
in
params
.
items
():
if
isinstance
(
param
,
np
.
ndarray
):
param
=
_nd
.
array
(
param
)
inputs
[
name
]
=
_expr
.
const
(
param
)
self
.
_set_params_func
(
inputs
)
self
.
_set_params_func
(
_convert_param_map
(
params
))
def
get_json
(
self
):
"""Return the json file of the built program."""
...
...
@@ -296,6 +300,29 @@ def optimize(mod, target=None, params=None):
return
mod
,
params
def
bind_params_by_name
(
func
,
params
):
"""Bind params to function by name.
This could be useful when assembling custom Relay optimization
passes that involve constant folding.
Parameters
----------
func : relay.Function
The function to bind parameters to.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
func : relay.Function
The function with parameters bound
"""
inputs
=
_convert_param_map
(
params
)
return
_build_module
.
BindParamsByName
(
func
,
inputs
)
class
GraphExecutor
(
_interpreter
.
Executor
):
"""Wrapper around Executor interface.
...
...
src/relay/backend/build_module.cc
View file @
f8f75ca2
...
...
@@ -42,6 +42,43 @@ using TargetsMap = Map<tvm::Integer, tvm::Target>;
using
namespace
tvm
::
relay
::
transform
;
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay
::
Function
BindParamsByName
(
relay
::
Function
func
,
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
)
{
std
::
unordered_map
<
std
::
string
,
relay
::
Var
>
name_dict
;
std
::
unordered_set
<
relay
::
Var
,
ObjectHash
,
ObjectEqual
>
repeat_var
;
for
(
auto
arg
:
func
->
params
)
{
const
auto
&
name
=
arg
->
name_hint
();
if
(
name_dict
.
count
(
name
))
{
repeat_var
.
insert
(
arg
);
}
else
{
name_dict
[
name
]
=
arg
;
}
}
std
::
unordered_map
<
relay
::
Var
,
Expr
,
ObjectHash
,
ObjectEqual
>
bind_dict
;
for
(
auto
&
kv
:
params
)
{
if
(
name_dict
.
count
(
kv
.
first
)
==
0
)
{
continue
;
}
auto
arg
=
name_dict
.
at
(
kv
.
first
);
if
(
repeat_var
.
count
(
arg
))
{
LOG
(
FATAL
)
<<
"Multiple args in the function have name "
<<
kv
.
first
;
}
bind_dict
[
arg
]
=
ConstantNode
::
make
(
kv
.
second
);
}
Expr
bound_expr
=
relay
::
Bind
(
func
,
bind_dict
);
Function
ret
=
Downcast
<
Function
>
(
bound_expr
);
CHECK
(
ret
.
defined
())
<<
"The returning type is expected to be a Relay Function."
<<
"
\n
"
;
return
ret
;
}
/*!
* \brief Output of building module
*
*/
...
...
@@ -249,45 +286,6 @@ class RelayBuildModule : public runtime::ModuleNode {
protected
:
/*!
* \brief Bind params to function by using name
* \param func Relay function
* \param params params dict
* \return relay::Function
*/
relay
::
Function
BindParamsByName
(
relay
::
Function
func
,
const
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>&
params
)
{
std
::
unordered_map
<
std
::
string
,
relay
::
Var
>
name_dict
;
std
::
unordered_set
<
relay
::
Var
,
ObjectHash
,
ObjectEqual
>
repeat_var
;
for
(
auto
arg
:
func
->
params
)
{
const
auto
&
name
=
arg
->
name_hint
();
if
(
name_dict
.
count
(
name
))
{
repeat_var
.
insert
(
arg
);
}
else
{
name_dict
[
name
]
=
arg
;
}
}
std
::
unordered_map
<
relay
::
Var
,
Expr
,
ObjectHash
,
ObjectEqual
>
bind_dict
;
for
(
auto
&
kv
:
params
)
{
if
(
name_dict
.
count
(
kv
.
first
)
==
0
)
{
continue
;
}
auto
arg
=
name_dict
.
at
(
kv
.
first
);
if
(
repeat_var
.
count
(
arg
))
{
LOG
(
FATAL
)
<<
"Multiple args in the function have name "
<<
kv
.
first
;
}
bind_dict
[
arg
]
=
ConstantNode
::
make
(
kv
.
second
);
}
Expr
bound_expr
=
relay
::
Bind
(
func
,
bind_dict
);
Function
ret
=
Downcast
<
Function
>
(
bound_expr
);
CHECK
(
ret
.
defined
())
<<
"The returning type is expected to be a Relay Function."
<<
"
\n
"
;
return
ret
;
}
/*!
* \brief Optimize a Relay Function.
*
* \param func The input Function where optmization will be applied on.
...
...
@@ -522,6 +520,16 @@ TVM_REGISTER_GLOBAL("relay.build_module._BuildModule")
*
rv
=
RelayBuildCreate
();
});
TVM_REGISTER_GLOBAL
(
"relay.build_module.BindParamsByName"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
Map
<
std
::
string
,
Constant
>
params
=
args
[
1
];
std
::
unordered_map
<
std
::
string
,
runtime
::
NDArray
>
params_
;
for
(
const
auto
&
kv
:
params
)
{
params_
[
kv
.
first
]
=
kv
.
second
->
data
;
}
*
rv
=
BindParamsByName
(
args
[
0
],
params_
);
});
}
// namespace backend
}
// namespace relay
}
// namespace tvm
tests/python/relay/test_pass_fold_constant.py
View file @
f8f75ca2
...
...
@@ -18,6 +18,8 @@ import numpy as np
import
tvm
from
tvm
import
relay
from
tvm.relay
import
transform
from
tvm.relay.build_module
import
bind_params_by_name
from
tvm.relay.testing
import
run_infer_type
,
create_workload
def
run_opt_pass
(
expr
,
opt_pass
):
...
...
@@ -161,6 +163,47 @@ def test_fold_full():
assert
relay
.
analysis
.
graph_equal
(
zz
,
zexpected
)
def
test_fold_batch_norm
():
def
expected
():
data
=
relay
.
var
(
"data"
,
relay
.
TensorType
((
1
,
3
,
224
,
224
),
"float32"
))
weight
=
relay
.
const
(
np
.
zeros
((
16
,
3
,
3
,
3
)))
bias
=
relay
.
const
(
np
.
zeros
((
16
,
1
,
1
)))
conv
=
relay
.
nn
.
conv2d
(
data
=
data
,
weight
=
weight
,
kernel_size
=
(
3
,
3
),
channels
=
16
,
padding
=
(
1
,
1
))
add
=
relay
.
add
(
conv
,
bias
)
return
relay
.
Function
(
relay
.
analysis
.
free_vars
(
add
),
add
)
remove_bn_pass
=
transform
.
Sequential
([
relay
.
transform
.
InferType
(),
relay
.
transform
.
SimplifyInference
(),
relay
.
transform
.
FoldConstant
(),
relay
.
transform
.
FoldScaleAxis
(),
])
data
=
relay
.
var
(
"data"
,
relay
.
TensorType
((
1
,
3
,
224
,
224
),
"float32"
))
weight
=
relay
.
var
(
"weight"
)
bn_gamma
=
relay
.
var
(
"bn_gamma"
)
bn_beta
=
relay
.
var
(
"bn_beta"
)
bn_mmean
=
relay
.
var
(
"bn_mean"
)
bn_mvar
=
relay
.
var
(
"bn_var"
)
conv
=
relay
.
nn
.
conv2d
(
data
=
data
,
weight
=
weight
,
kernel_size
=
(
3
,
3
),
channels
=
16
,
padding
=
(
1
,
1
))
bn_output
=
relay
.
nn
.
batch_norm
(
conv
,
bn_gamma
,
bn_beta
,
bn_mmean
,
bn_mvar
)
def
initializer
(
_
,
param
):
param
=
np
.
zeros
(
param
.
shape
)
mod
,
params
=
create_workload
(
bn_output
[
0
],
initializer
)
mod
[
"main"
]
=
bind_params_by_name
(
mod
[
"main"
],
params
)
with
relay
.
build_config
(
opt_level
=
3
):
mod
=
remove_bn_pass
(
mod
)
expect
=
run_infer_type
(
expected
())
assert
relay
.
analysis
.
graph_equal
(
mod
[
"main"
],
expect
)
if
__name__
==
"__main__"
:
test_fold_const
()
test_fold_let
()
...
...
@@ -168,3 +211,4 @@ if __name__ == "__main__":
test_fold_concat
()
test_fold_shape_of
()
test_fold_full
()
test_fold_batch_norm
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment