Commit 6e2c7ede by Zhi Committed by Tianqi Chen

[Relay][Transform] quantize opt passes to pass manager (#3289)

parent 579e96da
...@@ -21,7 +21,9 @@ import numpy as np ...@@ -21,7 +21,9 @@ import numpy as np
from . import _quantize from . import _quantize
from .. import expr as _expr from .. import expr as _expr
from .. import module as _module
from .. import ir_pass as _ir_pass from .. import ir_pass as _ir_pass
from .. import transform as _transform
from .. import op as _op from .. import op as _op
from ... import make as _make from ... import make as _make
from ..base import NodeBase, register_relay_node from ..base import NodeBase, register_relay_node
...@@ -178,26 +180,7 @@ def _set_conv_counter(n): ...@@ -178,26 +180,7 @@ def _set_conv_counter(n):
CONV_COUNTER = n CONV_COUNTER = n
def annotate(graph): def calibrate(graph, mod=None, ctx=None):
"""Given a float32 graph, annotate will rewrite the graph
and return back a graph which simulates the error brought by
current quantization scheme.
Parameters
---------
graph: Function
The original graph
Returns
-------
ret: Function
The graph after annotation
"""
_set_conv_counter(0) # reset counter
return _quantize.annotate(graph)
def calibrate(graph, dataset=None):
"""The calibrate procedure will try to calculate the content of """The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator. operator.
...@@ -207,8 +190,11 @@ def calibrate(graph, dataset=None): ...@@ -207,8 +190,11 @@ def calibrate(graph, dataset=None):
graph: Function graph: Function
The simulation graph after annotation. The simulation graph after annotation.
dataset: list of dict of Var -> NDArray mod: tvm.relay.Module
The calibration dataset. The module where calibration happens on.
ctx: tvm.relay.PassContext
The pass context used for calibration.
Returns Returns
------- -------
...@@ -253,93 +239,52 @@ def calibrate(graph, dataset=None): ...@@ -253,93 +239,52 @@ def calibrate(graph, dataset=None):
return _expr.bind(graph, const_params) return _expr.bind(graph, const_params)
def realize(graph): def annotate():
"""The realize pass will transform the simulated quantized """Given a float32 graph, this pass will rewrite the graph and return
graph, which computes with float32 actually, to a real low-bit a graph which simulates the error brought by the current quantization
integer graph. It will replace the simulated_quantize with scheme.
several fine-grained operators like add, multiply, and shift
as more as possible for performance (fusion, etc.)
Parameters
---------
graph: Function
The simulated graph after calibrating.
Returns Returns
------- -------
ret: Function ret: tvm.relay.Pass
The graph after realization The registered pass for quantization annotation.
""" """
return _quantize.realize(graph) return _quantize.QuantizeAnnotate()
def optimize(func, params=None): def realize():
""" Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and """The realize pass will transform the simulated quantized graph, which
"CanonicalizeOps" optimization before quantization. actually computes with float32, to a real low-bit integer graph. It will
replace the `simulated_quantize` with several fine-grained operators like
# TODO(zhiics) These passes are executed one by one so far. We need to add, multiply, and shift as much as possible for better performance.
# move them to the pass manager.
Parameters
---------
func: tvm.relay.Function
The original Relay function to be optimized.
params : dict of str to tvm.NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns Returns
------- -------
ret: tvm.relay.Function ret: tvm.relay.Pass
The graph after quantization The registered pass for quantization realization.
""" """
return _quantize.QuantizeRealize()
opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]
if params: def _bind_params(func, params):
name_dict = {} """Bind the params to the expression.
for arg in func.params: """
name = arg.name_hint name_dict = {}
if name in name_dict: for arg in func.params:
name_dict[name] = None name = arg.name_hint
else: if name in name_dict:
name_dict[name] = arg name_dict[name] = None
bind_dict = {} else:
for k, v in params.items(): name_dict[name] = arg
if k not in name_dict: bind_dict = {}
continue for k, v in params.items():
arg = name_dict[k] if k not in name_dict:
if arg is None: continue
raise ValueError("Multiple args in the function have name %s" % k) arg = name_dict[k]
bind_dict[arg] = _expr.const(v) if arg is None:
func = _expr.bind(func, bind_dict) raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v)
if "SimplifyInference" in opt_passes: return _expr.bind(func, bind_dict)
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)
if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func)
if "FoldScaleAxis" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)
if "CanonicalizeOps" in opt_passes:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)
if "FoldConstant" in opt_passes:
func = _ir_pass.fold_constant(func)
return func
def quantize(graph, params=None, dataset=None): def quantize(graph, params=None, dataset=None):
...@@ -365,11 +310,29 @@ def quantize(graph, params=None, dataset=None): ...@@ -365,11 +310,29 @@ def quantize(graph, params=None, dataset=None):
ret: Function ret: Function
The graph after quantization The graph after quantization
""" """
# TODO(zhiics) Move this to the pass manager. if params:
graph = optimize(graph, params) graph = _bind_params(graph, params)
graph = annotate(graph) mod = _module.Module.from_expr(graph)
graph = calibrate(graph, dataset) # Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
graph = realize(graph) # "CanonicalizeOps" optimization before quantization.
graph = _ir_pass.fold_constant(graph) optimize = _transform.Sequential([_transform.SimplifyInference(),
return graph _transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="QuantizeCalibrate")
_set_conv_counter(0) # reset counter
quantize_seq = _transform.Sequential([annotate(),
calibrate_pass,
realize(),
_transform.FoldConstant()])
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
mod = quantize_seq(mod)
return mod[mod.entry_func.name_hint]
...@@ -313,6 +313,7 @@ Module FunctionPassNode::operator()(const Module& mod, ...@@ -313,6 +313,7 @@ Module FunctionPassNode::operator()(const Module& mod,
<< pass_info->name << pass_info->name
<< " with opt level: " << " with opt level: "
<< pass_info->opt_level; << pass_info->opt_level;
Module updated_mod = mod; Module updated_mod = mod;
// Execute the pass function and return a new module. // Execute the pass function and return a new module.
std::vector<std::pair<GlobalVar, Function> > updates; std::vector<std::pair<GlobalVar, Function> > updates;
......
...@@ -43,6 +43,8 @@ namespace tvm { ...@@ -43,6 +43,8 @@ namespace tvm {
namespace relay { namespace relay {
namespace quantize { namespace quantize {
using namespace relay::transform;
/*! \brief Attribute for simulated quantize operator */ /*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> { struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind; int kind;
...@@ -131,23 +133,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr") ...@@ -131,23 +133,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
static_cast<QAnnotateKind>(args[1].operator int())); static_cast<QAnnotateKind>(args[1].operator int()));
}); });
TVM_REGISTER_API("relay._quantize.annotate")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref);
});
// ============= // =============
// realize pass // realize pass
...@@ -536,14 +521,6 @@ Expr AvgPoolRealize(const Call& ref_call, ...@@ -536,14 +521,6 @@ Expr AvgPoolRealize(const Call& ref_call,
RELAY_REGISTER_OP("nn.avg_pool2d") RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize); .set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
TVM_REGISTER_API("relay._quantize.realize")
.set_body_typed<Expr(Expr)>([](const Expr& e) {
Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
return ret;
});
// ============= // =============
// qconfig // qconfig
...@@ -613,6 +590,42 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope") ...@@ -613,6 +590,42 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
TVM_REGISTER_API("relay._quantize._ExitQConfigScope") TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
.set_body_typed(QConfig::ExitQConfigScope); .set_body_typed(QConfig::ExitQConfigScope);
Pass QuantizeAnnotate() {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f =
runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
.set_body_typed(QuantizeAnnotate);
Pass QuantizeRealizePass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
};
return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeRealize")
.set_body_typed(QuantizeRealizePass);
} // namespace quantize } // namespace quantize
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment