Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6e2c7ede
Commit
6e2c7ede
authored
Jun 13, 2019
by
Zhi
Committed by
Tianqi Chen
Jun 13, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Transform] quantize opt passes to pass manager (#3289)
parent
579e96da
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
130 deletions
+107
-130
python/tvm/relay/quantize/quantize.py
+68
-105
src/relay/pass/pass_manager.cc
+1
-0
src/relay/pass/quantize.cc
+38
-25
No files found.
python/tvm/relay/quantize/quantize.py
View file @
6e2c7ede
...
...
@@ -21,7 +21,9 @@ import numpy as np
from
.
import
_quantize
from
..
import
expr
as
_expr
from
..
import
module
as
_module
from
..
import
ir_pass
as
_ir_pass
from
..
import
transform
as
_transform
from
..
import
op
as
_op
from
...
import
make
as
_make
from
..base
import
NodeBase
,
register_relay_node
...
...
@@ -178,26 +180,7 @@ def _set_conv_counter(n):
CONV_COUNTER
=
n
def
annotate
(
graph
):
"""Given a float32 graph, annotate will rewrite the graph
and return back a graph which simulates the error brought by
current quantization scheme.
Parameters
---------
graph: Function
The original graph
Returns
-------
ret: Function
The graph after annotation
"""
_set_conv_counter
(
0
)
# reset counter
return
_quantize
.
annotate
(
graph
)
def
calibrate
(
graph
,
dataset
=
None
):
def
calibrate
(
graph
,
mod
=
None
,
ctx
=
None
):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
...
...
@@ -207,8 +190,11 @@ def calibrate(graph, dataset=None):
graph: Function
The simulation graph after annotation.
dataset: list of dict of Var -> NDArray
The calibration dataset.
mod: tvm.relay.Module
The module where calibration happens on.
ctx: tvm.relay.PassContext
The pass context used for calibration.
Returns
-------
...
...
@@ -253,93 +239,52 @@ def calibrate(graph, dataset=None):
return
_expr
.
bind
(
graph
,
const_params
)
def
realize
(
graph
):
"""The realize pass will transform the simulated quantized
graph, which computes with float32 actually, to a real low-bit
integer graph. It will replace the simulated_quantize with
several fine-grained operators like add, multiply, and shift
as more as possible for performance (fusion, etc.)
Parameters
---------
graph: Function
The simulated graph after calibrating.
def
annotate
():
"""Given a float32 graph, this pass will rewrite the graph and return
a graph which simulates the error brought by the current quantization
scheme.
Returns
-------
ret:
Function
The
graph after realization
ret:
tvm.relay.Pass
The
registered pass for quantization annotation.
"""
return
_quantize
.
realize
(
graph
)
return
_quantize
.
QuantizeAnnotate
(
)
def
optimize
(
func
,
params
=
None
):
""" Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization.
# TODO(zhiics) These passes are executed one by one so far. We need to
# move them to the pass manager.
Parameters
---------
func: tvm.relay.Function
The original Relay function to be optimized.
params : dict of str to tvm.NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
def
realize
():
"""The realize pass will transform the simulated quantized graph, which
actually computes with float32, to a real low-bit integer graph. It will
replace the `simulated_quantize` with several fine-grained operators like
add, multiply, and shift as much as possible for better performance.
Returns
-------
ret: tvm.relay.
Function
The
graph after quantization
ret: tvm.relay.
Pass
The
registered pass for quantization realization.
"""
return
_quantize
.
QuantizeRealize
()
opt_passes
=
[
"SimplifyInference"
,
"FoldScaleAxis"
,
"FoldConstant"
,
"CanonicalizeOps"
]
if
params
:
name_dict
=
{}
for
arg
in
func
.
params
:
name
=
arg
.
name_hint
if
name
in
name_dict
:
name_dict
[
name
]
=
None
else
:
name_dict
[
name
]
=
arg
bind_dict
=
{}
for
k
,
v
in
params
.
items
():
if
k
not
in
name_dict
:
continue
arg
=
name_dict
[
k
]
if
arg
is
None
:
raise
ValueError
(
"Multiple args in the function have name
%
s"
%
k
)
bind_dict
[
arg
]
=
_expr
.
const
(
v
)
func
=
_expr
.
bind
(
func
,
bind_dict
)
if
"SimplifyInference"
in
opt_passes
:
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
simplify_inference
(
func
)
if
"FoldConstant"
in
opt_passes
:
func
=
_ir_pass
.
fold_constant
(
func
)
if
"FoldScaleAxis"
in
opt_passes
:
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
backward_fold_scale_axis
(
func
)
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
forward_fold_scale_axis
(
func
)
func
=
_ir_pass
.
fold_constant
(
func
)
if
"CanonicalizeOps"
in
opt_passes
:
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
canonicalize_ops
(
func
)
if
"FoldConstant"
in
opt_passes
:
func
=
_ir_pass
.
fold_constant
(
func
)
return
func
def
_bind_params
(
func
,
params
):
"""Bind the params to the expression.
"""
name_dict
=
{}
for
arg
in
func
.
params
:
name
=
arg
.
name_hint
if
name
in
name_dict
:
name_dict
[
name
]
=
None
else
:
name_dict
[
name
]
=
arg
bind_dict
=
{}
for
k
,
v
in
params
.
items
():
if
k
not
in
name_dict
:
continue
arg
=
name_dict
[
k
]
if
arg
is
None
:
raise
ValueError
(
"Multiple args in the function have name
%
s"
%
k
)
bind_dict
[
arg
]
=
_expr
.
const
(
v
)
return
_expr
.
bind
(
func
,
bind_dict
)
def
quantize
(
graph
,
params
=
None
,
dataset
=
None
):
...
...
@@ -365,11 +310,29 @@ def quantize(graph, params=None, dataset=None):
ret: Function
The graph after quantization
"""
# TODO(zhiics) Move this to the pass manager.
graph
=
optimize
(
graph
,
params
)
graph
=
annotate
(
graph
)
graph
=
calibrate
(
graph
,
dataset
)
graph
=
realize
(
graph
)
graph
=
_ir_pass
.
fold_constant
(
graph
)
return
graph
if
params
:
graph
=
_bind_params
(
graph
,
params
)
mod
=
_module
.
Module
.
from_expr
(
graph
)
# Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
# "CanonicalizeOps" optimization before quantization.
optimize
=
_transform
.
Sequential
([
_transform
.
SimplifyInference
(),
_transform
.
FoldConstant
(),
_transform
.
FoldScaleAxis
(),
_transform
.
CanonicalizeOps
(),
_transform
.
FoldConstant
()])
calibrate_pass
=
_transform
.
function_pass
(
calibrate
,
opt_level
=
1
,
name
=
"QuantizeCalibrate"
)
_set_conv_counter
(
0
)
# reset counter
quantize_seq
=
_transform
.
Sequential
([
annotate
(),
calibrate_pass
,
realize
(),
_transform
.
FoldConstant
()])
with
_transform
.
PassContext
(
opt_level
=
3
,
required_pass
=
[
"QuantizeAnnotate"
,
"QuantizeCalibrate"
,
"QuantizeRealize"
]):
mod
=
optimize
(
mod
)
mod
=
quantize_seq
(
mod
)
return
mod
[
mod
.
entry_func
.
name_hint
]
src/relay/pass/pass_manager.cc
View file @
6e2c7ede
...
...
@@ -313,6 +313,7 @@ Module FunctionPassNode::operator()(const Module& mod,
<<
pass_info
->
name
<<
" with opt level: "
<<
pass_info
->
opt_level
;
Module
updated_mod
=
mod
;
// Execute the pass function and return a new module.
std
::
vector
<
std
::
pair
<
GlobalVar
,
Function
>
>
updates
;
...
...
src/relay/pass/quantize.cc
View file @
6e2c7ede
...
...
@@ -43,6 +43,8 @@ namespace tvm {
namespace
relay
{
namespace
quantize
{
using
namespace
relay
::
transform
;
/*! \brief Attribute for simulated quantize operator */
struct
SimulatedQuantizeAttrs
:
public
tvm
::
AttrsNode
<
SimulatedQuantizeAttrs
>
{
int
kind
;
...
...
@@ -131,23 +133,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
static_cast
<
QAnnotateKind
>
(
args
[
1
].
operator
int
()));
});
TVM_REGISTER_API
(
"relay._quantize.annotate"
)
.
set_body_typed
<
Expr
(
Expr
)
>
([]
(
const
Expr
&
expr
)
{
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref
=
[](
const
Expr
&
e
)
{
if
(
e
->
derived_from
<
TempExprNode
>
())
{
const
auto
*
n
=
e
.
as
<
QAnnotateExprNode
>
();
CHECK
(
n
);
const
PackedFunc
*
f
=
runtime
::
Registry
::
Get
(
"relay.quantize.attach_simulated_quantize"
);
Expr
ret
=
(
*
f
)(
n
->
expr
,
static_cast
<
int
>
(
kQInput
));
return
static_cast
<
Expr
>
(
QAnnotateExprNode
::
make
(
ret
,
kQInput
));
}
return
e
;
};
return
ForwardRewrite
(
expr
,
"FQAnnotateRewrite"
,
nullptr
,
fmulti_ref
);
});
// =============
// realize pass
...
...
@@ -536,14 +521,6 @@ Expr AvgPoolRealize(const Call& ref_call,
RELAY_REGISTER_OP
(
"nn.avg_pool2d"
)
.
set_attr
<
FForwardRewrite
>
(
"FQRealizeRewrite"
,
AvgPoolRealize
);
TVM_REGISTER_API
(
"relay._quantize.realize"
)
.
set_body_typed
<
Expr
(
Expr
)
>
([](
const
Expr
&
e
)
{
Expr
ret
=
ForwardRewrite
(
e
,
"FQRealizeRewrite"
,
nullptr
,
nullptr
);
return
ret
;
});
// =============
// qconfig
...
...
@@ -613,6 +590,42 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
TVM_REGISTER_API
(
"relay._quantize._ExitQConfigScope"
)
.
set_body_typed
(
QConfig
::
ExitQConfigScope
);
Pass
QuantizeAnnotate
()
{
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref
=
[](
const
Expr
&
e
)
{
if
(
e
->
derived_from
<
TempExprNode
>
())
{
const
auto
*
n
=
e
.
as
<
QAnnotateExprNode
>
();
CHECK
(
n
);
const
PackedFunc
*
f
=
runtime
::
Registry
::
Get
(
"relay.quantize.attach_simulated_quantize"
);
Expr
ret
=
(
*
f
)(
n
->
expr
,
static_cast
<
int
>
(
kQInput
));
return
static_cast
<
Expr
>
(
QAnnotateExprNode
::
make
(
ret
,
kQInput
));
}
return
e
;
};
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ForwardRewrite
(
f
,
"FQAnnotateRewrite"
,
fmulti_ref
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"QuantizeAnnotate"
,
{});
}
TVM_REGISTER_API
(
"relay._quantize.QuantizeAnnotate"
)
.
set_body_typed
(
QuantizeAnnotate
);
Pass
QuantizeRealizePass
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ForwardRewrite
(
f
,
"FQRealizeRewrite"
,
nullptr
,
nullptr
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"QuantizeRealize"
,
{});
}
TVM_REGISTER_API
(
"relay._quantize.QuantizeRealize"
)
.
set_body_typed
(
QuantizeRealizePass
);
}
// namespace quantize
}
// namespace relay
}
// namespace tvm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment