Unverified Commit 7eb1f353 by ziheng Committed by GitHub

[QUANTIZE] Refactor quantization codebase and fix model accuracy (#3543)

* Refactor.

* update

* update

* update

* update

* update

* update
parent 60fc9f74
......@@ -52,6 +52,18 @@ namespace relay {
TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
/*!
* \brief Check whether an expression is constant.
*
* If the inputs of an expression are all constant, it means the expression
* itself is constant also.
*
* \param e the expression.
*
* \return whether the expression is constant.
*/
TVM_DLL bool ConstantCheck(const Expr& e);
/*!
* \brief Compare two expressions for structural equivalence.
*
* This comparison operator respects scoping and compares
......
......@@ -44,6 +44,19 @@ struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
}
};
/*!
* \brief Annotate an expression to be cast into specific data type.
*/
struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
DataType dtype;
TVM_DECLARE_ATTRS(CastHintAttrs, "relay.attrs.CastHintAttrs") {
TVM_ATTR_FIELD(dtype)
.describe(
"The data type denoted to be cast.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
......@@ -91,6 +91,22 @@ def check_kind(t, mod=None):
return _analysis.check_kind(t)
def check_constant(expr):
"""Check whether an expression is constant
Parameters
----------
expr : tvm.relay.Expr
The input expression
Returns
-------
result : bool
Whether the expression is constant.
"""
return _analysis.check_constant(expr)
def free_vars(expr):
"""Get free Vars from expression expr in Post DFS order.
......
......@@ -19,5 +19,6 @@
from __future__ import absolute_import as _abs
from .quantize import *
from ._partition import register_partition_function
from ._annotate import register_annotate_function
from .kl_divergence import kl_divergence_scale
......@@ -20,14 +20,15 @@ from __future__ import absolute_import
import warnings
import topi
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig
from .quantize import annotate_context
from ..._ffi.function import register_func
from .. import expr as _expr
from .. import analysis as _analysis
from .. import op as _op
from ..op import op as _reg
from ..base import register_relay_node
from ..._ffi.function import register_func
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig, quantize_context
from .quantize import _forward_op
@_reg.register_compute("relay.op.annotation.simulated_quantize")
......@@ -75,12 +76,6 @@ class QAnnotateExpr(_expr.TempExpr):
_quantize.make_annotate_expr, expr, kind)
def _forward_op(ref_call, args):
"""forward the operator of ref_call with provided arguments"""
return _expr.Call(
ref_call.op, args, ref_call.attrs, ref_call.type_args)
def _get_expr_kind(anno):
"""Get the expression and QAnnotateKind from QAnnotateExpr or Expr"""
if isinstance(anno, QAnnotateExpr):
......@@ -113,7 +108,7 @@ def register_annotate_function(op_name, frewrite=None, level=10):
if not current_qconfig().guard(ref_call):
return default_rewrite(ref_call, new_args, ctx)
return func(ref_call, new_args, ctx)
_op.op._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
_reg._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
return frewrite_with_guard
return _register(frewrite) if frewrite is not None else _register
......@@ -135,17 +130,17 @@ def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding:
return data
actx = annotate_context()
qctx = quantize_context()
key = tuple([data, kind, sign, rounding])
if key in actx.qnode_map:
return actx.qnode_map[key]
if key in qctx.qnode_map:
return qctx.qnode_map[key]
dom_scale = _expr.var("dom_scale")
clip_min = _expr.var("clip_min")
clip_max = _expr.var("clip_max")
qnode = _quantize.simulated_quantize(
data, dom_scale, clip_min, clip_max, kind, sign, rounding)
actx.qnode_map[key] = qnode
qctx.qnode_map[key] = qnode
return qnode
register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
......@@ -163,13 +158,8 @@ def conv2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv2d. Lhs of conv will be quantized to
input field, and rhs of conv will be quantized to weight field.
Output would be in activation field"""
actx = annotate_context()
if current_qconfig().skip_conv_layers is not None:
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if actx.conv2d_counter() in skipped_indices:
actx.count_conv2d()
if quantize_context().check_to_skip(ref_call):
return None
actx.count_conv2d()
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
......@@ -185,21 +175,12 @@ def conv2d_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
def check_to_skip():
"""Check the index of conv2d layer to decide whether to skip the current operator."""
if current_qconfig().skip_conv_layers is not None:
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if annotate_context().conv2d_counter() - 1 in skipped_indices:
return True
return False
# TODO(tmoreau89,ziheng) need to include an option to turn off dense quant
# @register_annotate_function("nn.dense")
def dense_rewrite(ref_call, new_args, ctx):
"""Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
dense will be quantized to weight field. Output would be in activation field."""
if check_to_skip():
if quantize_context().check_to_skip(ref_call):
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
......@@ -219,7 +200,7 @@ def dense_rewrite(ref_call, new_args, ctx):
@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply."""
if check_to_skip():
if quantize_context().check_to_skip(ref_call):
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
......@@ -243,13 +224,14 @@ def multiply_rewrite(ref_call, new_args, ctx):
@register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
if check_to_skip():
if quantize_context().check_to_skip(ref_call):
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None and rhs_kind is None:
# trivial case
return None
if lhs_kind is None and rhs_kind is not None:
......@@ -260,11 +242,10 @@ def add_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
if lhs_kind is not None and rhs_kind is None:
if isinstance(rhs_expr, _expr.Constant):
# quantize rhs to WEIGHT field if it is Constant
if _analysis.check_constant(rhs_expr):
# - introduced by batch_norm: add(out, const)
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
......@@ -274,7 +255,6 @@ def add_rewrite(ref_call, new_args, ctx):
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
# quantize rhs to INPUT field if both lhs and rhs are ACTIVATION
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
......@@ -285,24 +265,9 @@ def add_rewrite(ref_call, new_args, ctx):
raise ValueError()
@register_annotate_function("stop_fusion")
def stop_fusion_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
if check_to_skip():
return None
x_expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
return None
ret_expr = attach_simulated_quantize(x_expr, QAnnotateKind.INPUT)
ret_expr = _forward_op(ref_call, [ret_expr])
return QAnnotateExpr(ret_expr, QAnnotateKind.INPUT)
def identity_rewrite(ref_call, new_args, ctx):
"""Simply forward the original operation"""
if check_to_skip():
if quantize_context().check_to_skip(ref_call):
return None
x_expr, x_kind = _get_expr_kind(new_args[0])
......@@ -322,7 +287,7 @@ register_annotate_function("annotation.stop_fusion", identity_rewrite)
def pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool2d"""
if check_to_skip():
if quantize_context().check_to_skip(ref_call):
return None
expr, x_kind = _get_expr_kind(new_args[0])
......@@ -339,14 +304,14 @@ def pool2d_rewrite(ref_call, new_args, ctx):
register_annotate_function("nn.max_pool2d", pool2d_rewrite)
@register_annotate_function("annotation.force_cast")
def force_cast_rewrite(ref_call, new_args, ctx):
@register_annotate_function("annotation.cast_hint")
def cast_hint_rewrite(ref_call, new_args, ctx):
"""Rewrite function to force cast"""
if check_to_skip():
return None
expr, x_kind = _get_expr_kind(new_args[0])
if quantize_context().check_to_skip(ref_call):
return expr
if x_kind is None:
return new_args[0]
if x_kind == QAnnotateKind.ACTIVATION:
......@@ -359,7 +324,7 @@ def force_cast_rewrite(ref_call, new_args, ctx):
@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate"""
if check_to_skip():
if quantize_context().check_to_skip(ref_call):
return None
input_tuple = new_args[0]
......@@ -377,69 +342,18 @@ def concatenate_rewrite(ref_call, new_args, ctx):
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
# Graph rewrite function registration for VTA target
def register_vta_rewrite(op_name, frewrite=None, level=10):
def _register(func):
return _op.op._Register(op_name, "FQVTARewrite", func, level)
return _register(frewrite) if frewrite is not None else _register
@register_relay_node
class QVTAExpr(_expr.TempExpr):
def __init__(self, expr):
self.__init_handle_by_constructor__(
_quantize.make_vta_expr, expr)
def realize(self):
return _quantize.temp_expr_realize(self)
def vta_expr_check(expr):
if isinstance(expr, QVTAExpr):
return True, expr.expr
return False, expr
@register_vta_rewrite("nn.conv2d")
def conv2d_vta_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv2d for VTA target"""
actx = annotate_context()
if current_qconfig().skip_conv_layers is not None:
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if actx.conv2d_counter() in skipped_indices:
actx.count_conv2d()
@register_annotate_function("nn.global_avg_pool2d")
def global_avg_pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for global_avg_pool2d for stopping quantize"""
if quantize_context().check_to_skip(ref_call):
return None
actx.count_conv2d()
data_cond, data = vta_expr_check(new_args[0])
kernel_cond, kernel = vta_expr_check(new_args[1])
assert not kernel_cond
if data_cond:
data = new_args[0].realize()
ret = _forward_op(ref_call, [data, kernel])
return QVTAExpr(ret)
expr, x_kind = _get_expr_kind(new_args[0])
def identity_vta_rewrite(ref_call, new_args, ctx):
cond, expr = vta_expr_check(new_args[0])
if cond:
return QVTAExpr(_forward_op(ref_call, [expr]))
if x_kind is None:
return None
expr = _forward_op(ref_call, [new_args[0].realize()])
register_vta_rewrite("nn.relu", identity_vta_rewrite)
register_vta_rewrite("nn.max_pool2d", identity_vta_rewrite)
@register_vta_rewrite("add")
def add_vta_rewrite(ref_call, new_args, ctx):
"""Rewrite function for ewise add for VTA target"""
lhs_cond, lhs = vta_expr_check(new_args[0])
rhs_cond, rhs = vta_expr_check(new_args[1])
if lhs_cond and rhs_cond:
lhs = new_args[0].realize()
rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs])
elif lhs_cond and not rhs_cond:
return QVTAExpr(_forward_op(ref_call, [lhs, rhs]))
return None
# stop quantize after global_avg_pool2d
quantize_context().stop_quantize()
return expr
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=unused-argument,inconsistent-return-statements
"""Internal module for registering attribute for annotation."""
from __future__ import absolute_import
from ... import target as _target
from .. import expr as _expr
from .. import analysis as _analysis
from ..base import register_relay_node
from ..op import op as _reg
from . import _quantize
from .quantize import _forward_op
def register_partition_function(op_name, frewrite=None, level=10):
def _register(func):
return _reg._Register(op_name, "FQPartitionRewrite", func, level)
return _register(frewrite) if frewrite is not None else _register
@register_relay_node
class QPartitionExpr(_expr.TempExpr):
def __init__(self, expr):
self.__init_handle_by_constructor__(
_quantize.make_partition_expr, expr)
def partition_expr_check(expr):
if isinstance(expr, QPartitionExpr):
return True, expr.expr
return False, expr
@register_partition_function("nn.conv2d")
def conv2d_partition_function(ref_call, new_args, ctx):
"""Rewrite function for conv2d for partition"""
data_cond, data = partition_expr_check(new_args[0])
kernel_cond, kernel = partition_expr_check(new_args[1])
assert not kernel_cond
if data_cond:
data = new_args[0].realize()
ret = _forward_op(ref_call, [data, kernel])
return QPartitionExpr(ret)
def identity_partition_function(ref_call, new_args, ctx):
cond, expr = partition_expr_check(new_args[0])
if cond:
return QPartitionExpr(_forward_op(ref_call, [expr]))
return None
register_partition_function("clip", identity_partition_function)
register_partition_function("nn.relu", identity_partition_function)
register_partition_function("nn.max_pool2d", identity_partition_function)
def add_partition_generic(ref_call, new_args, ctx):
"""Rewrite function for ewise add for partition for generic devices"""
lhs_cond, lhs = partition_expr_check(new_args[0])
rhs_cond, rhs = partition_expr_check(new_args[1])
if lhs_cond and rhs_cond:
# - introduced by ResNet, when for the first residual connection
# ...
# %0 = nn.conv2d(%data, %meta[relay.Constant])
# %1 = add(%0, %meta[relay.Constant])
# %2 = nn.relu(%1)
# %3 = nn.max_pool2d(%2)
# ...
# %9 = nn.conv2d(%8, %meta[relay.Constant])
# %10 = add(%9, %meta[relay.Constant])
# %11 = add(%3, %10) <- need to insert annotations for %3, %10
# ...
lhs = new_args[0].realize()
rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs])
elif not lhs_cond and rhs_cond:
# - introduced by residual connection in ResNet
# ...
# %13 = nn.conv2d(%12, %meta[relay.Constant])
# %14 = add(%13, %meta[relay.Constant])
# %15 = annotation.cast_hint(%15, 'int8')
# %16 = annotation.stop_fusion(%16)
# %17 = add(%5, %16)
# %18 = nn.relu(%17)
# ...
# %24 = nn.conv2d(%23, %meta[relay.Constant])
# %25 = add(%24, %meta[relay.Constant])
# %26 = add(%18, %25) <- need to insert annotations for %25
# ...
rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs])
elif lhs_cond and not rhs_cond:
if _analysis.check_constant(rhs):
# - introduced by batch_norm: add(out, bias)
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
# - introduced by residual connection in MobileNetV2
# ...
# %81 = add(%80, meta[relay.Constant])
# %82 = annotation.cast_hint(%81, 'int8')
# %83 = annotation.stop_fusion(%82)
# %84 = add(%79, %83)
# ...
# %96 = nn.conv2d(%94, %meta[relay.Constant])
# %96 = add(%95, %meta[relay.Constant])
# %97 = add(%96, %84) <- need to insert annotations for %96
# ...
lhs = new_args[0].realize()
return _forward_op(ref_call, [lhs, rhs])
elif not lhs_cond and not rhs_cond:
# trivial case
return None
else:
raise ValueError
# TODO(ziheng) enhance `register_partition_function` to dispatch
# for target automatically
@register_partition_function("add")
def add_partition_function(ref_call, new_args, ctx):
"""Rewrite function for ewise add for partition"""
if 'cuda' in _target.current_target().keys:
#TODO(wuwei/ziheng) cuda specific rules
return add_partition_generic(ref_call, new_args, ctx)
return add_partition_generic(ref_call, new_args, ctx)
@register_partition_function("multiply")
def multiply_partition_function(ref_call, new_args, ctx):
"""Rewrite function for ewise add for partition"""
lhs_cond, lhs = partition_expr_check(new_args[0])
rhs_cond, rhs = partition_expr_check(new_args[1])
if lhs_cond:
# introduced by bn: multiply(out, scale)
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
assert (not lhs_cond) and (not rhs_cond)
return None
......@@ -50,6 +50,12 @@ def kind2str(kind):
return str_map[kind]
def _forward_op(ref_call, args):
"""forward the operator of ref_call with provided arguments"""
return _expr.Call(
ref_call.op, args, ref_call.attrs, ref_call.type_args)
@register_relay_node("relay.quantize.QConfig")
class QConfig(NodeBase):
"""Configure the quantization behavior by setting config variables.
......@@ -74,8 +80,8 @@ class QConfig(NodeBase):
"dtype_activation": "int32",
"global_scale": 8.0,
"skip_conv_layers": [0],
"do_simulation": False,
"round_for_shift": True,
"store_lowbit_output": True,
"debug_enabled_ops": None,
}
......@@ -92,6 +98,7 @@ class QConfig(NodeBase):
self.handle = handle
def guard(self, ref_call):
"""Return true if op is enabled, otherwise return false"""
op_name = ref_call.op.name
if self.debug_enabled_ops is not None:
name_list = [x.value for x in self.debug_enabled_ops]
......@@ -126,9 +133,7 @@ def current_qconfig():
"""Get the current quantization configuration."""
return _quantize._GetCurrentQConfig()
# TODO(tmoreau89, ZihengJiang) the skip parameters are
# hacky - we should explore a more future-proof way to
# skip operators based on pattern matching
def qconfig(**kwargs):
"""Configure the quantization behavior by setting config variables.
......@@ -142,15 +147,14 @@ def qconfig(**kwargs):
skip_conv_layers: list
Specifying which layers to be skipped. Provide a list of indices
that indicate which conv2d layers to leave untouched.
that indicate which conv2d layers to leave untouched. Start from 0.
do_simulation: boolean
Whether to do simulation with float operation only.
round_for_shift: boolean
Whether to add bias for rounding during shift.
store_lowbit_output: boolean
Whether to store low-bit integer back as output before dequantizing.
Some accelerators need this, e.g. VTA.
debug_enabled_ops: None or list of str
Partially quantize specified operators for debugging. The default value
is None, which means will try to call all operartors' annotate rewrite
......@@ -166,35 +170,79 @@ def qconfig(**kwargs):
return _make.node("relay.quantize.QConfig", **node_args)
class AnnotateContext(object):
"""A global singleton annotate scope"""
class QuantizeContext(object):
"""An internal used global context object for annotation,
for putting some state variables like `conv2d_counter`."""
Current = None
def __init__(self):
self.qnode_map = dict()
self._conv2d_counter = 0
self._stop_quantize = False
def __enter__(self):
self._conv2d_counter = 0
return self
def conv2d_counter(self):
"""Get the counter for conv2d."""
return self._conv2d_counter
def check_to_skip(self, ref_call):
"""Check the index of conv2d layer to decide whether to
skip the current operator."""
if self._stop_quantize:
return True
def count_conv2d(self):
"""Increase the value of the conv2d counter by one."""
if current_qconfig().skip_conv_layers is not None:
# check skip conv layers
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if self._conv2d_counter in skipped_indices:
if ref_call.op.name == 'nn.conv2d':
self._conv2d_counter += 1
return True
if ref_call.op.name == 'nn.conv2d':
self._conv2d_counter += 1
return False
def stop_quantize(self):
self._stop_quantize = True
def reset(self):
self._conv2d_counter = 0
self._stop_quantize = False
def __enter__(self):
self.reset()
return self
def __exit__(self, ptype, value, traceback):
pass
def annotate_context():
def quantize_context():
"""Get the global singleton scope"""
if AnnotateContext.Current is None:
AnnotateContext.Current = AnnotateContext()
return AnnotateContext.Current
if QuantizeContext.Current is None:
QuantizeContext.Current = QuantizeContext()
return QuantizeContext.Current
def partition():
"""Partition graph into small low-precision sections by `cast_hint` and
`stop_fusion`.
Returns
-------
ret: tvm.relay.Pass
The registered pass for VTA rewrite.
"""
return _quantize.QuantizePartition()
def annotate():
"""Given a float32 graph, this pass will rewrite the graph and return
a graph which simulates the error brought by the current quantization
scheme.
Returns
-------
ret: tvm.relay.Pass
The registered pass for quantization annotation.
"""
return _quantize.QuantizeAnnotate()
def collect_stats(graph):
......@@ -300,20 +348,8 @@ def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None):
const_params[nclip_max] = _make_const((valid_range - 1))
_analysis.post_order_visit(graph, visit_func)
return _expr.bind(graph, const_params)
def annotate():
"""Given a float32 graph, this pass will rewrite the graph and return
a graph which simulates the error brought by the current quantization
scheme.
Returns
-------
ret: tvm.relay.Pass
The registered pass for quantization annotation.
"""
return _quantize.QuantizeAnnotate()
ret = _expr.bind(graph, const_params)
return ret
def realize():
......@@ -330,17 +366,6 @@ def realize():
return _quantize.QuantizeRealize()
def rewrite_for_vta():
"""Performs rewriting for VTA target.
Returns
-------
ret: tvm.relay.Pass
The registered pass for VTA rewrite.
"""
return _quantize.QuantizeRewriteForVTA()
def _bind_params(func, params):
"""Bind the params to the expression.
"""
......@@ -362,6 +387,25 @@ def _bind_params(func, params):
return _expr.bind(func, bind_dict)
def prerequisite_optimize(graph, params=None):
""" Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """
optimize = _transform.Sequential([_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
if params:
graph = _bind_params(graph, params)
mod = _module.Module.from_expr(graph)
with _transform.PassContext(opt_level=3):
mod = optimize(mod)
return mod["main"]
def quantize(graph, params=None, dataset=None):
""" The quantization procedure. Before running the three main
procedure of quantization, "annotate", "calibrate" and "realize"
......@@ -385,33 +429,23 @@ def quantize(graph, params=None, dataset=None):
ret: Function
The graph after quantization
"""
if params:
graph = _bind_params(graph, params)
graph = prerequisite_optimize(graph, params)
mod = _module.Module.from_expr(graph)
# Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
# "CanonicalizeOps" optimization before quantization.
optimize = _transform.Sequential([_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="QuantizeCalibrate")
# Quantize pass list
quant_passes = [annotate(),
calibrate_pass,
realize(),
_transform.FoldConstant()]
if current_qconfig().store_lowbit_output:
quant_passes = [rewrite_for_vta()] + quant_passes
quant_passes = [partition(),
annotate(),
calibrate_pass]
if not current_qconfig().do_simulation:
quant_passes.append(realize())
quant_passes.append(_transform.FoldConstant())
quantize_seq = _transform.Sequential(quant_passes)
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
with quantize_context():
mod = quantize_seq(mod)
return mod["main"]
......@@ -83,13 +83,18 @@ TVM_ADD_FILELINE)
return {topi::identity(inputs[0])};
});
Expr ForceCast(Expr data) {
static const Op& op = Op::Get("annotation.force_cast");
return CallNode::make(op, {data}, Attrs{}, {});
// relay.annotation.cast_hint
TVM_REGISTER_NODE_TYPE(CastHintAttrs);
Expr CastHint(Expr data, DataType dtype) {
auto attrs = make_node<CastHintAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("annotation.cast_hint");
return CallNode::make(op, {data}, Attrs{attrs}, {});
}
RELAY_REGISTER_OP("annotation.force_cast")
.describe(R"code(Annotate an expression to force a cast.)code"
RELAY_REGISTER_OP("annotation.cast_hint")
.describe(R"code(Annotate an expression to be cast into specific data type.)code"
TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input data.")
......
......@@ -66,6 +66,13 @@ class ConstantChecker : private ExprVisitor {
}
};
bool ConstantCheck(const Expr& e) {
return ConstantChecker().Check(e);
}
TVM_REGISTER_API("relay._analysis.check_constant")
.set_body_typed(ConstantCheck);
// TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator.
......
......@@ -31,6 +31,7 @@
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/reduce.h>
......@@ -420,7 +421,7 @@ Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array
Expr StopFusion(Expr data);
Expr ForceCast(Expr data);
Expr CastHint(Expr data, DataType dtype);
} // namespace relay
} // namespace tvm
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
*
* \file annotate.cc
*
* \brief Annotating the graph with simulated quantize operators.
*/
#include <tvm/relay/transform.h>
#include <tvm/relay/analysis.h>
#include "./quantize.h"
namespace tvm {
namespace relay {
namespace quantize {
using namespace relay::transform;
class QAnnotateExpr;
class QAnnotateExprNode : public TempExprNode {
public:
Expr expr;
QAnnotateKind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
v->Visit("kind", &kind);
}
TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QAnnotateExpr";
TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr);
Expr QAnnotateExprNode::Realize() const {
return expr;
}
QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) {
auto rnode = make_node<QAnnotateExprNode>();
rnode->expr = expr;
rnode->kind = kind;
return QAnnotateExpr(rnode);
}
TVM_REGISTER_API("relay._quantize.make_annotate_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QAnnotateExprNode::make(args[0],
static_cast<QAnnotateKind>(args[1].operator int()));
});
Pass QuantizeAnnotate() {
// TODO(tvm-teams): since partition has added cast_hint in different
// branches, try to remove this in the future.
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f =
runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
auto new_params = func->params;
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
}
return FunctionNode::make(new_params,
func->body,
func->ret_type,
func->type_params,
func->attrs);
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
.set_body_typed(QuantizeAnnotate);
} // namespace quantize
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
*
* \file partition.cc
*
* \brief Partition a graph into sections for quantization.
*/
#include <tvm/relay/transform.h>
#include "../pattern_util.h"
#include "./quantize.h"
namespace tvm {
namespace relay {
namespace quantize {
using namespace relay::transform;
class QPartitionExpr;
class QPartitionExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr expr;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
}
TVM_DLL static QPartitionExpr make(Expr expr);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QPartitionExpr";
TVM_DECLARE_NODE_TYPE_INFO(QPartitionExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QPartitionExpr, QPartitionExprNode, TempExpr);
Expr QPartitionExprNode::Realize() const {
// insert cast hint and stop fusion
const QConfig& cfg = QConfig::Current();
Expr ret = CastHint(this->expr, cfg->dtype_input);
return StopFusion(ret);
}
QPartitionExpr QPartitionExprNode::make(Expr expr) {
auto rnode = make_node<QPartitionExprNode>();
rnode->expr = expr;
return QPartitionExpr(rnode);
}
TVM_REGISTER_API("relay._quantize.make_partition_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QPartitionExprNode::make(args[0]);
});
Pass QuantizePartition() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
auto ret = Downcast<Function>(
ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr));
return ret;
};
return CreateFunctionPass(pass_func, 1, "QuantizePartition", {});
}
TVM_REGISTER_API("relay._quantize.QuantizePartition")
.set_body_typed(QuantizePartition);
} // namespace quantize
} // namespace relay
} // namespace tvm
......@@ -26,17 +26,9 @@
* for compression and acceleration.
*/
#include <dmlc/thread_local.h>
#include <tvm/base.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <cmath>
#include <string>
#include <vector>
#include <stack>
#include <utility>
#include "../pattern_util.h"
#include "./quantize.h"
......@@ -44,8 +36,6 @@ namespace tvm {
namespace relay {
namespace quantize {
using namespace relay::transform;
TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);
bool SimulatedQuantizeRel(const Array<Type>& types,
......@@ -91,490 +81,6 @@ TVM_REGISTER_API("relay._quantize.simulated_quantize")
});
// =============
// annotate pass
Expr QAnnotateExprNode::Realize() const {
const auto& cfg = QConfig::Current();
if (cfg->store_lowbit_output) {
// store low bit output back for VTA
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
return (*f)(this->expr, static_cast<int>(kQInput));
} else {
return expr;
}
}
QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) {
auto rnode = make_node<QAnnotateExprNode>();
rnode->expr = expr;
rnode->kind = kind;
return QAnnotateExpr(rnode);
}
TVM_REGISTER_API("relay._quantize.make_annotate_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QAnnotateExprNode::make(args[0],
static_cast<QAnnotateKind>(args[1].operator int()));
});
// =============
// realize pass
Expr QRealizeIntExprNode::Realize() const {
const auto& cfg = QConfig::Current();
Expr data = this->data;
if (cfg->store_lowbit_output) {
data = Cast(data, cfg->dtype_input);
}
// dequantize
data = Cast(data, Float(32));
data = Multiply(data, this->dom_scale);
return data;
}
QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) {
NodePtr<QRealizeIntExprNode> n = make_node<QRealizeIntExprNode>();
n->data = std::move(data);
n->dom_scale = std::move(dom_scale);
n->dtype = std::move(dtype);
return QRealizeIntExpr(n);
}
inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
return CallNode::make(ref_call->op,
args, ref_call->attrs, ref_call->type_args);
}
/* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
// here we assume the dtype of data is dtype activation
if (s1 == s2) return data;
float factor = s1 / s2;
float shift_factor = std::log2(factor);
CHECK_GT(shift_factor, 0);
if (static_cast<int>(shift_factor) == shift_factor) {
return LeftShift(data, MakeConstantScalar(dtype,
static_cast<int>(shift_factor)));
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
data = Cast(data, Float(32));
data = Multiply(data, MakeConstantScalar(Float(32), factor));
return Cast(Round(data), dtype);
}
}
Expr QuantizeRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
// do not handle data type cast
const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>();
CHECK_EQ(param->rounding, "round");
Expr dom_scale = new_args[1];
Expr clip_min = new_args[2];
Expr clip_max = new_args[3];
float dom_scale_imm = GetScalarFromConstant<float>(dom_scale);
float clip_min_imm = GetScalarFromConstant<float>(clip_min);
float clip_max_imm = GetScalarFromConstant<float>(clip_max);
// x * idom_scale = y * odom_scale
// => y = x * idom_scale / odom_scale
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
// int32->int8
Expr data = n->data;
float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale);
float odom_scale_imm = GetScalarFromConstant<float>(dom_scale);
if (idom_scale_imm == odom_scale_imm) {
// same domain scale, only clip
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
}
float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
CHECK_NE(shift_nbit, 0);
if (static_cast<int>(shift_nbit) == shift_nbit) {
if (shift_nbit > 0) {
// use right shift
if (cfg->round_for_shift) {
float round_bias = std::pow(2.0, shift_nbit - 1);
data = Add(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(round_bias)));
}
data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
} else {
data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
}
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} else {
// float computation
data = Cast(data, Float(32));
Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
}
}
// quantize from real
CHECK(!new_args[0]->derived_from<TempExprNode>());
Expr data = new_args[0];
Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
}
Expr FoldConstantOpt(const Expr& expr) {
auto mod = ModuleNode::FromExpr(expr);
mod = transform::FoldConstant()(mod);
auto entry_func = mod->Lookup("main");
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
}
RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
Expr Conv2dRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>()) {
return Expr(nullptr);
}
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
CHECK(lhs);
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
CHECK(rhs);
Expr ldata = lhs->data;
if (lhs->dtype != cfg->dtype_input) {
ldata = Cast(ldata, cfg->dtype_input);
}
Expr rdata = Cast(rhs->data, cfg->dtype_weight);
const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
auto attrs = make_node<Conv2DAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = CallNode::make(ref_call->op,
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
}
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
Expr DenseRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->derived_from<TempExprNode>() || !new_args[1]->derived_from<TempExprNode>()) {
return Expr(nullptr);
}
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
Expr ldata = lhs->data;
if (lhs->dtype != cfg->dtype_input) {
ldata = Cast(ldata, cfg->dtype_input);
}
Expr rdata = Cast(rhs->data, cfg->dtype_weight);
const auto ref_attrs = ref_call->attrs.as<DenseAttrs>();
auto attrs = make_node<DenseAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = CallNode::make(ref_call->op,
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
}
RELAY_REGISTER_OP("nn.dense")
.set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);
Expr MulRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
// execute the operation with activation data type.
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
Expr ldata = lhs->data;
Expr rdata = rhs->data;
DataType dtype = cfg->dtype_activation;
if (lhs->dtype != dtype) {
ldata = Cast(ldata, dtype);
}
if (rhs->dtype != dtype) {
rdata = Cast(rdata, dtype);
}
Expr ret = ForwardOp(ref_call, {ldata, rdata});
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("multiply")
.set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);
float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
if (nptrs.size() == 2) {
// x = a * s1, y = b * s2
// x + y = (a * s1 / s2 + b) * s2, if s1 > s2
// = (a + b * s2 / s1) * s1, if s2 > s1
float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale);
float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale);
return s1 > s2 ? s2 : s1;
} else {
const QConfig& cfg = QConfig::Current();
float scale = cfg->global_scale;
return scale / std::pow(2.0, cfg->nbit_activation - 1);
}
}
/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
DataType* dtype_ptr, Expr* scale_ptr) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();
std::vector<const QRealizeIntExprNode*> nptrs;
Array<Expr> ret;
for (auto arg : args) {
const auto* nptr = arg.as<QRealizeIntExprNode>();
CHECK(nptr);
nptrs.push_back(nptr);
ret.push_back(nptr->data);
}
// unify the data type
CHECK_EQ(ref_args.size(), args.size());
DataType dtype;
if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
dtype = cfg->dtype_input;
} else {
dtype = cfg->dtype_activation;
}
for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input);
new_arg = StopFusion(new_arg);
ret.Set(i, Cast(new_arg, dtype));
}
}
// unify the dom_scale
float s = ChooseDomScale(nptrs);
Expr dom_scale = MakeConstantScalar(Float(32), s);
for (size_t i = 0; i < ret.size(); ++i) {
float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype));
}
*dtype_ptr = dtype;
*scale_ptr = dom_scale;
return ret;
}
Expr AddRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("add")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
Expr ClipRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
auto attrs = make_node<ClipAttrs>();
double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
attrs->a_min = ref_attrs->a_min / dom_scale;
attrs->a_max = ref_attrs->a_max / dom_scale;
Expr ret = CallNode::make(ref_call->op,
{n->data}, Attrs(attrs), ref_call->type_args);
return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("clip")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
Expr ConcatenateRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
CHECK_EQ(ref_call->args.size(), 1);
const auto* tuple = new_args[0].as<TupleNode>();
const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
CHECK(tuple);
CHECK(ref_tuple);
const Array<Expr>& arr = tuple->fields;
const Array<Expr>& ref_arr = ref_tuple->fields;
if (arr[0].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
} else {
for (auto arg : new_args) {
CHECK(!arg->derived_from<TempExprNode>());
}
return Expr(nullptr);
}
}
RELAY_REGISTER_OP("concatenate")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);
/* \brief forward the original operator */
Expr IdentityRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = ForwardOp(ref_call, {n->data});
return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("strided_slice")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("annotation.stop_fusion")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
/* \brief for unary operators which requantize its input to dtype_nbit */
Expr CastDtypeInputRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr data = Cast(n->data, cfg->dtype_input);
Expr ret = ForwardOp(ref_call, {data});
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.max_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
Expr AvgPoolRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr data = n->data;
if (n->dtype != cfg->dtype_activation) {
data = Cast(n->data, cfg->dtype_activation);
}
Expr ret = ForwardOp(ref_call, {data});
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
Expr ForceCastRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = Cast(n->data, cfg->dtype_input);
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("annotation.force_cast")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ForceCastRealize);
TVM_REGISTER_API("relay._quantize.realize")
.set_body_typed<Expr(Expr)>([](const Expr& e) {
Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
return ret;
});
// =============
// qconfig
QConfig qconfig() {
return QConfig(make_node<QConfigNode>());
}
/*! \brief Entry to hold the BuildConfig context stack. */
struct TVMQConfigThreadLocalEntry {
/*! \brief The default build config if the stack is empty */
......@@ -584,7 +90,7 @@ struct TVMQConfigThreadLocalEntry {
std::stack<QConfig> context_stack;
TVMQConfigThreadLocalEntry() :
default_config(qconfig()) {
default_config(make_node<QConfigNode>()) {
}
};
......@@ -620,8 +126,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
p->stream << "nbit_activation=" << op->nbit_activation << ", ";
p->stream << "global_scale=" << op->global_scale << ", ";
p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
p->stream << "do_simulation==" << op->do_simulation << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
p->stream << ")";
});
......@@ -635,95 +141,6 @@ TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
.set_body_typed(QConfig::ExitQConfigScope);
Pass QuantizeAnnotate() {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f =
runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
auto new_params = func->params;
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
}
return FunctionNode::make(new_params,
func->body,
func->ret_type,
func->type_params,
func->attrs);
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
.set_body_typed(QuantizeAnnotate);
Pass QuantizeRealizePass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
};
return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeRealize")
.set_body_typed(QuantizeRealizePass);
Pass QuantizeRewriteForVTAPass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQVTARewrite", nullptr, nullptr));
};
return CreateFunctionPass(pass_func, 1, "QuantizeRewriteForVTA", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeRewriteForVTA")
.set_body_typed(QuantizeRewriteForVTAPass);
// =============
// Insert stop_fusion for vta.
Expr QVTAExprNode::Realize() const {
Expr ret = ForceCast(this->expr);
return StopFusion(ret);
}
QVTAExpr QVTAExprNode::make(Expr expr) {
auto rnode = make_node<QVTAExprNode>();
rnode->expr = expr;
return QVTAExpr(rnode);
}
TVM_REGISTER_API("relay._quantize.make_vta_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QVTAExprNode::make(args[0]);
});
TVM_REGISTER_API("relay._quantize.make_stop_fusion")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
return StopFusion(expr);
});
TVM_REGISTER_API("relay._quantize.temp_expr_realize")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
const QVTAExprNode* n = expr.as<QVTAExprNode>();
CHECK(n);
return n->Realize();
});
} // namespace quantize
} // namespace relay
} // namespace tvm
......@@ -59,104 +59,8 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
}
};
/*!
* \brief TempExpr used during annotate forward rewrite.
*/
class QAnnotateExpr;
/*!
* \brief TempExprNode used during annotate forward rewrite.
*/
class QAnnotateExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr expr;
/*! \brief The kind of annotate field */
QAnnotateKind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
v->Visit("kind", &kind);
}
TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QAnnotateExpr";
TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr);
/*!
* \brief TempExpr used to insert `force_cast` for VTA.
*/
class QVTAExpr;
/*!
* \brief TempExprNode used to insert `force_cast` for VTA.
*/
class QVTAExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr expr;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
}
TVM_DLL static QVTAExpr make(Expr expr);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QVTAExpr";
TVM_DECLARE_NODE_TYPE_INFO(QVTAExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QVTAExpr, QVTAExprNode, TempExpr);
/*! \brief TempExpr used during realize forward rewrite. */
class QRealizeExpr;
/*! \brief TempExpr representing integer. */
class QRealizeIntExpr;
class QRealizeExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr data;
static constexpr const char* _type_key = "relay.quantize.QRealizeExpr";
TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr);
class QRealizeIntExprNode : public QRealizeExprNode {
public:
Expr dom_scale;
/*! \brief current data type */
DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("dom_scale", &dom_scale);
v->Visit("dtype", &dtype);
}
Expr Realize() const final;
TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype);
static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr";
TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode);
};
RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr);
class QConfig;
/*!
* \brief Container for build configuration options
*/
......@@ -170,8 +74,8 @@ class QConfigNode : public Node {
DataType dtype_activation = Int(32);
double global_scale = 8.0;
Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
bool do_simulation = false;
bool round_for_shift = true;
bool store_lowbit_output = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
void VisitAttrs(AttrVisitor* v) final {
......@@ -183,8 +87,8 @@ class QConfigNode : public Node {
v->Visit("dtype_activation", &dtype_activation);
v->Visit("global_scale", &global_scale);
v->Visit("skip_conv_layers", &skip_conv_layers);
v->Visit("do_simulation", &do_simulation);
v->Visit("round_for_shift", &round_for_shift);
v->Visit("store_lowbit_output", &store_lowbit_output);
v->Visit("debug_enabled_ops", &debug_enabled_ops);
}
......@@ -250,12 +154,6 @@ struct QConfigContext {
}
};
/*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \return The new BuildConfig
*/
TVM_DLL QConfig qconfig();
} // namespace quantize
} // namespace relay
} // namespace tvm
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
*
* \file realize.cc
*
* \brief Realizing the simulated graph into real low-precision
* graph.
*/
#include <tvm/relay/transform.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include "./quantize.h"
#include "../pattern_util.h"
namespace tvm {
namespace relay {
namespace quantize {
using namespace relay::transform;
class QRealizeExpr;
class QRealizeIntExpr;
class QRealizeExprNode : public TempExprNode {
public:
Expr data;
static constexpr const char* _type_key = "relay.quantize.QRealizeExpr";
TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr);
class QRealizeIntExprNode : public QRealizeExprNode {
public:
Expr dom_scale;
DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("dom_scale", &dom_scale);
v->Visit("dtype", &dtype);
}
Expr Realize() const final;
TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype);
static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr";
TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode);
};
RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr);
Expr QRealizeIntExprNode::Realize() const {
Expr data = this->data;
// dequantize
data = Cast(data, Float(32));
data = Multiply(data, this->dom_scale);
return data;
}
QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) {
NodePtr<QRealizeIntExprNode> n = make_node<QRealizeIntExprNode>();
n->data = std::move(data);
n->dom_scale = std::move(dom_scale);
n->dtype = std::move(dtype);
return QRealizeIntExpr(n);
}
inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
return CallNode::make(ref_call->op,
args, ref_call->attrs, ref_call->type_args);
}
/* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
// here we assume the dtype of data is dtype activation
if (s1 == s2) return data;
float factor = s1 / s2;
float shift_factor = std::log2(factor);
CHECK_GT(shift_factor, 0);
if (static_cast<int>(shift_factor) == shift_factor) {
return LeftShift(data, MakeConstantScalar(dtype,
static_cast<int>(shift_factor)));
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(dtype, factor));
} else {
LOG(FATAL) << "fall back to float computation";
data = Cast(data, Float(32));
data = Multiply(data, MakeConstantScalar(Float(32), factor));
return Cast(Round(data), dtype);
}
}
Expr QuantizeRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
// do not handle data type cast
const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>();
CHECK_EQ(param->rounding, "round");
Expr dom_scale = new_args[1];
Expr clip_min = new_args[2];
Expr clip_max = new_args[3];
float dom_scale_imm = GetScalarFromConstant<float>(dom_scale);
float clip_min_imm = GetScalarFromConstant<float>(clip_min);
float clip_max_imm = GetScalarFromConstant<float>(clip_max);
// x * idom_scale = y * odom_scale
// => y = x * idom_scale / odom_scale
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
// int32->int8
Expr data = n->data;
float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale);
float odom_scale_imm = GetScalarFromConstant<float>(dom_scale);
if (idom_scale_imm == odom_scale_imm) {
// same domain scale, only clip
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
}
float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
CHECK_GT(shift_nbit, 0);
if (static_cast<int>(shift_nbit) == shift_nbit) {
// use right shift
if (cfg->round_for_shift) {
float round_bias = std::pow(2.0, shift_nbit - 1);
data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias)));
}
data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} else {
// float computation
data = Cast(data, Float(32));
Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
}
}
// quantize from real
CHECK(!new_args[0]->derived_from<TempExprNode>());
Expr data = new_args[0];
Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
}
Expr FoldConstantOpt(const Expr& expr) {
auto mod = ModuleNode::FromExpr(expr);
mod = transform::FoldConstant()(mod);
auto entry_func = mod->Lookup("main");
return expr.as<FunctionNode>() == nullptr ? entry_func->body : entry_func;
}
RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
Expr Conv2dRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>()) {
return Expr(nullptr);
}
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
CHECK(lhs);
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
CHECK(rhs);
Expr ldata = lhs->data;
if (lhs->dtype != cfg->dtype_input) {
ldata = Cast(ldata, cfg->dtype_input);
}
Expr rdata = Cast(rhs->data, cfg->dtype_weight);
const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
auto attrs = make_node<Conv2DAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = CallNode::make(ref_call->op,
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
}
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
Expr DenseRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->derived_from<TempExprNode>() || !new_args[1]->derived_from<TempExprNode>()) {
return Expr(nullptr);
}
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
Expr ldata = lhs->data;
if (lhs->dtype != cfg->dtype_input) {
ldata = Cast(ldata, cfg->dtype_input);
}
Expr rdata = Cast(rhs->data, cfg->dtype_weight);
const auto ref_attrs = ref_call->attrs.as<DenseAttrs>();
auto attrs = make_node<DenseAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = CallNode::make(ref_call->op,
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
}
RELAY_REGISTER_OP("nn.dense")
.set_attr<FForwardRewrite>("FQRealizeRewrite", DenseRealize);
Expr MulRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
// execute the operation with activation data type.
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
Expr ldata = lhs->data;
Expr rdata = rhs->data;
DataType dtype = cfg->dtype_activation;
if (lhs->dtype != dtype) {
ldata = Cast(ldata, dtype);
} else {
CHECK_EQ(lhs->dtype, dtype);
}
if (rhs->dtype != dtype) {
rdata = Cast(rdata, dtype);
} else {
CHECK_EQ(rhs->dtype, dtype);
}
Expr ret = ForwardOp(ref_call, {ldata, rdata});
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("multiply")
.set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);
float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
if (nptrs.size() == 2) {
// x = a * s1, y = b * s2
// x + y = (a * s1 / s2 + b) * s2, if s1 > s2
// = (a + b * s2 / s1) * s1, if s2 > s1
float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale);
float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale);
return s1 > s2 ? s2 : s1;
} else {
const QConfig& cfg = QConfig::Current();
float scale = cfg->global_scale;
return scale / std::pow(2.0, cfg->nbit_activation - 1);
}
}
/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,
DataType* dtype_ptr, Expr* scale_ptr) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
const QConfig& cfg = QConfig::Current();
std::vector<const QRealizeIntExprNode*> nptrs;
Array<Expr> ret;
for (auto arg : args) {
const auto* nptr = arg.as<QRealizeIntExprNode>();
CHECK(nptr);
nptrs.push_back(nptr);
ret.push_back(nptr->data);
}
// unify the data type
CHECK_EQ(ref_args.size(), args.size());
DataType dtype;
if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {
dtype = cfg->dtype_input;
} else {
dtype = cfg->dtype_activation;
}
for (size_t i = 0; i < ret.size(); ++i) {
auto ref_arg = ref_args[i].as<CallNode>();
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
} else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&
ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {
auto new_arg = Cast(ret[i], cfg->dtype_input);
new_arg = StopFusion(new_arg);
ret.Set(i, Cast(new_arg, dtype));
}
}
// unify the dom_scale
float s = ChooseDomScale(nptrs);
Expr dom_scale = MakeConstantScalar(Float(32), s);
for (size_t i = 0; i < ret.size(); ++i) {
float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype));
}
*dtype_ptr = dtype;
*scale_ptr = dom_scale;
return ret;
}
Expr AddRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("add")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
Expr ClipRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
const auto ref_attrs = ref_call->attrs.as<ClipAttrs>();
auto attrs = make_node<ClipAttrs>();
double dom_scale = GetScalarFromConstant<float>(n->dom_scale);
attrs->a_min = ref_attrs->a_min / dom_scale;
attrs->a_max = ref_attrs->a_max / dom_scale;
Expr ret = CallNode::make(ref_call->op,
{n->data}, Attrs(attrs), ref_call->type_args);
return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("clip")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ClipRealize);
Expr ConcatenateRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
CHECK_EQ(ref_call->args.size(), 1);
const auto* tuple = new_args[0].as<TupleNode>();
const auto* ref_tuple = ref_call->args[0].as<TupleNode>();
CHECK(tuple);
CHECK(ref_tuple);
const Array<Expr>& arr = tuple->fields;
const Array<Expr>& ref_arr = ref_tuple->fields;
if (arr[0].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
} else {
for (auto arg : new_args) {
CHECK(!arg->derived_from<TempExprNode>());
}
return Expr(nullptr);
}
}
RELAY_REGISTER_OP("concatenate")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);
/* \brief forward the original operator */
Expr IdentityRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = ForwardOp(ref_call, {n->data});
return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("strided_slice")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("annotation.stop_fusion")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
/* \brief for unary operators which requantize its input to dtype_nbit */
Expr CastDtypeInputRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr data = Cast(n->data, cfg->dtype_input);
Expr ret = ForwardOp(ref_call, {data});
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.max_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastDtypeInputRealize);
Expr AvgPoolRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr data = n->data;
if (n->dtype != cfg->dtype_activation) {
data = Cast(n->data, cfg->dtype_activation);
}
Expr ret = ForwardOp(ref_call, {data});
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
Expr CastHintRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const auto param = ref_call->attrs.as<CastHintAttrs>();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = Cast(n->data, param->dtype);
return QRealizeIntExprNode::make(ret, n->dom_scale, param->dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("annotation.cast_hint")
.set_attr<FForwardRewrite>("FQRealizeRewrite", CastHintRealize);
Pass QuantizeRealizePass() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(
ForwardRewrite(f, "FQRealizeRewrite", nullptr, nullptr));
};
return CreateFunctionPass(pass_func, 1, "QuantizeRealize", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeRealize")
.set_body_typed(QuantizeRealizePass);
} // namespace quantize
} // namespace relay
} // namespace tvm
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import namedtuple
import tvm
from tvm import relay
from tvm.relay import quantize as qtz
import mxnet as mx
from mxnet import gluon
import logging
import os
logging.basicConfig(level=logging.INFO)
Config = namedtuple('Config', ['model', 'nbit_input', 'dtype_input', 'nbit_output', 'dtype_output', 'global_scale', 'expected_acc'])
def get_val_data(model_name,
rec_val,
batch_size,
num_workers=4):
rec_val = os.path.expanduser(rec_val)
mean_rgb = [123.68, 116.779, 103.939]
std_rgb = [58.393, 57.12, 57.375]
def batch_fn(batch, ctx):
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
return data, label
img_size = 299 if model_name == 'inceptionv3' else 224
val_data = mx.io.ImageRecordIter(
path_imgrec = rec_val,
preprocess_threads = num_workers,
shuffle = False,
batch_size = batch_size,
resize = 256,
data_shape = (3, img_size, img_size),
mean_r = mean_rgb[0],
mean_g = mean_rgb[1],
mean_b = mean_rgb[2],
std_r = std_rgb[0],
std_g = std_rgb[1],
std_b = std_rgb[2],
)
return val_data, batch_fn
def get_model(model_name, batch_size, qconfig, target=None, original=False, simulated=False):
gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)
img_size = 299 if model_name == 'inceptionv3' else 224
data_shape = (batch_size, 3, img_size, img_size)
mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
net = mod['main']
with relay.build_config(opt_level=3):
qfunc = relay.quantize.prerequisite_optimize(net, params=params)
logging.debug('original')
logging.debug(qfunc.astext(show_meta_data=False))
if original:
return qfunc
with qconfig:
logging.debug('current quantize config')
logging.debug(qtz.current_qconfig())
qfunc = qtz.quantize(qfunc)
logging.debug('after quantize')
logging.debug(qfunc.astext(show_meta_data=False))
return qfunc
def eval_acc(model, dataset, batch_fn, target=tvm.target.cuda(), ctx=tvm.gpu(), log_interval=100):
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(model, target)
# create runtime module
m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
# setup evaluaiton metric
dataset.reset()
batch_size = dataset.batch_size
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
acc_top1.reset()
acc_top5.reset()
# Execute
for i, batch in enumerate(dataset):
data, label = batch_fn(batch, [mx.cpu(0)])
m.run(data=data[0].asnumpy())
out_arr = m.get_output(0)
acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())])
acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())])
if not (i + 1) % log_interval:
_, top1 = acc_top1.get()
_, top5 = acc_top5.get()
nsamples = (i + 1) * batch_size
logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5)
logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5)
return top1
def test_quantize_acc(cfg, rec_val):
qconfig = qtz.qconfig(skip_conv_layers=[0],
nbit_input=cfg.nbit_input,
nbit_weight=cfg.nbit_input,
global_scale=cfg.global_scale,
dtype_input=cfg.dtype_input,
dtype_weight=cfg.dtype_input,
dtype_activation=cfg.dtype_output,
debug_enabled_ops=None)
model = get_model(cfg.model, 32, qconfig, tvm.target.cuda())
val_data, batch_fn = get_val_data(cfg.model, rec_val=rec_val, batch_size=32)
acc = eval_acc(model, val_data, batch_fn)
assert acc > cfg.expected_acc
return acc
if __name__ == "__main__":
#TODO(for user): replace the line with the path to imagenet validation dataset
rec_val = "/scratch/tqchen/imagenet/val.rec"
results = []
configs = [
Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.666),
Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=8.0, expected_acc=0.692),
Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.692),
Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.733),
Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.747),
Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.756),
# TODO: need to fix accuracy
# Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0),
]
for config in configs:
acc = test_quantize_acc(config, rec_val)
results.append((config, acc))
for res in results:
print(res)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import math
import numpy as np
import tvm
from tvm import relay
from tvm.relay import quantize as qtz
from tvm.relay import transform
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def make_dataset(graph, size=100):
args = run_infer_type(graph).params
def create_arr(var):
ttype = var.type_annotation
np_arr = np.random.uniform(-1.0, 1.0, size=ttype.concrete_shape).astype(ttype.dtype)
return tvm.ndarray.array(np_arr)
params = {}
for arg in args:
if arg.name_hint == 'data':
dataset = [{'data': create_arr(arg)} for _ in range(size)]
else:
params[arg.name_hint] = create_arr(arg)
return dataset, params
def test_simulated_quantize():
data = relay.var("data", relay.ty.TensorType((3, 4, 5, 6), "float32"))
out = qtz._annotate.attach_simulated_quantize(data, 1)
out = run_infer_type(out)
assert out.checked_type == out.args[0].checked_type
assert out.args[1].checked_type == relay.ty.TensorType(tuple(), "float32")
assert out.args[2].checked_type == relay.ty.TensorType(tuple(), "float32")
assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")
def test_quantize_pass():
def quantize_weight(arr):
maximum = np.amax(np.abs(arr.asnumpy()))
scale = 2**math.ceil(math.log(maximum, 2))
out = np.around(arr.asnumpy() / scale * 128).astype('int8')
out = np.clip(out, -127, 127)
return relay.const(out, 'int8')
n, c, h, w = 1, 3, 224, 224
def make_graph(data):
weight = relay.var("conv_weight")
out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
out = relay.Function(relay.analysis.free_vars(out), out)
return out
def make_qgraph(data, weight):
out = data * relay.const(32.0)
out = relay.round(out)
out = relay.clip(out, a_min=-127, a_max=127)
out = out.astype('int8')
out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
padding=(1, 1), channels=c, out_dtype='int32')
out = out.astype('float32')
out = relay.multiply(out, relay.const(0.00024414062))
out = relay.Function(relay.analysis.free_vars(out), out)
return out
np.random.seed(42)
data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
graph = make_graph(data)
dataset, params = make_dataset(graph, 10)
with qtz.qconfig(skip_conv_layers=None, global_scale=4.0,
round_for_shift=False, store_lowbit_output=False):
qgraph0 = qtz.quantize(graph, params)
qgraph0 = run_infer_type(qgraph0)
conv_weight = quantize_weight(params['conv_weight'])
qgraph1 = make_qgraph(data, conv_weight)
qgraph1 = run_infer_type(qgraph1)
graph = relay.create_executor('graph')
res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
if __name__ == "__main__":
test_simulated_quantize()
test_quantize_pass()
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
set -e
set -u
export PYTHONPATH=python:topi/python
# Rebuild cython
make cython3
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
rm -rf topi/python/topi/*.pyc topi/python/topi/*/*.pyc topi/python/topi/*/*/*.pyc topi/python/topi/*/*/*/*.pyc
python3 -m nose -v topi/tests/python/nightly
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment