Commit 21e8dfac by Wuwei Lin Committed by ziheng

[Relay][Quantization] Speed-aware quantization scheme improvement (#2723)

* [Relay][Quantization] Speed-aware quantization scheme improvement

* Add comment

* Add use_stop_fusion to qconfig

* Update comment
parent b0a0ae4d
......@@ -9,7 +9,7 @@ from ..build_module import build as _tvm_build_module
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import ir_pass
from . import expr
from . import expr as _expr
from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen
......@@ -22,6 +22,7 @@ OPT_PASS_LEVEL = {
"FoldScaleAxis": 3,
"AlterOpLayout": 3,
"CanonicalizeOps": 3,
"EliminateCommonSubexpr": 3,
}
......@@ -126,8 +127,8 @@ def _bind_params_by_name(func, params):
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = expr.const(v)
return expr.bind(func, bind_dict)
bind_dict[arg] = _expr.const(v)
return _expr.bind(func, bind_dict)
def optimize(func, target=None, params=None):
......@@ -162,6 +163,16 @@ def optimize(func, target=None, params=None):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)
if cfg.pass_enabled("EliminateCommonSubexpr"):
def fskip(expr):
if isinstance(expr, _expr.Call) and expr.op.name == 'cast' and \
expr.attrs.dtype == 'int32':
return True
return False
func = ir_pass.infer_type(func)
func = ir_pass.eliminate_common_subexpr(func, fskip)
if cfg.pass_enabled("CombineParallelConv2D"):
func = ir_pass.infer_type(func)
func = ir_pass.combine_parallel_conv2d(func)
......
......@@ -192,6 +192,9 @@ def add_rewrite(ref_call, new_args, ctx):
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
# quantize rhs to INPUT field if both lhs and rhs are ACTIVATION
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
......
......@@ -58,6 +58,7 @@ class QConfig(NodeBase):
"round_for_shift": True,
"store_lowbit_output": True,
"debug_enabled_ops": None,
"use_stop_fusion": True
}
# pylint: disable=no-member
......@@ -129,6 +130,10 @@ def qconfig(**kwargs):
Whether to store low-bit integer back as output before dequantizing.
Some accelerators need this, e.g. VTA.
use_stop_fusion: boolean
Whether add stop_fusion when casting to dtype_activation. stop_fusion forces lowbit
results to be stored in memory.
Returns
-------
config: QConfig
......
......@@ -124,7 +124,7 @@ TVM_REGISTER_API("relay._quantize.annotate")
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr);
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, fmulti_ref);
});
......@@ -329,9 +329,11 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
const Array<Expr>& args,
DataType* dtype_ptr,
Expr* scale_ptr) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();
std::vector<const QRealizeIntExprNode*> nptrs;
......@@ -344,10 +346,19 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
}
// unify the data type
CHECK_EQ(ref_args.size(), args.size());
DataType dtype = cfg->dtype_activation;
for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input);
if (cfg->use_stop_fusion) {
new_arg = StopFusion(new_arg);
}
ret.Set(i, Cast(new_arg, dtype));
}
}
......@@ -371,7 +382,7 @@ Expr AddRealize(const Call& ref_call,
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(new_args, &dtype, &dom_scale);
Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}
......@@ -387,15 +398,19 @@ Expr ConcatenateRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
CHECK_EQ(ref_call->args.size(), 1);
const auto* tuple = new_args[0].as<TupleNode>();
const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
CHECK(tuple);
CHECK(ref_tuple);
const Array<Expr>& arr = tuple->fields;
const Array<Expr>& ref_arr = ref_tuple->fields;
if (arr[0].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(arr, &dtype, &dom_scale);
Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
} else {
......@@ -530,7 +545,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";
p->stream << "use_stop_fusion==" << op->use_stop_fusion;
p->stream << ")";
});
......
......@@ -110,6 +110,7 @@ class QConfigNode : public Node {
bool round_for_shift = true;
bool store_lowbit_output = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
bool use_stop_fusion = true;
void VisitAttrs(AttrVisitor* v) final {
v->Visit("nbit_input", &nbit_input);
......@@ -123,6 +124,7 @@ class QConfigNode : public Node {
v->Visit("round_for_shift", &round_for_shift);
v->Visit("store_lowbit_output", &store_lowbit_output);
v->Visit("debug_enabled_ops", &debug_enabled_ops);
v->Visit("use_stop_fusion", &use_stop_fusion);
}
static constexpr const char* _type_key = "relay.quantize.QConfig";
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment