Commit 36702a76 by Josh Fromm Committed by Wuwei Lin

[Relay] Option to select which convolution layers are quantized. (#3173)

* Stashing for later maybe.

* Added new option to leave specific layers unquantized.

* Better error checking.

* remove unneeded import

* tab to spaces

* pylint fixes

* more pylint fixes
parent 0f2a3086
...@@ -156,6 +156,13 @@ def conv2d_rewrite(ref_call, new_args, ctx): ...@@ -156,6 +156,13 @@ def conv2d_rewrite(ref_call, new_args, ctx):
if cnt < current_qconfig().skip_k_conv: if cnt < current_qconfig().skip_k_conv:
_set_conv_counter(cnt + 1) _set_conv_counter(cnt + 1)
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt in leave_alone_indices:
_set_conv_counter(cnt + 1)
return None
_set_conv_counter(cnt + 1) _set_conv_counter(cnt + 1)
lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
...@@ -168,6 +175,7 @@ def conv2d_rewrite(ref_call, new_args, ctx): ...@@ -168,6 +175,7 @@ def conv2d_rewrite(ref_call, new_args, ctx):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
...@@ -178,6 +186,11 @@ def dense_rewrite(ref_call, new_args, ctx): ...@@ -178,6 +186,11 @@ def dense_rewrite(ref_call, new_args, ctx):
cnt = _conv_counter() cnt = _conv_counter()
if cnt < current_qconfig().skip_k_conv: if cnt < current_qconfig().skip_k_conv:
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
...@@ -194,8 +207,13 @@ def dense_rewrite(ref_call, new_args, ctx): ...@@ -194,8 +207,13 @@ def dense_rewrite(ref_call, new_args, ctx):
@register_annotate_function("multiply") @register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx): def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply.""" """Rewrite function for multiply."""
if _conv_counter() <= current_qconfig().skip_k_conv: cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
...@@ -216,8 +234,13 @@ def multiply_rewrite(ref_call, new_args, ctx): ...@@ -216,8 +234,13 @@ def multiply_rewrite(ref_call, new_args, ctx):
@register_annotate_function("add") @register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx): def add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add.""" """Rewrite function for add."""
if _conv_counter() <= current_qconfig().skip_k_conv: cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
...@@ -244,8 +267,13 @@ def add_rewrite(ref_call, new_args, ctx): ...@@ -244,8 +267,13 @@ def add_rewrite(ref_call, new_args, ctx):
def identity_rewrite(ref_call, new_args, ctx): def identity_rewrite(ref_call, new_args, ctx):
"""Simply forward the original operation""" """Simply forward the original operation"""
if _conv_counter() <= current_qconfig().skip_k_conv: cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None
x_expr, x_kind = _get_expr_kind(new_args[0]) x_expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None: if x_kind is None:
...@@ -262,8 +290,14 @@ register_annotate_function("nn.avg_pool2d", identity_rewrite) ...@@ -262,8 +290,14 @@ register_annotate_function("nn.avg_pool2d", identity_rewrite)
def pool2d_rewrite(ref_call, new_args, ctx): def pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool2d""" """Rewrite function for max pool2d"""
if _conv_counter() <= current_qconfig().skip_k_conv: cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None
expr, x_kind = _get_expr_kind(new_args[0]) expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None: if x_kind is None:
...@@ -280,8 +314,13 @@ register_annotate_function("nn.max_pool2d", pool2d_rewrite) ...@@ -280,8 +314,13 @@ register_annotate_function("nn.max_pool2d", pool2d_rewrite)
@register_annotate_function("concatenate") @register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx): def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate""" """Rewrite function for concatenate"""
if _conv_counter() <= current_qconfig().skip_k_conv: cnt = _conv_counter()
if cnt <= current_qconfig().skip_k_conv:
return None return None
if current_qconfig().skip_conv_layers is not None:
leave_alone_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if cnt - 1 in leave_alone_indices:
return None
input_tuple = new_args[0] input_tuple = new_args[0]
expr_list = [_get_expr_kind(x)[0] for x in input_tuple] expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
......
...@@ -71,6 +71,7 @@ class QConfig(NodeBase): ...@@ -71,6 +71,7 @@ class QConfig(NodeBase):
"dtype_activation": "int32", "dtype_activation": "int32",
"global_scale": 8.0, "global_scale": 8.0,
"skip_k_conv": 1, "skip_k_conv": 1,
"skip_conv_layers": None,
"round_for_shift": True, "round_for_shift": True,
"store_lowbit_output": True, "store_lowbit_output": True,
"debug_enabled_ops": None, "debug_enabled_ops": None,
...@@ -139,6 +140,10 @@ def qconfig(**kwargs): ...@@ -139,6 +140,10 @@ def qconfig(**kwargs):
skip_k_conv: int skip_k_conv: int
The number of skipped conv2d. The number of skipped conv2d.
skip_conv_layers: list
Different way of specifying which layers to avoid. Provide a list of indices
that indicate which conv2d layers to leave untouched.
round_for_shift: boolean round_for_shift: boolean
Whether to add bias for rounding during shift. Whether to add bias for rounding during shift.
......
...@@ -596,6 +596,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -596,6 +596,7 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "nbit_activation=" << op->nbit_activation << ", "; p->stream << "nbit_activation=" << op->nbit_activation << ", ";
p->stream << "global_scale=" << op->global_scale << ", "; p->stream << "global_scale=" << op->global_scale << ", ";
p->stream << "skip_k_conv==" << op->skip_k_conv << ", "; p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", "; p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", "; p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", "; p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";
......
...@@ -126,6 +126,7 @@ class QConfigNode : public Node { ...@@ -126,6 +126,7 @@ class QConfigNode : public Node {
DataType dtype_activation = Int(32); DataType dtype_activation = Int(32);
double global_scale = 8.0; double global_scale = 8.0;
int skip_k_conv = 1; int skip_k_conv = 1;
Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
bool round_for_shift = true; bool round_for_shift = true;
bool store_lowbit_output = true; bool store_lowbit_output = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr)); Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
...@@ -140,6 +141,7 @@ class QConfigNode : public Node { ...@@ -140,6 +141,7 @@ class QConfigNode : public Node {
v->Visit("dtype_activation", &dtype_activation); v->Visit("dtype_activation", &dtype_activation);
v->Visit("global_scale", &global_scale); v->Visit("global_scale", &global_scale);
v->Visit("skip_k_conv", &skip_k_conv); v->Visit("skip_k_conv", &skip_k_conv);
v->Visit("skip_conv_layers", &skip_conv_layers);
v->Visit("round_for_shift", &round_for_shift); v->Visit("round_for_shift", &round_for_shift);
v->Visit("store_lowbit_output", &store_lowbit_output); v->Visit("store_lowbit_output", &store_lowbit_output);
v->Visit("debug_enabled_ops", &debug_enabled_ops); v->Visit("debug_enabled_ops", &debug_enabled_ops);
......
...@@ -105,7 +105,9 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou ...@@ -105,7 +105,9 @@ def conv2d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCHW', ou
return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype, return winograd_cuda(cfg, data, kernel, strides, padding, dilation, layout, out_dtype,
pre_computed=False) pre_computed=False)
if cfg.template_key == 'int8': if cfg.template_key == 'int8':
return conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out_dtype) if (data.dtype == 'int8' or data.dtype == 'uint8'):
return conv2d_NCHWc_int8(
cfg, data, kernel, strides, padding, dilation, layout, out_dtype)
if layout == 'NCHW': if layout == 'NCHW':
return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype) return nn.conv2d_nchw(data, kernel, strides, padding, dilation, out_dtype)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment