Commit 6e2c7ede by Zhi Committed by Tianqi Chen

[Relay][Transform] quantize opt passes to pass manager (#3289)

parent 579e96da
...@@ -21,7 +21,9 @@ import numpy as np ...@@ -21,7 +21,9 @@ import numpy as np
from . import _quantize from . import _quantize
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import ir_pass as _ir_pass from .. import ir_pass as _ir_pass
from .. import transform as _transform
from .. import op as _op from .. import op as _op
from ... import make as _make from ... import make as _make
from ..base import NodeBase, register_relay_node from ..base import NodeBase, register_relay_node
...@@ -178,26 +180,7 @@ def _set_conv_counter(n): ...@@ -178,26 +180,7 @@ def _set_conv_counter(n):
CONV_COUNTER = n CONV_COUNTER = n
def annotate(graph): def calibrate(graph, mod=None, ctx=None):
"""Given a float32 graph, annotate will rewrite the graph
and return back a graph which simulates the error brought by
current quantization scheme.
Parameters
---------
graph: Function
The original graph
Returns
-------
ret: Function
The graph after annotation
"""
_set_conv_counter(0) # reset counter
return _quantize.annotate(graph)
def calibrate(graph, dataset=None):
"""The calibrate procedure will try to calculate the content of """The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator. operator.
...@@ -207,8 +190,11 @@ def calibrate(graph, dataset=None): ...@@ -207,8 +190,11 @@ def calibrate(graph, dataset=None):
graph: Function graph: Function
The simulation graph after annotation. The simulation graph after annotation.
dataset: list of dict of Var -> NDArray mod: tvm.relay.Module
The calibration dataset. The module where calibration happens on.
ctx: tvm.relay.PassContext
The pass context used for calibration.
Returns Returns
------- -------
...@@ -253,54 +239,36 @@ def calibrate(graph, dataset=None): ...@@ -253,54 +239,36 @@ def calibrate(graph, dataset=None):
return _expr.bind(graph, const_params) return _expr.bind(graph, const_params)
def realize(graph): def annotate():
"""The realize pass will transform the simulated quantized """Given a float32 graph, this pass will rewrite the graph and return
graph, which computes with float32 actually, to a real low-bit a graph which simulates the error brought by the current quantization
integer graph. It will replace the simulated_quantize with scheme.
several fine-grained operators like add, multiply, and shift
as more as possible for performance (fusion, etc.)
Parameters
---------
graph: Function
The simulated graph after calibrating.
Returns Returns
------- -------
ret: Function ret: tvm.relay.Pass
The graph after realization The registered pass for quantization annotation.
""" """
return _quantize.realize(graph) return _quantize.QuantizeAnnotate()
def optimize(func, params=None):
""" Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization.
# TODO(zhiics) These passes are executed one by one so far. We need to
# move them to the pass manager.
Parameters
---------
func: tvm.relay.Function
The original Relay function to be optimized.
params : dict of str to tvm.NDArray def realize():
Input parameters to the graph that do not change """The realize pass will transform the simulated quantized graph, which
during inference time. Used for constant folding. actually computes with float32, to a real low-bit integer graph. It will
replace the `simulated_quantize` with several fine-grained operators like
add, multiply, and shift as much as possible for better performance.
Returns Returns
------- -------
ret: tvm.relay.Function ret: tvm.relay.Pass
The graph after quantization The registered pass for quantization realization.
""" """
return _quantize.QuantizeRealize()
opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]
if params: def _bind_params(func, params):
"""Bind the params to the expression.
"""
name_dict = {} name_dict = {}
for arg in func.params: for arg in func.params:
name = arg.name_hint name = arg.name_hint
...@@ -316,30 +284,7 @@ def optimize(func, params=None): ...@@ -316,30 +284,7 @@ def optimize(func, params=None):
if arg is None: if arg is None:
raise ValueError("Multiple args in the function have name %s" % k) raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v) bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict) return _expr.bind(func, bind_dict)
if "SimplifyInference" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)
if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func)
if "FoldScaleAxis" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)
if "CanonicalizeOps" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)
if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func)
return func
def quantize(graph, params=None, dataset=None): def quantize(graph, params=None, dataset=None):
...@@ -365,11 +310,29 @@ def quantize(graph, params=None, dataset=None): ...@@ -365,11 +310,29 @@ def quantize(graph, params=None, dataset=None):
ret: Function ret: Function
The graph after quantization The graph after quantization
""" """
# TODO(zhiics) Move this to the pass manager. if params:
graph = optimize(graph, params) graph = _bind_params(graph, params)
graph = annotate(graph) mod = _module.Module.from_expr(graph)
graph = calibrate(graph, dataset) # Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
graph = realize(graph) # "CanonicalizeOps" optimization before quantization.
graph = _ir_pass.fold_constant(graph) optimize = _transform.Sequential([_transform.SimplifyInference(),
return graph _transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="QuantizeCalibrate")
_set_conv_counter(0) # reset counter
quantize_seq = _transform.Sequential([annotate(),
calibrate_pass,
realize(),
_transform.FoldConstant()])
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
mod = quantize_seq(mod)
return mod[mod.entry_func.name_hint]
...@@ -313,6 +313,7 @@ Module FunctionPassNode::operator()(const Module& mod, ...@@ -313,6 +313,7 @@ Module FunctionPassNode::operator()(const Module& mod,
<< pass_info->name << pass_info->name
<< " with opt level: " << " with opt level: "
<< pass_info->opt_level; << pass_info->opt_level;
Module updated_mod = mod; Module updated_mod = mod;
// Execute the pass function and return a new module. // Execute the pass function and return a new module.
std::vector<std::pair<GlobalVar, Function> > updates; std::vector<std::pair<GlobalVar, Function> > updates;
......
...@@ -43,6 +43,8 @@ namespace tvm { ...@@ -43,6 +43,8 @@ namespace tvm {
namespace relay { namespace relay {
namespace quantize { namespace quantize {
using namespace relay::transform;
/*! \brief Attribute for simulated quantize operator */ /*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> { struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind; int kind;
...@@ -131,23 +133,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr") ...@@ -131,23 +133,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
static_cast<QAnnotateKind>(args[1].operator int())); static_cast<QAnnotateKind>(args[1].operator int()));
}); });
TVM_REGISTER_API("relay._quantize.annotate")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref);
});
// ============= // =============
// realize pass // realize pass
...@@ -536,14 +521,6 @@ Expr AvgPoolRealize(const Call& ref_call, ...@@ -536,14 +521,6 @@ Expr AvgPoolRealize(const Call& ref_call,
RELAY_REGISTER_OP("nn.avg_pool2d") RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize); .set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
TVM_REGISTER_API("relay._quantize.realize")
.set_body_typed<Expr(Expr)>([](const Expr& e) {
Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
return ret;
});
// ============= // =============
// qconfig // qconfig
...@@ -613,6 +590,42 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope") ...@@ -613,6 +590,42 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
TVM_REGISTER_API("relay._quantize._ExitQConfigScope") TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
.set_body_typed(QConfig::ExitQConfigScope); .set_body_typed(QConfig::ExitQConfigScope);
Pass QuantizeAnnotate() {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f =
runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
.set_body_typed(QuantizeAnnotate);
Pass QuantizeRealizePass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
};
return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeRealize")
.set_body_typed(QuantizeRealizePass);
} // namespace quantize } // namespace quantize
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment