Commit 6e2c7ede by Zhi Committed by Tianqi Chen

[Relay][Transform] quantize opt passes to pass manager (#3289)

parent 579e96da
......@@ -21,7 +21,9 @@ import numpy as np
from . import _quantize
from .. import expr as _expr
from .. import module as _module
from .. import ir_pass as _ir_pass
from .. import transform as _transform
from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
......@@ -178,26 +180,7 @@ def _set_conv_counter(n):
CONV_COUNTER = n
def annotate(graph):
"""Given a float32 graph, annotate will rewrite the graph
and return back a graph which simulates the error brought by
current quantization scheme.
Parameters
---------
graph: Function
The original graph
Returns
-------
ret: Function
The graph after annotation
"""
_set_conv_counter(0) # reset counter
return _quantize.annotate(graph)
def calibrate(graph, dataset=None):
def calibrate(graph, mod=None, ctx=None):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
......@@ -207,8 +190,11 @@ def calibrate(graph, dataset=None):
graph: Function
The simulation graph after annotation.
dataset: list of dict of Var -> NDArray
The calibration dataset.
mod: tvm.relay.Module
The module where calibration happens on.
ctx: tvm.relay.PassContext
The pass context used for calibration.
Returns
-------
......@@ -253,93 +239,52 @@ def calibrate(graph, dataset=None):
return _expr.bind(graph, const_params)
def realize(graph):
"""The realize pass will transform the simulated quantized
graph, which computes with float32 actually, to a real low-bit
integer graph. It will replace the simulated_quantize with
several fine-grained operators like add, multiply, and shift
as more as possible for performance (fusion, etc.)
Parameters
---------
graph: Function
The simulated graph after calibrating.
def annotate():
"""Given a float32 graph, this pass will rewrite the graph and return
a graph which simulates the error brought by the current quantization
scheme.
Returns
-------
ret: Function
The graph after realization
ret: tvm.relay.Pass
The registered pass for quantization annotation.
"""
return _quantize.realize(graph)
return _quantize.QuantizeAnnotate()
def optimize(func, params=None):
""" Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization.
# TODO(zhiics) These passes are executed one by one so far. We need to
# move them to the pass manager.
Parameters
---------
func: tvm.relay.Function
The original Relay function to be optimized.
params : dict of str to tvm.NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
def realize():
"""The realize pass will transform the simulated quantized graph, which
actually computes with float32, to a real low-bit integer graph. It will
replace the `simulated_quantize` with several fine-grained operators like
add, multiply, and shift as much as possible for better performance.
Returns
-------
ret: tvm.relay.Function
The graph after quantization
ret: tvm.relay.Pass
The registered pass for quantization realization.
"""
return _quantize.QuantizeRealize()
opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]
if params:
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict)
if "SimplifyInference" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)
if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func)
if "FoldScaleAxis" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)
if "CanonicalizeOps" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)
if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func)
return func
def _bind_params(func, params):
"""Bind the params to the expression.
"""
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v)
return _expr.bind(func, bind_dict)
def quantize(graph, params=None, dataset=None):
......@@ -365,11 +310,29 @@ def quantize(graph, params=None, dataset=None):
ret: Function
The graph after quantization
"""
# TODO(zhiics) Move this to the pass manager.
graph = optimize(graph, params)
graph = annotate(graph)
graph = calibrate(graph, dataset)
graph = realize(graph)
graph = _ir_pass.fold_constant(graph)
return graph
if params:
graph = _bind_params(graph, params)
mod = _module.Module.from_expr(graph)
# Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
# "CanonicalizeOps" optimization before quantization.
optimize = _transform.Sequential([_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="QuantizeCalibrate")
_set_conv_counter(0) # reset counter
quantize_seq = _transform.Sequential([annotate(),
calibrate_pass,
realize(),
_transform.FoldConstant()])
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
mod = quantize_seq(mod)
return mod[mod.entry_func.name_hint]
......@@ -313,6 +313,7 @@ Module FunctionPassNode::operator()(const Module& mod,
<< pass_info->name
<< " with opt level: "
<< pass_info->opt_level;
Module updated_mod = mod;
// Execute the pass function and return a new module.
std::vector<std::pair<GlobalVar, Function> > updates;
......
......@@ -43,6 +43,8 @@ namespace tvm {
namespace relay {
namespace quantize {
using namespace relay::transform;
/*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind;
......@@ -131,23 +133,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
static_cast<QAnnotateKind>(args[1].operator int()));
});
TVM_REGISTER_API("relay._quantize.annotate")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref);
});
// =============
// realize pass
......@@ -536,14 +521,6 @@ Expr AvgPoolRealize(const Call& ref_call,
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
TVM_REGISTER_API("relay._quantize.realize")
.set_body_typed<Expr(Expr)>([](const Expr& e) {
Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
return ret;
});
// =============
// qconfig
......@@ -613,6 +590,42 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
.set_body_typed(QConfig::ExitQConfigScope);
Pass QuantizeAnnotate() {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f =
runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
.set_body_typed(QuantizeAnnotate);
Pass QuantizeRealizePass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
};
return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeRealize")
.set_body_typed(QuantizeRealizePass);
} // namespace quantize
} // namespace relay
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment