Unverified Commit 500ff051 by Wuwei Lin Committed by GitHub

[Relay][Quantize] Integrate data-aware calibration into quantization (#4295)

* [Relay][Quantize] Integrate data-aware calibration into quantization

* Update _calibrate.py

* trigger ci

* Address comments

* address comments
parent af52eba1
......@@ -21,4 +21,3 @@ from __future__ import absolute_import as _abs
from .quantize import *
from ._partition import register_partition_function
from ._annotate import register_annotate_function
from .kl_divergence import kl_divergence_scale
......@@ -57,6 +57,7 @@ _reg.register_schedule("relay.op.annotation.simulated_quantize",
_reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.OpPattern.ELEMWISE)
_reg.register_schedule("annotation.cast_hint", _reg.schedule_injective)
@register_relay_node
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Find scales for quantization on the dataset."""
from __future__ import absolute_import
import logging
import multiprocessing as mp
import numpy as np
import tvm
from . import _quantize
from . import quantize
from .. import op as _op
from .. import expr as _expr
from .. import module as _module
from .. import analysis as _analysis
from .. import transform as _transform
from .. import build_module as _build_module
from ...contrib import graph_runtime
from .kl_divergence import _find_scale_by_kl
def collect_stats(mod, dataset):
"""Given an annotated graph, create a profile graph to collect profile data from the
calibration dataset. This pass collects simulated_quantize op input into a tuple.
Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
graph.
Parameters
----------
mod: Module
The simulation graph after annotation.
Returns
-------
ret: list of ndarray
List of output data of each layer
"""
logging.info("collecting statistics for calibration...")
func = mod['main']
func = _quantize.CreateStatsCollector(func)
target = tvm.target.current_target() or 'llvm'
with _transform.build_config(opt_level=3):
graph, lib, params = _build_module.build(func, target=target)
outputs = []
runtime = graph_runtime.create(graph, lib, tvm.context(target))
runtime.set_input(**params)
num_outputs = runtime.get_num_outputs()
outputs = [[] for i in range(num_outputs)]
for batch in dataset:
runtime.set_input(**batch)
runtime.run()
for i in range(num_outputs):
output = runtime.get_output(i).asnumpy()
outputs[i].append(output)
for i in range(num_outputs):
outputs[i] = np.concatenate(outputs[i]).reshape(-1)
return outputs
def _kl_scale(stats):
with mp.Pool() as pool:
logging.info("finding threshold with kl for calibration...")
scales = list(pool.map(_find_scale_by_kl, stats))
def func(sq_call): # pylint: disable=unused-argument
scale = scales[func.scale_idx]
func.scale_idx += 1
return scale
func.scale_idx = 0
return func
def _set_params(mod, input_scale_func, weight_scale_func):
quantize_op = _op.get("relay.op.annotation.simulated_quantize")
cfg = quantize.current_qconfig()
const_params = {}
def visit_func(expr):
'''visitor function for traverse'''
if isinstance(expr, _expr.Call) and expr.op == quantize_op:
_, ndom_scale, nclip_min, nclip_max = expr.args
attrs = expr.attrs
kind = attrs.kind
nbit = cfg.get_nbit_by_kind(kind)
valid_bit = nbit - attrs.sign
# set scale
if kind == quantize.QAnnotateKind.WEIGHT:
assert isinstance(expr.args[0], _expr.Constant)
scale = weight_scale_func(expr)
else:
scale = input_scale_func(expr)
def _make_const(val):
return _expr.const(val, 'float32')
valid_range = 2**valid_bit
const_params[ndom_scale] = _make_const(scale / valid_range)
const_params[nclip_min] = _make_const(- (valid_range - 1))
const_params[nclip_max] = _make_const((valid_range - 1))
func = mod['main']
_analysis.post_order_visit(func, visit_func)
func = _expr.bind(func, const_params)
return _module.Module.from_expr(func)
# weight scale functions
def _power2_scale(sq_call): # pylint: disable=unused-argument
"""calculate weight scale with nearest mode-2 scale"""
var = sq_call.args[0]
assert isinstance(var, _expr.Constant)
val = np.amax(np.abs(var.data.asnumpy()))
return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0
def _max_scale(sq_call):
"""calculate weight scale with maximum absolute value"""
var = sq_call.args[0]
assert isinstance(var, _expr.Constant)
val = np.amax(np.abs(var.data.asnumpy()))
return val
# input scale functions
def _global_scale(sq_call): # pylint: disable=unused-argument
cfg = quantize.current_qconfig()
return cfg.global_scale
def calibrate(dataset=None):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
Parameters
---------
dataset: Optional[Iterable[NDArray]]
The calibration dataset.
Returns
-------
ret: Function
The module pass function.
"""
def wrapped_func(mod, ctx): # pylint: disable=unused-argument
"""make transform.module pass happy"""
cfg = quantize.current_qconfig()
if cfg.calibrate_mode == 'kl_divergence':
stats = collect_stats(mod, dataset)
input_scale_func = _kl_scale(stats)
elif cfg.calibrate_mode == 'global_scale':
input_scale_func = _global_scale
else:
raise ValueError("Unknown calibrate mode {}".format(cfg.calibrate_mode))
if cfg.weight_scale == 'max':
weight_scale_func = _max_scale
elif cfg.weight_scale == 'power2':
weight_scale_func = _power2_scale
else:
raise ValueError("Unknown weight scale mode {}".format(cfg.weight_scale))
return _set_params(mod, input_scale_func, weight_scale_func)
return wrapped_func
......@@ -45,7 +45,7 @@ def _smooth_distribution(p, eps=0.0001):
# pylint: disable=invalid-name
def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255):
def _find_scale_by_kl(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255):
"""Given a tensor, find the optimal threshold for quantizing it.
The reference distribution is `q`, and the candidate distribution is `p`.
`q` is a truncated version of the original distribution.
......@@ -54,6 +54,8 @@ def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantize
http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
"""
assert isinstance(arr, np.ndarray)
assert stats is not None, "scipy needs to be installed for \
utilizing kl calibration during quantization"
min_val = np.min(arr)
max_val = np.max(arr)
......
......@@ -17,14 +17,10 @@
#pylint: disable=unused-argument
"""Automatic quantization toolkit."""
from __future__ import absolute_import
import numpy as np
from . import _quantize
from ._calibrate import calibrate
from .. import expr as _expr
from .. import module as _module
from .. import analysis as _analysis
from .. import transform as _transform
from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
......@@ -78,7 +74,9 @@ class QConfig(NodeBase):
"dtype_input": "int8",
"dtype_weight": "int8",
"dtype_activation": "int32",
"calibrate_mode": "global_scale",
"global_scale": 8.0,
"weight_scale": "power2",
"skip_conv_layers": [0],
"do_simulation": False,
"round_for_shift": True,
......@@ -143,9 +141,20 @@ def qconfig(**kwargs):
nbit_dict: dict of QAnnotateKind -> int
Number of bit for every kind of annotate field.
calibrate_mode: str
The calibration mode. 'global_scale' or 'kl_divergence'.
global_scale: use global scale
kl_divergence: find scales by kl divergence on the dataset.
global_scale: float
The global scale for calibration.
weight_scale: str
The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT).
power2: Find the maximum of the absolute value of the tensor, and then round up to power
of two.
max: Find the maximum of the absolute value of the tensor
skip_conv_layers: list
Specifying which layers to be skipped. Provide a list of indices
that indicate which conv2d layers to leave untouched. Start from 0.
......@@ -249,113 +258,6 @@ def annotate():
return _quantize.QuantizeAnnotate()
def collect_stats(graph):
"""Given an annotated graph, create a profile graph to collect profile data from the
calibration dataset. This pass collects simulated_quantize op input into a tuple.
Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
graph.
Parameters
----------
graph: Function
The simulation graph after annotation.
Returns
-------
ret: Function
The profile graph which outputs a tuple of profile data.
"""
return _quantize.CollectStats(graph)
def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
Parameters
---------
graph: Function
The simulation graph after annotation.
mod: tvm.relay.Module
The module where calibration happens on.
ctx: tvm.relay.PassContext
The pass context used for calibration.
weight_scales: 'power2' or 'max'.
The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT).
power2: Find the maximum of the absolute value of the tensor, and then round up to power
of two.
max: Find the maximum of the absolute value of the tensor.
scales: List[float]
Pre-calculated scales for input and activations. Length and the order of elements of the
scales list should match the output tuple of the profile graph created by collect_stats.
Returns
-------
ret: Function
The graph after calibration
"""
def power2_scale(arr):
"""calculate weight scale with nearest mode-2 scale"""
val = np.amax(np.abs(arr.asnumpy()))
return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0
def max_scale(arr):
"""calculate weight scale with maximum absolute value"""
val = np.amax(np.abs(arr.asnumpy()))
return val
scale_idx = 0
cfg = current_qconfig()
const_params = {}
quantize_op = _op.get("relay.op.annotation.simulated_quantize")
def visit_func(expr):
"""Internal visit function"""
nonlocal scale_idx
if isinstance(expr, _expr.Call) and expr.op == quantize_op:
_, ndom_scale, nclip_min, nclip_max = expr.args
attrs = expr.attrs
kind = attrs.kind
nbit = cfg.get_nbit_by_kind(kind)
valid_bit = nbit - attrs.sign
if kind in [QAnnotateKind.WEIGHT]:
if all([isinstance(arg, _expr.Constant)
for arg in [ndom_scale, nclip_min, nclip_max]]):
return
var = expr.args[0]
assert isinstance(var, _expr.Constant)
if weight_scales == 'max':
scale = max_scale(var.data)
elif weight_scales == 'power2':
scale = power2_scale(var.data)
else:
raise ValueError('{} not supported'.format(weight_scales))
elif scales is not None:
scale = scales[scale_idx]
scale_idx += 1
else:
scale = cfg.global_scale
def _make_const(val):
return _expr.const(val, 'float32')
valid_range = 2**valid_bit
const_params[ndom_scale] = _make_const(scale / valid_range)
const_params[nclip_min] = _make_const(- (valid_range - 1))
const_params[nclip_max] = _make_const((valid_range - 1))
_analysis.post_order_visit(graph, visit_func)
ret = _expr.bind(graph, const_params)
return ret
def realize():
"""The realize pass will transform the simulated quantized graph, which
actually computes with float32, to a real low-bit integer graph. It will
......@@ -391,7 +293,7 @@ def _bind_params(func, params):
return _expr.bind(func, bind_dict)
def prerequisite_optimize(graph, params=None):
def prerequisite_optimize(mod, params=None):
""" Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """
......@@ -402,15 +304,13 @@ def prerequisite_optimize(graph, params=None):
_transform.FoldConstant()])
if params:
graph = _bind_params(graph, params)
mod['main'] = _bind_params(mod['main'], params)
mod = _module.Module.from_expr(graph)
with _transform.PassContext(opt_level=3):
mod = optimize(mod)
return mod["main"]
mod = optimize(mod)
return mod
def quantize(graph, params=None, dataset=None):
def quantize(mod, params=None, dataset=None):
""" The quantization procedure. Before running the three main
procedure of quantization, "annotate", "calibrate" and "realize"
, we need to do "SimplifyInference", "FoldScaleAxis", "FoldConstant"
......@@ -418,8 +318,8 @@ def quantize(graph, params=None, dataset=None):
Parameters
---------
graph: Function
The original graph.
mod: Module
The original module.
params : dict of str to NDArray
Input parameters to the graph that do not change
......@@ -433,11 +333,10 @@ def quantize(graph, params=None, dataset=None):
ret: Function
The graph after quantization
"""
graph = prerequisite_optimize(graph, params)
mod = prerequisite_optimize(mod, params)
mod = _module.Module.from_expr(graph)
calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="QuantizeCalibrate")
calibrate_pass = _transform.module_pass(calibrate(dataset), opt_level=1,
name="QuantizeCalibrate")
quant_passes = [partition(),
annotate(),
calibrate_pass]
......@@ -452,4 +351,4 @@ def quantize(graph, params=None, dataset=None):
with quantize_context():
mod = quantize_seq(mod)
return mod["main"]
return mod
......@@ -87,12 +87,12 @@ class StatsCollector : private ExprMutator {
* \param expr The simulation graph after annotation.
* \return The profile graph.
*/
Expr CollectStats(const Expr& expr) {
Expr CreateStatsCollector(const Expr& expr) {
return StatsCollector().Collect(expr);
}
TVM_REGISTER_API("relay._quantize.CollectStats")
.set_body_typed(CollectStats);
TVM_REGISTER_API("relay._quantize.CreateStatsCollector")
.set_body_typed(CreateStatsCollector);
} // namespace quantize
} // namespace relay
......
......@@ -123,7 +123,9 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "nbit_input=" << op->nbit_input << ", ";
p->stream << "nbit_weight=" << op->nbit_weight << ", ";
p->stream << "nbit_activation=" << op->nbit_activation << ", ";
p->stream << "calibrate_mode=" << op->calibrate_mode << ", ";
p->stream << "global_scale=" << op->global_scale << ", ";
p->stream << "weight_scale=" << op->weight_scale << ", ";
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "do_simulation==" << op->do_simulation << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
......
......@@ -70,7 +70,9 @@ class QConfigNode : public Node {
DataType dtype_input = Int(8);
DataType dtype_weight = Int(8);
DataType dtype_activation = Int(32);
std::string calibrate_mode = "global_scale";
double global_scale = 8.0;
std::string weight_scale = "power2";
Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
bool do_simulation = false;
bool round_for_shift = true;
......@@ -84,7 +86,9 @@ class QConfigNode : public Node {
v->Visit("dtype_input", &dtype_input);
v->Visit("dtype_weight", &dtype_weight);
v->Visit("dtype_activation", &dtype_activation);
v->Visit("calibrate_mode", &calibrate_mode);
v->Visit("global_scale", &global_scale);
v->Visit("weight_scale", &weight_scale);
v->Visit("skip_conv_layers", &skip_conv_layers);
v->Visit("do_simulation", &do_simulation);
v->Visit("round_for_shift", &round_for_shift);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment