Commit 45878ff2 by Wuwei Lin Committed by Tianqi Chen

[Relay][Quantization] Fix add_rewrite and UnifyDTypeScale (#3534)

* [Relay][Quantization] Fix issue introduced in #3135

* Recover StopFusion

* Fix fmultiref

* Fix lint
parent 8471f811
......@@ -260,13 +260,11 @@ def add_rewrite(ref_call, new_args, ctx):
if isinstance(rhs_expr, _expr.Constant):
# quantize rhs to WEIGHT field if it is Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
assert lhs_kind == QAnnotateKind.ACTIVATION
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
raise ValueError
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
if lhs_kind is not None and rhs_kind is not None:
if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT:
......@@ -277,6 +275,10 @@ def add_rewrite(ref_call, new_args, ctx):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT:
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError()
@register_annotate_function("stop_fusion")
......
......@@ -135,22 +135,6 @@ TVM_REGISTER_API("relay._quantize.make_annotate_expr")
});
TVM_REGISTER_API("relay._quantize.annotate")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr);
});
// =============
// realize pass
......@@ -395,10 +379,9 @@ float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
const Array<Expr>& args,
DataType* dtype_ptr,
Expr* scale_ptr) {
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
DataType* dtype_ptr, Expr* scale_ptr) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();
std::vector<const QRealizeIntExprNode*> nptrs;
......@@ -413,14 +396,21 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
// unify the data type
CHECK_EQ(ref_args.size(), args.size());
DataType dtype;
if (nptrs[0]->dtype == cfg->dtype_activation) {
DataType dtype = cfg->dtype_activation;
ret.Set(1, Cast(ret[1], dtype));
} else if (nptrs[1]->dtype == cfg->dtype_input) {
DataType dtype = cfg->dtype_input;
ret.Set(0, Cast(ret[0], dtype));
if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
dtype = cfg->dtype_input;
} else {
LOG(FATAL) << "should not touch here.";
dtype = cfg->dtype_activation;
}
for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input);
new_arg = StopFusion(new_arg);
ret.Set(i, Cast(new_arg, dtype));
}
}
// unify the dom_scale
......@@ -447,6 +437,7 @@ Expr AddRealize(const Call& ref_call,
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
return Expr(nullptr);
}
......@@ -674,7 +665,7 @@ Pass QuantizeAnnotate() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", fmulti_ref));
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
auto new_params = func->params;
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment