Unverified Commit bfb4884e by ziheng Committed by GitHub

[QUANTIZE] Memorizing the quantize node mapping (#3233)

* [QUANTIZE] Support for clip operator

* [QUANTIZE] Memorizing the quantize node mapping.

* [QUANTIZE] Remove use_stop_fusion and skip_k_conv in qconfig

* update

* update

* update

* update
parent b796e335
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
"""The interface of expr function exposed from C++.""" """The interface of expr function exposed from C++."""
from __future__ import absolute_import from __future__ import absolute_import
import logging
from ... import build_module as _build from ... import build_module as _build
from ... import container as _container from ... import container as _container
from ..._ffi.function import _init_api, register_func from ..._ffi.function import _init_api, register_func
...@@ -50,8 +49,8 @@ def lower(sch, inputs, func_name, source_func): ...@@ -50,8 +49,8 @@ def lower(sch, inputs, func_name, source_func):
# pylint: disable=broad-except # pylint: disable=broad-except
try: try:
f = _build.lower(sch, inputs, name=func_name) f = _build.lower(sch, inputs, name=func_name)
logging.debug("lower function %s", func_name) # logging.debug("lower function %s", func_name)
logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) # logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
except Exception: except Exception:
msg = traceback.format_exc() msg = traceback.format_exc()
msg += "Error during compile function\n" msg += "Error during compile function\n"
......
...@@ -22,7 +22,7 @@ import warnings ...@@ -22,7 +22,7 @@ import warnings
import topi import topi
from . import _quantize from . import _quantize
from .quantize import QAnnotateKind, current_qconfig from .quantize import QAnnotateKind, current_qconfig
from .quantize import _conv_counter, _set_conv_counter from .quantize import annotate_context
from .. import expr as _expr from .. import expr as _expr
from .. import op as _op from .. import op as _op
from ..op import op as _reg from ..op import op as _reg
...@@ -116,7 +116,6 @@ def register_annotate_function(op_name, frewrite=None, level=10): ...@@ -116,7 +116,6 @@ def register_annotate_function(op_name, frewrite=None, level=10):
return _register(frewrite) if frewrite is not None else _register return _register(frewrite) if frewrite is not None else _register
@register_func("relay.quantize.attach_simulated_quantize")
def attach_simulated_quantize(data, kind, sign=True, rounding="round"): def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
"""Attach a simulated quantize operation after input data expr. """Attach a simulated quantize operation after input data expr.
...@@ -133,11 +132,20 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"): ...@@ -133,11 +132,20 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding: if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding:
return data return data
actx = annotate_context()
key = tuple([data, kind, sign, rounding])
if key in actx.qnode_map:
return actx.qnode_map[key]
dom_scale = _expr.var("dom_scale") dom_scale = _expr.var("dom_scale")
clip_min = _expr.var("clip_min") clip_min = _expr.var("clip_min")
clip_max = _expr.var("clip_max") clip_max = _expr.var("clip_max")
return _quantize.simulated_quantize( qnode = _quantize.simulated_quantize(
data, dom_scale, clip_min, clip_max, kind, sign, rounding) data, dom_scale, clip_min, clip_max, kind, sign, rounding)
actx.qnode_map[key] = qnode
return qnode
register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
@register_annotate_function("nn.contrib_conv2d_NCHWc") @register_annotate_function("nn.contrib_conv2d_NCHWc")
...@@ -152,18 +160,13 @@ def conv2d_rewrite(ref_call, new_args, ctx): ...@@ -152,18 +160,13 @@ def conv2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv2d. Lhs of conv will be quantized to """Rewrite function for conv2d. Lhs of conv will be quantized to
input field, and rhs of conv will be quantized to weight field. input field, and rhs of conv will be quantized to weight field.
Output would be in activation field""" Output would be in activation field"""
cnt = _conv_counter() actx = annotate_context()
if cnt < current_qconfig().skip_k_conv:
_set_conv_counter(cnt + 1)
return None
if current_qconfig().skip_conv_layers is not None: if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers] skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt in leave_alone_indices: if actx.conv2d_counter() in skipped_indices:
_set_conv_counter(cnt + 1) actx.count_conv2d()
return None return None
actx.count_conv2d()
_set_conv_counter(cnt + 1)
lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
...@@ -179,17 +182,21 @@ def conv2d_rewrite(ref_call, new_args, ctx): ...@@ -179,17 +182,21 @@ def conv2d_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
def check_to_skip():
"""Check the index of conv2d layer to decide whether to skip the current operator."""
if current_qconfig().skip_conv_layers is not None:
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if annotate_context().conv2d_counter() - 1 in skipped_indices:
return True
return False
@register_annotate_function("nn.dense") @register_annotate_function("nn.dense")
def dense_rewrite(ref_call, new_args, ctx): def dense_rewrite(ref_call, new_args, ctx):
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
dense will be quantized to weight field. Output would be in activation field.""" dense will be quantized to weight field. Output would be in activation field."""
cnt = _conv_counter() if check_to_skip():
if cnt < current_qconfig().skip_k_conv:
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
...@@ -207,13 +214,8 @@ def dense_rewrite(ref_call, new_args, ctx): ...@@ -207,13 +214,8 @@ def dense_rewrite(ref_call, new_args, ctx):
@register_annotate_function("multiply") @register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx): def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply.""" """Rewrite function for multiply."""
cnt = _conv_counter() if check_to_skip():
if cnt <= current_qconfig().skip_k_conv:
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
...@@ -234,13 +236,8 @@ def multiply_rewrite(ref_call, new_args, ctx): ...@@ -234,13 +236,8 @@ def multiply_rewrite(ref_call, new_args, ctx):
@register_annotate_function("add") @register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx): def add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add.""" """Rewrite function for add."""
cnt = _conv_counter() if check_to_skip():
if cnt <= current_qconfig().skip_k_conv:
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
...@@ -265,15 +262,25 @@ def add_rewrite(ref_call, new_args, ctx): ...@@ -265,15 +262,25 @@ def add_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
@register_annotate_function("stop_fusion")
def stop_fusion_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
if check_to_skip():
return None
x_expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
return None
ret_expr = attach_simulated_quantize(x_expr, QAnnotateKind.INPUT)
ret_expr = _forward_op(ref_call, [ret_expr])
return QAnnotateExpr(ret_expr, QAnnotateKind.INPUT)
def identity_rewrite(ref_call, new_args, ctx): def identity_rewrite(ref_call, new_args, ctx):
"""Simply forward the original operation""" """Simply forward the original operation"""
cnt = _conv_counter() if check_to_skip():
if cnt <= current_qconfig().skip_k_conv:
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None
x_expr, x_kind = _get_expr_kind(new_args[0]) x_expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None: if x_kind is None:
...@@ -283,6 +290,7 @@ def identity_rewrite(ref_call, new_args, ctx): ...@@ -283,6 +290,7 @@ def identity_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(ret_expr, x_kind) return QAnnotateExpr(ret_expr, x_kind)
register_annotate_function("clip", identity_rewrite)
register_annotate_function("nn.relu", identity_rewrite) register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite) register_annotate_function("strided_slice", identity_rewrite)
register_annotate_function("nn.avg_pool2d", identity_rewrite) register_annotate_function("nn.avg_pool2d", identity_rewrite)
...@@ -290,13 +298,8 @@ register_annotate_function("nn.avg_pool2d", identity_rewrite) ...@@ -290,13 +298,8 @@ register_annotate_function("nn.avg_pool2d", identity_rewrite)
def pool2d_rewrite(ref_call, new_args, ctx): def pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool2d""" """Rewrite function for max pool2d"""
cnt = _conv_counter() if check_to_skip():
if cnt <= current_qconfig().skip_k_conv:
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None
expr, x_kind = _get_expr_kind(new_args[0]) expr, x_kind = _get_expr_kind(new_args[0])
...@@ -314,13 +317,8 @@ register_annotate_function("nn.max_pool2d", pool2d_rewrite) ...@@ -314,13 +317,8 @@ register_annotate_function("nn.max_pool2d", pool2d_rewrite)
@register_annotate_function("concatenate") @register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx): def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate""" """Rewrite function for concatenate"""
cnt = _conv_counter() if check_to_skip():
if cnt <= current_qconfig().skip_k_conv:
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None
input_tuple = new_args[0] input_tuple = new_args[0]
expr_list = [_get_expr_kind(x)[0] for x in input_tuple] expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
......
...@@ -71,12 +71,10 @@ class QConfig(NodeBase): ...@@ -71,12 +71,10 @@ class QConfig(NodeBase):
"dtype_weight": "int8", "dtype_weight": "int8",
"dtype_activation": "int32", "dtype_activation": "int32",
"global_scale": 8.0, "global_scale": 8.0,
"skip_k_conv": 1, "skip_conv_layers": [0],
"skip_conv_layers": None,
"round_for_shift": True, "round_for_shift": True,
"store_lowbit_output": True, "store_lowbit_output": True,
"debug_enabled_ops": None, "debug_enabled_ops": None,
"use_stop_fusion": True
} }
# pylint: disable=no-member # pylint: disable=no-member
...@@ -138,11 +136,8 @@ def qconfig(**kwargs): ...@@ -138,11 +136,8 @@ def qconfig(**kwargs):
global_scale: float global_scale: float
The global scale for calibration. The global scale for calibration.
skip_k_conv: int
The number of skipped conv2d.
skip_conv_layers: list skip_conv_layers: list
Different way of specifying which layers to avoid. Provide a list of indices Specifying which layers to be skipped. Provide a list of indices
that indicate which conv2d layers to leave untouched. that indicate which conv2d layers to leave untouched.
round_for_shift: boolean round_for_shift: boolean
...@@ -152,9 +147,10 @@ def qconfig(**kwargs): ...@@ -152,9 +147,10 @@ def qconfig(**kwargs):
Whether to store low-bit integer back as output before dequantizing. Whether to store low-bit integer back as output before dequantizing.
Some accelerators need this, e.g. VTA. Some accelerators need this, e.g. VTA.
use_stop_fusion: boolean debug_enabled_ops: None or list of str
Whether add stop_fusion when casting to dtype_activation. stop_fusion forces lowbit Partially quantize specified operators for debugging. The default value
results to be stored in memory. is None, which means will try to call all operartors' annotate rewrite
function.
Returns Returns
------- -------
...@@ -166,18 +162,35 @@ def qconfig(**kwargs): ...@@ -166,18 +162,35 @@ def qconfig(**kwargs):
return _make.node("relay.quantize.QConfig", **node_args) return _make.node("relay.quantize.QConfig", **node_args)
CONV_COUNTER = 0 class AnnotateContext(object):
"""A global singleton annotate scope"""
Current = None
def __init__(self):
self.qnode_map = dict()
self._conv2d_counter = 0
def __enter__(self):
self._conv2d_counter = 0
return self
def conv2d_counter(self):
"""Get the counter for conv2d."""
return self._conv2d_counter
def count_conv2d(self):
"""Increase the value of the conv2d counter by one."""
self._conv2d_counter += 1
def _conv_counter(): def __exit__(self, ptype, value, traceback):
"""Get the global counter for conv2d.""" pass
return CONV_COUNTER
def _set_conv_counter(n): def annotate_context():
"""Set the value of the global conv2d counter.""" """Get the global singleton scope"""
global CONV_COUNTER if AnnotateContext.Current is None:
CONV_COUNTER = n AnnotateContext.Current = AnnotateContext()
return AnnotateContext.Current
def calibrate(graph, mod=None, ctx=None): def calibrate(graph, mod=None, ctx=None):
...@@ -324,15 +337,15 @@ def quantize(graph, params=None, dataset=None): ...@@ -324,15 +337,15 @@ def quantize(graph, params=None, dataset=None):
calibrate_pass = _transform.function_pass(calibrate, opt_level=1, calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="QuantizeCalibrate") name="QuantizeCalibrate")
_set_conv_counter(0) # reset counter
quantize_seq = _transform.Sequential([annotate(), quantize_seq = _transform.Sequential([annotate(),
calibrate_pass, calibrate_pass,
realize(), realize(),
_transform.FoldConstant()]) _transform.FoldConstant()])
with _transform.PassContext(opt_level=3, with annotate_context():
required_pass=["QuantizeAnnotate", with _transform.PassContext(opt_level=3,
"QuantizeCalibrate", required_pass=["QuantizeAnnotate",
"QuantizeRealize"]): "QuantizeCalibrate",
mod = optimize(mod) "QuantizeRealize"]):
mod = quantize_seq(mod) mod = optimize(mod)
mod = quantize_seq(mod)
return mod[mod.entry_func.name_hint] return mod[mod.entry_func.name_hint]
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -393,7 +393,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, ...@@ -393,7 +393,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args,
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) && } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) { ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input); auto new_arg = Cast(ret[i], cfg->dtype_input);
if (cfg->use_stop_fusion) { if (cfg->store_lowbit_output) {
new_arg = StopFusion(new_arg); new_arg = StopFusion(new_arg);
} }
ret.Set(i, Cast(new_arg, dtype)); ret.Set(i, Cast(new_arg, dtype));
...@@ -431,6 +431,28 @@ Expr AddRealize(const Call& ref_call, ...@@ -431,6 +431,28 @@ Expr AddRealize(const Call& ref_call,
RELAY_REGISTER_OP("add") RELAY_REGISTER_OP("add")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize); .set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
Expr ClipRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
auto attrs = make_node<ClipAttrs>();
double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
attrs->a_min = ref_attrs->a_min / dom_scale;
attrs->a_max = ref_attrs->a_max / dom_scale;
Expr ret = CallNode::make(ref_call->op,
{n->data}, Attrs(attrs), ref_call->type_args);
return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("clip")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
Expr ConcatenateRealize(const Call& ref_call, Expr ConcatenateRealize(const Call& ref_call,
const Array<Expr>& new_args, const Array<Expr>& new_args,
...@@ -572,12 +594,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -572,12 +594,10 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "nbit_weight=" << op->nbit_weight << ", "; p->stream << "nbit_weight=" << op->nbit_weight << ", ";
p->stream << "nbit_activation=" << op->nbit_activation << ", "; p->stream << "nbit_activation=" << op->nbit_activation << ", ";
p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "global_scale=" << op->global_scale << ", ";
p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", "; p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", "; p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
p->stream << "use_stop_fusion==" << op->use_stop_fusion;
p->stream << ")"; p->stream << ")";
}); });
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -125,12 +125,10 @@ class QConfigNode : public Node { ...@@ -125,12 +125,10 @@ class QConfigNode : public Node {
DataType dtype_weight = Int(8); DataType dtype_weight = Int(8);
DataType dtype_activation = Int(32); DataType dtype_activation = Int(32);
double global_scale = 8.0; double global_scale = 8.0;
int skip_k_conv = 1;
Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr)); Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
bool round_for_shift = true; bool round_for_shift = true;
bool store_lowbit_output = true; bool store_lowbit_output = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr)); Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
bool use_stop_fusion = true;
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("nbit_input", &nbit_input); v->Visit("nbit_input", &nbit_input);
...@@ -140,12 +138,10 @@ class QConfigNode : public Node { ...@@ -140,12 +138,10 @@ class QConfigNode : public Node {
v->Visit("dtype_weight", &dtype_weight); v->Visit("dtype_weight", &dtype_weight);
v->Visit("dtype_activation", &dtype_activation); v->Visit("dtype_activation", &dtype_activation);
v->Visit("global_scale", &global_scale); v->Visit("global_scale", &global_scale);
v->Visit("skip_k_conv", &skip_k_conv);
v->Visit("skip_conv_layers", &skip_conv_layers); v->Visit("skip_conv_layers", &skip_conv_layers);
v->Visit("round_for_shift", &round_for_shift); v->Visit("round_for_shift", &round_for_shift);
v->Visit("store_lowbit_output", &store_lowbit_output); v->Visit("store_lowbit_output", &store_lowbit_output);
v->Visit("debug_enabled_ops", &debug_enabled_ops); v->Visit("debug_enabled_ops", &debug_enabled_ops);
v->Visit("use_stop_fusion", &use_stop_fusion);
} }
static constexpr const char* _type_key = "relay.quantize.QConfig"; static constexpr const char* _type_key = "relay.quantize.QConfig";
......
...@@ -81,7 +81,7 @@ def test_quantize_pass(): ...@@ -81,7 +81,7 @@ def test_quantize_pass():
graph = make_graph(data) graph = make_graph(data)
dataset, params = make_dataset(graph, 10) dataset, params = make_dataset(graph, 10)
with qtz.qconfig(skip_k_conv=0, global_scale=4.0, with qtz.qconfig(skip_conv_layers=None, global_scale=4.0,
round_for_shift=False, store_lowbit_output=False): round_for_shift=False, store_lowbit_output=False):
qgraph0 = qtz.quantize(graph, params) qgraph0 = qtz.quantize(graph, params)
qgraph0 = relay.ir_pass.infer_type(qgraph0) qgraph0 = relay.ir_pass.infer_type(qgraph0)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment