Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6e2c7ede
Commit
6e2c7ede
authored
Jun 13, 2019
by
Zhi
Committed by
Tianqi Chen
Jun 13, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay][Transform] quantize opt passes to pass manager (#3289)
parent
579e96da
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
130 deletions
+107
-130
python/tvm/relay/quantize/quantize.py
+68
-105
src/relay/pass/pass_manager.cc
+1
-0
src/relay/pass/quantize.cc
+38
-25
No files found.
python/tvm/relay/quantize/quantize.py
View file @
6e2c7ede
...
@@ -21,7 +21,9 @@ import numpy as np
...
@@ -21,7 +21,9 @@ import numpy as np
from
.
import
_quantize
from
.
import
_quantize
from
..
import
expr
as
_expr
from
..
import
expr
as
_expr
from
..
import
module
as
_module
from
..
import
ir_pass
as
_ir_pass
from
..
import
ir_pass
as
_ir_pass
from
..
import
transform
as
_transform
from
..
import
op
as
_op
from
..
import
op
as
_op
from
...
import
make
as
_make
from
...
import
make
as
_make
from
..base
import
NodeBase
,
register_relay_node
from
..base
import
NodeBase
,
register_relay_node
...
@@ -178,26 +180,7 @@ def _set_conv_counter(n):
...
@@ -178,26 +180,7 @@ def _set_conv_counter(n):
CONV_COUNTER
=
n
CONV_COUNTER
=
n
def
annotate
(
graph
):
def
calibrate
(
graph
,
mod
=
None
,
ctx
=
None
):
"""Given a float32 graph, annotate will rewrite the graph
and return back a graph which simulates the error brought by
current quantization scheme.
Parameters
---------
graph: Function
The original graph
Returns
-------
ret: Function
The graph after annotation
"""
_set_conv_counter
(
0
)
# reset counter
return
_quantize
.
annotate
(
graph
)
def
calibrate
(
graph
,
dataset
=
None
):
"""The calibrate procedure will try to calculate the content of
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
operator.
...
@@ -207,8 +190,11 @@ def calibrate(graph, dataset=None):
...
@@ -207,8 +190,11 @@ def calibrate(graph, dataset=None):
graph: Function
graph: Function
The simulation graph after annotation.
The simulation graph after annotation.
dataset: list of dict of Var -> NDArray
mod: tvm.relay.Module
The calibration dataset.
The module where calibration happens on.
ctx: tvm.relay.PassContext
The pass context used for calibration.
Returns
Returns
-------
-------
...
@@ -253,93 +239,52 @@ def calibrate(graph, dataset=None):
...
@@ -253,93 +239,52 @@ def calibrate(graph, dataset=None):
return
_expr
.
bind
(
graph
,
const_params
)
return
_expr
.
bind
(
graph
,
const_params
)
def
realize
(
graph
):
def
annotate
():
"""The realize pass will transform the simulated quantized
"""Given a float32 graph, this pass will rewrite the graph and return
graph, which computes with float32 actually, to a real low-bit
a graph which simulates the error brought by the current quantization
integer graph. It will replace the simulated_quantize with
scheme.
several fine-grained operators like add, multiply, and shift
as more as possible for performance (fusion, etc.)
Parameters
---------
graph: Function
The simulated graph after calibrating.
Returns
Returns
-------
-------
ret:
Function
ret:
tvm.relay.Pass
The
graph after realization
The
registered pass for quantization annotation.
"""
"""
return
_quantize
.
realize
(
graph
)
return
_quantize
.
QuantizeAnnotate
(
)
def
optimize
(
func
,
params
=
None
):
def
realize
():
""" Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"""The realize pass will transform the simulated quantized graph, which
"CanonicalizeOps" optimization before quantization.
actually computes with float32, to a real low-bit integer graph. It will
replace the `simulated_quantize` with several fine-grained operators like
# TODO(zhiics) These passes are executed one by one so far. We need to
add, multiply, and shift as much as possible for better performance.
# move them to the pass manager.
Parameters
---------
func: tvm.relay.Function
The original Relay function to be optimized.
params : dict of str to tvm.NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
Returns
-------
-------
ret: tvm.relay.
Function
ret: tvm.relay.
Pass
The
graph after quantization
The
registered pass for quantization realization.
"""
"""
return
_quantize
.
QuantizeRealize
()
opt_passes
=
[
"SimplifyInference"
,
"FoldScaleAxis"
,
"FoldConstant"
,
"CanonicalizeOps"
]
if
params
:
def
_bind_params
(
func
,
params
):
name_dict
=
{}
"""Bind the params to the expression.
for
arg
in
func
.
params
:
"""
name
=
arg
.
name_hint
name_dict
=
{}
if
name
in
name_dict
:
for
arg
in
func
.
params
:
name_dict
[
name
]
=
None
name
=
arg
.
name_hint
else
:
if
name
in
name_dict
:
name_dict
[
name
]
=
arg
name_dict
[
name
]
=
None
bind_dict
=
{}
else
:
for
k
,
v
in
params
.
items
():
name_dict
[
name
]
=
arg
if
k
not
in
name_dict
:
bind_dict
=
{}
continue
for
k
,
v
in
params
.
items
():
arg
=
name_dict
[
k
]
if
k
not
in
name_dict
:
if
arg
is
None
:
continue
raise
ValueError
(
"Multiple args in the function have name
%
s"
%
k
)
arg
=
name_dict
[
k
]
bind_dict
[
arg
]
=
_expr
.
const
(
v
)
if
arg
is
None
:
func
=
_expr
.
bind
(
func
,
bind_dict
)
raise
ValueError
(
"Multiple args in the function have name
%
s"
%
k
)
bind_dict
[
arg
]
=
_expr
.
const
(
v
)
if
"SimplifyInference"
in
opt_passes
:
return
_expr
.
bind
(
func
,
bind_dict
)
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
simplify_inference
(
func
)
if
"FoldConstant"
in
opt_passes
:
func
=
_ir_pass
.
fold_constant
(
func
)
if
"FoldScaleAxis"
in
opt_passes
:
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
backward_fold_scale_axis
(
func
)
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
forward_fold_scale_axis
(
func
)
func
=
_ir_pass
.
fold_constant
(
func
)
if
"CanonicalizeOps"
in
opt_passes
:
func
=
_ir_pass
.
infer_type
(
func
)
func
=
_ir_pass
.
canonicalize_ops
(
func
)
if
"FoldConstant"
in
opt_passes
:
func
=
_ir_pass
.
fold_constant
(
func
)
return
func
def
quantize
(
graph
,
params
=
None
,
dataset
=
None
):
def
quantize
(
graph
,
params
=
None
,
dataset
=
None
):
...
@@ -365,11 +310,29 @@ def quantize(graph, params=None, dataset=None):
...
@@ -365,11 +310,29 @@ def quantize(graph, params=None, dataset=None):
ret: Function
ret: Function
The graph after quantization
The graph after quantization
"""
"""
# TODO(zhiics) Move this to the pass manager.
if
params
:
graph
=
optimize
(
graph
,
params
)
graph
=
_bind_params
(
graph
,
params
)
graph
=
annotate
(
graph
)
mod
=
_module
.
Module
.
from_expr
(
graph
)
graph
=
calibrate
(
graph
,
dataset
)
# Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
graph
=
realize
(
graph
)
# "CanonicalizeOps" optimization before quantization.
graph
=
_ir_pass
.
fold_constant
(
graph
)
optimize
=
_transform
.
Sequential
([
_transform
.
SimplifyInference
(),
return
graph
_transform
.
FoldConstant
(),
_transform
.
FoldScaleAxis
(),
_transform
.
CanonicalizeOps
(),
_transform
.
FoldConstant
()])
calibrate_pass
=
_transform
.
function_pass
(
calibrate
,
opt_level
=
1
,
name
=
"QuantizeCalibrate"
)
_set_conv_counter
(
0
)
# reset counter
quantize_seq
=
_transform
.
Sequential
([
annotate
(),
calibrate_pass
,
realize
(),
_transform
.
FoldConstant
()])
with
_transform
.
PassContext
(
opt_level
=
3
,
required_pass
=
[
"QuantizeAnnotate"
,
"QuantizeCalibrate"
,
"QuantizeRealize"
]):
mod
=
optimize
(
mod
)
mod
=
quantize_seq
(
mod
)
return
mod
[
mod
.
entry_func
.
name_hint
]
src/relay/pass/pass_manager.cc
View file @
6e2c7ede
...
@@ -313,6 +313,7 @@ Module FunctionPassNode::operator()(const Module& mod,
...
@@ -313,6 +313,7 @@ Module FunctionPassNode::operator()(const Module& mod,
<<
pass_info
->
name
<<
pass_info
->
name
<<
" with opt level: "
<<
" with opt level: "
<<
pass_info
->
opt_level
;
<<
pass_info
->
opt_level
;
Module
updated_mod
=
mod
;
Module
updated_mod
=
mod
;
// Execute the pass function and return a new module.
// Execute the pass function and return a new module.
std
::
vector
<
std
::
pair
<
GlobalVar
,
Function
>
>
updates
;
std
::
vector
<
std
::
pair
<
GlobalVar
,
Function
>
>
updates
;
...
...
src/relay/pass/quantize.cc
View file @
6e2c7ede
...
@@ -43,6 +43,8 @@ namespace tvm {
...
@@ -43,6 +43,8 @@ namespace tvm {
namespace
relay
{
namespace
relay
{
namespace
quantize
{
namespace
quantize
{
using
namespace
relay
::
transform
;
/*! \brief Attribute for simulated quantize operator */
/*! \brief Attribute for simulated quantize operator */
struct
SimulatedQuantizeAttrs
:
public
tvm
::
AttrsNode
<
SimulatedQuantizeAttrs
>
{
struct
SimulatedQuantizeAttrs
:
public
tvm
::
AttrsNode
<
SimulatedQuantizeAttrs
>
{
int
kind
;
int
kind
;
...
@@ -131,23 +133,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
...
@@ -131,23 +133,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
static_cast
<
QAnnotateKind
>
(
args
[
1
].
operator
int
()));
static_cast
<
QAnnotateKind
>
(
args
[
1
].
operator
int
()));
});
});
TVM_REGISTER_API
(
"relay._quantize.annotate"
)
.
set_body_typed
<
Expr
(
Expr
)
>
([]
(
const
Expr
&
expr
)
{
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref
=
[](
const
Expr
&
e
)
{
if
(
e
->
derived_from
<
TempExprNode
>
())
{
const
auto
*
n
=
e
.
as
<
QAnnotateExprNode
>
();
CHECK
(
n
);
const
PackedFunc
*
f
=
runtime
::
Registry
::
Get
(
"relay.quantize.attach_simulated_quantize"
);
Expr
ret
=
(
*
f
)(
n
->
expr
,
static_cast
<
int
>
(
kQInput
));
return
static_cast
<
Expr
>
(
QAnnotateExprNode
::
make
(
ret
,
kQInput
));
}
return
e
;
};
return
ForwardRewrite
(
expr
,
"FQAnnotateRewrite"
,
nullptr
,
fmulti_ref
);
});
// =============
// =============
// realize pass
// realize pass
...
@@ -536,14 +521,6 @@ Expr AvgPoolRealize(const Call& ref_call,
...
@@ -536,14 +521,6 @@ Expr AvgPoolRealize(const Call& ref_call,
RELAY_REGISTER_OP
(
"nn.avg_pool2d"
)
RELAY_REGISTER_OP
(
"nn.avg_pool2d"
)
.
set_attr
<
FForwardRewrite
>
(
"FQRealizeRewrite"
,
AvgPoolRealize
);
.
set_attr
<
FForwardRewrite
>
(
"FQRealizeRewrite"
,
AvgPoolRealize
);
TVM_REGISTER_API
(
"relay._quantize.realize"
)
.
set_body_typed
<
Expr
(
Expr
)
>
([](
const
Expr
&
e
)
{
Expr
ret
=
ForwardRewrite
(
e
,
"FQRealizeRewrite"
,
nullptr
,
nullptr
);
return
ret
;
});
// =============
// =============
// qconfig
// qconfig
...
@@ -613,6 +590,42 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
...
@@ -613,6 +590,42 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
TVM_REGISTER_API
(
"relay._quantize._ExitQConfigScope"
)
TVM_REGISTER_API
(
"relay._quantize._ExitQConfigScope"
)
.
set_body_typed
(
QConfig
::
ExitQConfigScope
);
.
set_body_typed
(
QConfig
::
ExitQConfigScope
);
Pass
QuantizeAnnotate
()
{
std
::
function
<
Expr
(
const
Expr
&
)
>
fmulti_ref
=
[](
const
Expr
&
e
)
{
if
(
e
->
derived_from
<
TempExprNode
>
())
{
const
auto
*
n
=
e
.
as
<
QAnnotateExprNode
>
();
CHECK
(
n
);
const
PackedFunc
*
f
=
runtime
::
Registry
::
Get
(
"relay.quantize.attach_simulated_quantize"
);
Expr
ret
=
(
*
f
)(
n
->
expr
,
static_cast
<
int
>
(
kQInput
));
return
static_cast
<
Expr
>
(
QAnnotateExprNode
::
make
(
ret
,
kQInput
));
}
return
e
;
};
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ForwardRewrite
(
f
,
"FQAnnotateRewrite"
,
fmulti_ref
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"QuantizeAnnotate"
,
{});
}
TVM_REGISTER_API
(
"relay._quantize.QuantizeAnnotate"
)
.
set_body_typed
(
QuantizeAnnotate
);
Pass
QuantizeRealizePass
()
{
runtime
::
TypedPackedFunc
<
Function
(
Function
,
Module
,
PassContext
)
>
pass_func
=
[
=
](
Function
f
,
Module
m
,
PassContext
pc
)
{
return
Downcast
<
Function
>
(
ForwardRewrite
(
f
,
"FQRealizeRewrite"
,
nullptr
,
nullptr
));
};
return
CreateFunctionPass
(
pass_func
,
1
,
"QuantizeRealize"
,
{});
}
TVM_REGISTER_API
(
"relay._quantize.QuantizeRealize"
)
.
set_body_typed
(
QuantizeRealizePass
);
}
// namespace quantize
}
// namespace quantize
}
// namespace relay
}
// namespace relay
}
// namespace tvm
}
// namespace tvm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment