Unverified Commit 741b6bbe by ziheng Committed by GitHub

[OPT] Low-bit Quantization (#2116)

* [QUANTIZE] Quantization implementation.

* Update.

* Update.

* Update.

* Update.
parent da972bdf
...@@ -139,7 +139,6 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> { ...@@ -139,7 +139,6 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
} }
}; };
struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> { struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
Array<Integer> axes; Array<Integer> axes;
...@@ -151,16 +150,16 @@ struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> { ...@@ -151,16 +150,16 @@ struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
} }
}; };
// Clip /*! \brief Attributes for Clip operator */
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> { struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
double a_min; double a_min;
double a_max; double a_max;
TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") { TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
TVM_ATTR_FIELD(a_min) TVM_ATTR_FIELD(a_min)
.describe("The minimum clip value."); .describe("The minimum clip value.");
TVM_ATTR_FIELD(a_max) TVM_ATTR_FIELD(a_max)
.describe("The maximum clip value."); .describe("The maximum clip value.");
} }
}; };
......
...@@ -551,6 +551,7 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr, ...@@ -551,6 +551,7 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
return map_.get<ValueType>(expr, def_value); return map_.get<ValueType>(expr, def_value);
} }
/*! /*!
* \brief Check that an expression is a "primtive operator". * \brief Check that an expression is a "primtive operator".
* *
......
...@@ -8,7 +8,7 @@ from . import expr ...@@ -8,7 +8,7 @@ from . import expr
from . import expr_functor from . import expr_functor
from . import module from . import module
from . import ir_pass from . import ir_pass
from .build_module import build, build_config, create_executor from .build_module import build, build_config, create_executor, optimize
from . import parser from . import parser
from . import debug from . import debug
...@@ -23,6 +23,7 @@ from . import vision ...@@ -23,6 +23,7 @@ from . import vision
from . import image from . import image
from . import frontend from . import frontend
from . import backend from . import backend
from . import quantize
from .scope_builder import ScopeBuilder from .scope_builder import ScopeBuilder
......
...@@ -129,7 +129,7 @@ def _bind_params_by_name(func, params): ...@@ -129,7 +129,7 @@ def _bind_params_by_name(func, params):
return expr.bind(func, bind_dict) return expr.bind(func, bind_dict)
def optimize(func, target, params=None): def optimize(func, target=None, params=None):
"""Perform target invariant optimizations. """Perform target invariant optimizations.
Parameters Parameters
...@@ -400,7 +400,7 @@ class GraphExecutor(_interpreter.Executor): ...@@ -400,7 +400,7 @@ class GraphExecutor(_interpreter.Executor):
graph_json, mod, params = build(func, target=self.target) graph_json, mod, params = build(func, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx) gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params: if params:
gmodule.set_input(*params) gmodule.set_input(**params)
def _graph_wrapper(*args, **kwargs): def _graph_wrapper(*args, **kwargs):
args = self._convert_args(func, args, kwargs) args = self._convert_args(func, args, kwargs)
......
#pylint: disable=wildcard-import, redefined-builtin
"""Automatic quantization utilities."""
from __future__ import absolute_import as _abs
from .quantize import *
from ._annotate import register_annotate_function
#pylint: disable=unused-argument
"""Internal module for registering attribute for annotation."""
from __future__ import absolute_import
import topi
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig
from .quantize import _conv_counter, _set_conv_counter
from .. import expr as _expr
from .. import op as _op
from ..op import op as _reg
from ..base import register_relay_node
from ..._ffi.function import register_func
@_reg.register_compute("relay.op.annotation.simulated_quantize")
def simulated_quantize_compute(attrs, inputs, out_type, target):
"""Compiler for simulated_quantize."""
assert len(inputs) == 4
assert attrs.sign
assert attrs.rounding == "round"
data, scale, clip_min, clip_max = inputs
# simulate rounding error
scaled_data = topi.divide(data, scale)
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
round_data = topi.round(clipped_data)
# recover data
rdata = topi.multiply(round_data, scale)
return [rdata]
_reg.register_schedule("relay.op.annotation.simulated_quantize",
_reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.OpPattern.OPAQUE)
@register_relay_node
class QAnnotateExpr(_expr.TempExpr):
"""A special kind of Expr for Annotating.
Parameters
---------
expr: Expr
the original relay ir expr.
kind: QAnnotateKind
the kind of annotation field.
"""
def __init__(self, expr, kind):
self.__init_handle_by_constructor__(
_quantize.make_annotate_expr, expr, kind)
def _forward_op(ref_call, args):
"""forward the operator of ref_call with provided arguments"""
return _expr.Call(
ref_call.op, args, ref_call.attrs, ref_call.type_args)
def _get_expr_kind(anno):
"""Get the expression and QAnnotateKind from QAnnotateExpr or Expr"""
if isinstance(anno, QAnnotateExpr):
return anno.expr, anno.kind
return anno, None
def register_annotate_function(op_name, frewrite=None, level=10):
"""register a rewrite function for operator, used by annotation.
Parameters
---------
op_name: str
The name of operation
frewrite : function, optional
The function to be registered.
level : int, optional
The priority level
"""
def default_rewrite(ref_call, new_args, ctx):
# recover from QAnnotateExpr
args = [_get_expr_kind(x)[0] for x in new_args]
return _forward_op(ref_call, args)
def _register(func):
"""internal register function"""
def frewrite_with_guard(ref_call, new_args, ctx):
if not current_qconfig().guard(ref_call):
return default_rewrite(ref_call, new_args, ctx)
return func(ref_call, new_args, ctx)
_op.op._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
return frewrite_with_guard
return _register(frewrite) if frewrite is not None else _register
@register_func("relay.quantize.attach_simulated_quantize")
def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
"""Attach a simulated quantize operation after input data expr.
Parameters
---------
data: Expr
the original data expr.
kind: QAnnotateKind
the kind of annotation field.
"""
dom_scale = _expr.var("dom_scale")
clip_min = _expr.var("clip_min")
clip_max = _expr.var("clip_max")
return _quantize.simulated_quantize(
data, dom_scale, clip_min, clip_max, kind, sign, rounding)
@register_annotate_function("nn.conv2d")
def conv2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv2d. Lhs of conv will be quantized to
input field, and rhs of conv will be quantized to weight field.
Output would be in activation field"""
cnt = _conv_counter()
if cnt < current_qconfig().skip_k_conv:
_set_conv_counter(cnt + 1)
return None
_set_conv_counter(cnt + 1)
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply."""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None and rhs_kind is None:
return None
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind is None:
# quantize lhs to INPUT field
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
# quantize rhs to WEIGHT field
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError
@register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None and rhs_kind is None:
return None
if lhs_kind is None and rhs_kind is not None:
# quantize lhs to INPUT field if it is normal expression
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
if lhs_kind is not None and rhs_kind is None:
if isinstance(rhs_expr, _expr.Constant):
# quantize rhs to WEIGHT field if it is Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
def identity_rewrite(ref_call, new_args, ctx):
"""Simply forward the original operation"""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None
x_expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
return None
ret_expr = _forward_op(ref_call, [x_expr])
return QAnnotateExpr(ret_expr, x_kind)
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
register_annotate_function("nn.avg_pool2d", identity_rewrite)
def pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool2d"""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None
expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
return None
if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
register_annotate_function("nn.max_pool2d", pool2d_rewrite)
@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate"""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None
input_tuple = new_args[0]
expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
kind_list = [_get_expr_kind(x)[1] for x in input_tuple]
# make sure the inputs of concatenate are all normal
# expression or annotate expression
if kind_list[0] is None:
for k in kind_list:
assert k is None
return None
for k in kind_list:
assert k is not None
expr = _forward_op(ref_call, [_expr.Tuple(expr_list)])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
#pylint: disable=unused-argument
"""Internal module for quantization."""
from __future__ import absolute_import
from tvm._ffi.function import _init_api
_init_api("relay._quantize", __name__)
#pylint: disable=unused-argument
"""Automatic quantization toolkit."""
from __future__ import absolute_import
import numpy as np
from . import _quantize
from .. import expr as _expr
from .. import ir_pass as _ir_pass
from .. import build_module as _build
from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
class QAnnotateKind(object):
"""Denote the kind of annotation field, corresponding
to different nbit configure."""
INPUT = 1
WEIGHT = 2
ACTIVATION = 3
def kind2str(kind):
"""Convert a `QAnnotateKind` to string"""
str_map = {
QAnnotateKind.INPUT: "input",
QAnnotateKind.WEIGHT: "weight",
QAnnotateKind.ACTIVATION: "activation",
}
assert kind in str_map
return str_map[kind]
@register_relay_node("relay.quantize.QConfig")
class QConfig(NodeBase):
"""Configure the quantization behavior by setting config variables.
Note
----
This object is backed by node system in C++, with arguments that can be
exchanged between python and C++.
Do not construct directly, use qconfig instead.
The fields that are backed by the C++ node are immutable once an instance
is constructed. See _node_defaults for the fields.
"""
_node_defaults = {
"nbit_input": 8,
"nbit_weight": 8,
"nbit_activation": 32,
"dtype_input": "int8",
"dtype_weight": "int8",
"dtype_activation": "int32",
"global_scale": 8.0,
"skip_k_conv": 1,
"round_for_shift": True,
"store_lowbit_output": True,
"debug_enabled_ops": None,
}
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
super(QConfig, self).__init__(handle)
self.handle = handle
def guard(self, ref_call):
op_name = ref_call.op.name
if self.debug_enabled_ops is not None:
name_list = [x.value for x in self.debug_enabled_ops]
if op_name not in name_list:
return False
return True
def get_nbit_by_kind(self, kind):
name = kind2str(kind)
return getattr(self, 'nbit_' + name)
def get_dtype_by_kind(self, kind):
name = kind2str(kind)
return getattr(self, 'dtype_' + name)
def __enter__(self):
# pylint: disable=protected-access
_quantize._EnterQConfigScope(self)
return self
def __exit__(self, ptype, value, trace):
_quantize._ExitQConfigScope(self)
def __setattr__(self, name, value):
if name in QConfig._node_defaults:
raise AttributeError(
"'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(QConfig, self).__setattr__(name, value)
def current_qconfig():
"""Get the current quantization configuration."""
return _quantize._GetCurrentQConfig()
def qconfig(**kwargs):
"""Configure the quantization behavior by setting config variables.
Parameters
---------
nbit_dict: dict of QAnnotateKind -> int
Number of bit for every kind of annotate field.
global_scale: float
The global scale for calibration.
skip_k_conv: int
The number of skipped conv2d.
round_for_shift: boolean
Whether to add bias for rounding during shift.
store_lowbit_output: boolean
Whether to store low-bit integer back as output before dequantizing.
Some accelerators need this, e.g. VTA.
Returns
-------
config: QConfig
The quantization configuration
"""
node_args = {k: v if k not in kwargs else kwargs[k]
for k, v in QConfig._node_defaults.items()}
return _make.node("relay.quantize.QConfig", **node_args)
CONV_COUNTER = 0
def _conv_counter():
"""Get the global counter for conv2d."""
return CONV_COUNTER
def _set_conv_counter(n):
"""Set the value of the global conv2d counter."""
global CONV_COUNTER
CONV_COUNTER = n
def annotate(graph):
"""Given a float32 graph, annotate will rewrite the graph
and return back a graph which simulates the error brought by
current quantization scheme.
Parameters
---------
graph: Function
The original graph
Returns
-------
ret: Function
The graph after annotation
"""
_set_conv_counter(0) # reset counter
return _quantize.annotate(graph)
def calibrate(graph, dataset=None):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
Parameters
---------
graph: Function
The simulation graph after annotation.
dataset: list of dict of Var -> NDArray
The calibration dataset.
Returns
-------
ret: Function
The graph after calibration
"""
def power2_scale(arr):
"""calculate weight scale with nearest mode-2 scale"""
val = np.amax(np.abs(arr.asnumpy()))
return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0
cfg = current_qconfig()
const_params = {}
quantize_op = _op.get("relay.op.annotation.simulated_quantize")
def visit_func(expr):
"""Internal visit function"""
if isinstance(expr, _expr.Call) and expr.op == quantize_op:
_, ndom_scale, nclip_min, nclip_max = expr.args
attrs = expr.attrs
kind = attrs.kind
nbit = cfg.get_nbit_by_kind(kind)
valid_bit = nbit - attrs.sign
if kind == QAnnotateKind.WEIGHT:
var = expr.args[0]
assert isinstance(var, _expr.Constant)
scale = power2_scale(var.data)
else:
scale = cfg.global_scale
def _make_const(val):
return _expr.const(val, 'float32')
valid_range = 2**valid_bit
const_params[ndom_scale] = _make_const(scale / valid_range)
const_params[nclip_min] = _make_const(- (valid_range - 1))
const_params[nclip_max] = _make_const((valid_range - 1))
_ir_pass.post_order_visit(graph, visit_func)
return _expr.bind(graph, const_params)
def realize(graph):
"""The realize pass will transform the simulated quantized
graph, which computes with float32 actually, to a real low-bit
integer graph. It will replace the simulated_quantize with
several fine-grained operators like add, multiply, and shift
as more as possible for performance (fusion, etc.)
Parameters
---------
graph: Function
The simulated graph after calibrating.
Returns
-------
ret: Function
The graph after realization
"""
return _quantize.realize(graph)
def quantize(graph, params=None, dataset=None):
""" The quantization procedure. Before running the three main
procedure of quantization, "annotate", "calibrate" and "realize"
, we need to do "SimplifyInference", "FoldScaleAxis", "FoldConstant"
first for optimizing.
Parameters
---------
graph: Function
The original graph.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
dataset: list of dict of Var -> NDArray
The calibration dataset.
Returns
-------
ret: Function
The graph after quantization
"""
opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]
with _build.build_config(add_pass=opt_passes):
graph = _build.optimize(graph, params=params)
graph = annotate(graph)
graph = calibrate(graph, dataset)
graph = realize(graph)
graph = _ir_pass.fold_constant(graph)
return graph
...@@ -228,11 +228,11 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { ...@@ -228,11 +228,11 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
void ExprVisitor::VisitType(const Type& t) { return; } void ExprVisitor::VisitType(const Type& t) { return; }
// visitor to implement apply // visitor to implement apply
class ExprApplyVisit : public ExprVisitor { class ExprApplyVisit : public ExprVisitor {
public: public:
explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {} explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {}
void VisitExpr(const Expr& e) final { void VisitExpr(const Expr& e) final {
if (visited_.count(e.get()) != 0) return; if (visited_.count(e.get()) != 0) return;
visited_.insert(e.get()); visited_.insert(e.get());
...@@ -257,7 +257,6 @@ TVM_REGISTER_API("relay._ir_pass.post_order_visit") ...@@ -257,7 +257,6 @@ TVM_REGISTER_API("relay._ir_pass.post_order_visit")
}); });
}); });
// Implement bind. // Implement bind.
class ExprBinder : public ExprMutator { class ExprBinder : public ExprMutator {
public: public:
......
...@@ -1601,7 +1601,6 @@ RELAY_REGISTER_OP("slice_like") ...@@ -1601,7 +1601,6 @@ RELAY_REGISTER_OP("slice_like")
.set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute) .set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
// relay.layout_transform // relay.layout_transform
Array<Tensor> LayoutTransformCompute(const Attrs& attrs, Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
const Array<Tensor>& inputs, const Array<Tensor>& inputs,
......
...@@ -103,8 +103,10 @@ This function takes a tensor, a minimum value `a_min`, and a maximum value `a_ma ...@@ -103,8 +103,10 @@ This function takes a tensor, a minimum value `a_min`, and a maximum value `a_ma
.set_attr<TOpPattern>("TOpPattern", kElemWise) .set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<TOpIsStateful>("TOpIsStateful", false) .set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout) .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attrs_type_key("relay.attrs.ClipAttrs")
.set_support_level(3); .set_support_level(3);
RELAY_REGISTER_UNARY_OP("floor") RELAY_REGISTER_UNARY_OP("floor")
.describe(R"code(Returns the floor of input array, computed element-wise. .describe(R"code(Returns the floor of input array, computed element-wise.
)code" TVM_ADD_FILELINE) )code" TVM_ADD_FILELINE)
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <string> #include <string>
...@@ -20,6 +21,45 @@ namespace tvm { ...@@ -20,6 +21,45 @@ namespace tvm {
namespace relay { namespace relay {
/*! /*!
* \brief Dispatch DataType to the C++ data type
* during runtime.
*/
#define TVM_DTYPE_DISPATCH(type, DType, ...) \
if (type == Float(64)) { \
typedef double DType; \
{__VA_ARGS__} \
} else if (type == Float(32)) { \
typedef float DType; \
{__VA_ARGS__} \
} else if (type == Int(64)) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else if (type == Int(32)) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if (type == Int(16)) { \
typedef int16_t DType; \
{__VA_ARGS__} \
} else if (type == Int(8)) { \
typedef int8_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(64)) { \
typedef uint64_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(32)) { \
typedef uint32_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(16)) { \
typedef uint16_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(8)) { \
typedef uint8_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "unknown data type " << type; \
}
/*!
* \brief Try to match lhs and rhs via broadcasting rule, such that: * \brief Try to match lhs and rhs via broadcasting rule, such that:
* *
* rhs matches the dimension of lhs specified by lhs_axes * rhs matches the dimension of lhs specified by lhs_axes
...@@ -145,9 +185,10 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { ...@@ -145,9 +185,10 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
*/ */
template<typename T> template<typename T>
inline Constant MakeConstantScalar(DataType dtype, T value) { inline Constant MakeConstantScalar(DataType dtype, T value) {
CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch";
runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0}); runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0});
*static_cast<T*>(arr->data) = value; TVM_DTYPE_DISPATCH(dtype, DType, {
*static_cast<DType*>(arr->data) = value;
})
return ConstantNode::make(arr); return ConstantNode::make(arr);
} }
...@@ -168,6 +209,25 @@ inline Expr Log(Expr e) { ...@@ -168,6 +209,25 @@ inline Expr Log(Expr e) {
static const Op& op = Op::Get("log"); static const Op& op = Op::Get("log");
return CallNode::make(op, {e}); return CallNode::make(op, {e});
} }
/*!
* \brief Get an immediate scalar from a Constant expr.
*
* \param expr The Constant expr.
* \return A scalar with type T.
*/
template <typename T>
T GetScalarFromConstant(Expr expr) {
const auto* n = expr.as<ConstantNode>();
CHECK(n->is_scalar());
return static_cast<T*>(n->data->data)[0];
}
inline Expr Cast(Expr x, DataType dtype) {
static const Op& op = Op::Get("cast");
auto attrs = make_node<CastAttrs>();
attrs->dtype = dtype;
return CallNode::make(op, {x}, Attrs(attrs), {});
}
inline Expr Negative(Expr x) { inline Expr Negative(Expr x) {
static const Op& op = Op::Get("negative"); static const Op& op = Op::Get("negative");
...@@ -181,12 +241,39 @@ inline Expr Sqrt(Expr x) { ...@@ -181,12 +241,39 @@ inline Expr Sqrt(Expr x) {
} }
inline Expr Relu(Expr x) {
static const Op& op = Op::Get("nn.relu");
return CallNode::make(op, {x}, Attrs(), {});
}
inline Expr Round(Expr x) {
static const Op& op = Op::Get("round");
return CallNode::make(op, {x}, Attrs(), {});
}
inline Expr Clip(Expr x, double a_min, double a_max) {
static const Op& op = Op::Get("clip");
auto attrs = make_node<ClipAttrs>();
attrs->a_min = a_min;
attrs->a_max = a_max;
return CallNode::make(op, {x}, Attrs(attrs), {});
}
inline Expr Add(Expr lhs, Expr rhs) { inline Expr Add(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add"); static const Op& op = Op::Get("add");
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); return CallNode::make(op, {lhs, rhs}, Attrs(), {});
} }
inline Expr Substract(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("subtract");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr Multiply(Expr lhs, Expr rhs) { inline Expr Multiply(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("multiply"); static const Op& op = Op::Get("multiply");
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); return CallNode::make(op, {lhs, rhs}, Attrs(), {});
...@@ -208,6 +295,24 @@ inline Expr OneLike(Expr e) { ...@@ -208,6 +295,24 @@ inline Expr OneLike(Expr e) {
return CallNode::make(op, {e}); return CallNode::make(op, {e});
} }
inline Expr Power(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("power");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr RightShift(Expr x, Expr nbit) {
static const Op& op = Op::Get("right_shift");
return CallNode::make(op, {x, nbit}, Attrs(), {});
}
inline Expr LeftShift(Expr x, Expr nbit) {
static const Op& op = Op::Get("left_shift");
return CallNode::make(op, {x, nbit}, Attrs(), {});
}
inline Expr ReshapeLike(Expr lhs, Expr rhs) { inline Expr ReshapeLike(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("reshape_like"); static const Op& op = Op::Get("reshape_like");
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); return CallNode::make(op, {lhs, rhs}, Attrs(), {});
......
/*!
* Copyright (c) 2018 by Contributors
*
* \file quantize.cc
*
* \brief transform a graph to a low-bit graph
* for compression and acceleration.
*/
#include <dmlc/thread_local.h>
#include <tvm/base.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <cmath>
#include <string>
#include <vector>
#include <stack>
#include "pattern_util.h"
#include "quantize.h"
namespace tvm {
namespace relay {
namespace quantize {
/*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind;
bool sign;
std::string rounding;
TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
TVM_ATTR_FIELD(kind)
.describe("kind of field, hint for nbit/dtype configuration.");
TVM_ATTR_FIELD(sign).set_default(true)
.describe("whether to use signed data type.");
TVM_ATTR_FIELD(rounding).set_default("round")
.describe("rounding mode. Can be 'floor', 'ceil', 'round'");
}
};
TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);
bool SimulatedQuantizeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 5);
const auto param = attrs.as<SimulatedQuantizeAttrs>();
CHECK(param != nullptr);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
reporter->Assign(types[1], TensorTypeNode::make({}, Float(32))); // dom_scale
reporter->Assign(types[2], TensorTypeNode::make({}, Float(32))); // clip_min
reporter->Assign(types[3], TensorTypeNode::make({}, Float(32))); // clip_max
reporter->Assign(types[4], types[0]); // output
return true;
}
RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE)
.set_num_inputs(4)
.add_argument("data", "Tensor", "The input data.")
.add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar")
.add_argument("clip_min", "Tensor", "lower bound. It should be a scalar")
.add_argument("clip_max", "Tensor", "upper bound. It should be a scalar")
.set_attrs_type_key("relay.attrs.SimulatedQuantizeAttrs")
.set_support_level(10)
.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel);
TVM_REGISTER_API("relay._quantize.simulated_quantize")
.set_body_typed<Expr(Expr, Expr, Expr, Expr, int, bool, std::string)>(
[](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max,
int kind, bool sign, std::string rounding) {
auto attrs = make_node<SimulatedQuantizeAttrs>();
attrs->kind = kind;
attrs->sign = sign;
attrs->rounding = rounding;
static const Op& op = Op::Get("relay.op.annotation.simulated_quantize");
return CallNode::make(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {});
});
// =============
// annotate pass
Expr QAnnotateExprNode::Realize() const {
const auto& cfg = QConfig::Current();
if (cfg->store_lowbit_output) {
// store low bit output back for VTA
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
return (*f)(this->expr, static_cast<int>(kQInput));
} else {
return expr;
}
}
QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) {
auto rnode = make_node<QAnnotateExprNode>();
rnode->expr = expr;
rnode->kind = kind;
return QAnnotateExpr(rnode);
}
TVM_REGISTER_API("relay._quantize.make_annotate_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QAnnotateExprNode::make(args[0],
static_cast<QAnnotateKind>(args[1].operator int()));
});
TVM_REGISTER_API("relay._quantize.annotate")
.set_body_typed<Expr(Expr)>([] (const Expr& expr) {
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr);
});
// =============
// realize pass
Expr QRealizeIntExprNode::Realize() const {
const auto& cfg = QConfig::Current();
Expr data = this->data;
if (cfg->store_lowbit_output) {
data = Cast(data, cfg->dtype_input);
}
// dequantize
data = Cast(data, Float(32));
data = Multiply(data, this->dom_scale);
return data;
}
QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) {
NodePtr<QRealizeIntExprNode> n = make_node<QRealizeIntExprNode>();
n->data = std::move(data);
n->dom_scale = std::move(dom_scale);
n->dtype = std::move(dtype);
return QRealizeIntExpr(n);
}
inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
return CallNode::make(ref_call->op,
args, ref_call->attrs, ref_call->type_args);
}
/* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2) {
// here we assume the dtype of data is dtype activation
const QConfig& cfg = QConfig::Current();
if (s1 == s2) return data;
float factor = s1 / s2;
float shift_factor = std::log2(factor);
CHECK_GT(shift_factor, 0);
if (static_cast<int>(shift_factor) == shift_factor) {
return LeftShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_factor)));
} else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(cfg->dtype_activation, factor));
} else {
LOG(FATAL) << "fall back to float computation";
data = Cast(data, Float(32));
return Multiply(data, MakeConstantScalar(Float(32), factor));
}
}
Expr QuantizeRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
// do not handle data type cast
const auto param = ref_call->attrs.as<SimulatedQuantizeAttrs>();
CHECK_EQ(param->rounding, "round");
Expr dom_scale = new_args[1];
Expr clip_min = new_args[2];
Expr clip_max = new_args[3];
float dom_scale_imm = GetScalarFromConstant<float>(dom_scale);
float clip_min_imm = GetScalarFromConstant<float>(clip_min);
float clip_max_imm = GetScalarFromConstant<float>(clip_max);
// x * idom_scale = y * odom_scale
// => y = x * idom_scale / odom_scale
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr data = n->data;
float idom_scale_imm = GetScalarFromConstant<float>(n->dom_scale);
float odom_scale_imm = GetScalarFromConstant<float>(dom_scale);
float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
// int32->int8
CHECK_GT(shift_nbit, 0);
if (static_cast<int>(shift_nbit) == shift_nbit) {
// use shift
if (cfg->round_for_shift) {
float round_bias = std::pow(2.0, shift_nbit - 1);
data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias)));
}
data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} else {
// float computation
data = Cast(data, Float(32));
Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
}
}
// quantize from real
CHECK(!new_args[0]->derived_from<TempExprNode>());
Expr data = new_args[0];
Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(round_data, dom_scale, Float(32));
}
RELAY_REGISTER_OP("simulated_quantize")
.set_attr<FForwardRewrite>("FQRealizeRewrite", QuantizeRealize);
Expr Conv2dRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>()) {
return Expr(nullptr);
}
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
CHECK(lhs);
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
CHECK(rhs);
Expr ldata = lhs->data;
if (lhs->dtype != cfg->dtype_input) {
ldata = Cast(ldata, cfg->dtype_input);
}
Expr rdata = Cast(rhs->data, cfg->dtype_weight);
const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
auto attrs = make_node<Conv2DAttrs>();
*attrs = *ref_attrs;
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = CallNode::make(ref_call->op,
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
}
RELAY_REGISTER_OP("nn.conv2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", Conv2dRealize);
Expr MulRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
// execute the operation with activation data type.
const auto* lhs = new_args[0].as<QRealizeIntExprNode>();
const auto* rhs = new_args[1].as<QRealizeIntExprNode>();
Expr ldata = lhs->data;
Expr rdata = rhs->data;
DataType dtype = cfg->dtype_activation;
if (lhs->dtype == Float(32)) {
ldata = Cast(ldata, dtype);
} else {
CHECK_EQ(lhs->dtype, dtype);
}
if (rhs->dtype == Float(32)) {
rdata = Cast(rdata, dtype);
} else {
CHECK_EQ(rhs->dtype, dtype);
}
Expr ret = ForwardOp(ref_call, {ldata, rdata});
Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("multiply")
.set_attr<FForwardRewrite>("FQRealizeRewrite", MulRealize);
float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {
if (nptrs.size() == 2) {
// x = a * s1, y = b * s2
// x + y = (a * s1 / s2 + b) * s2, if s1 > s2
// = (a + b * s2 / s1) * s1, if s2 > s1
float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale);
float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale);
return s1 > s2 ? s2 : s1;
} else {
const QConfig& cfg = QConfig::Current();
float scale = cfg->global_scale;
return scale / std::pow(2.0, cfg->nbit_activation - 1);
}
}
/* \brief Unify the dom scale of arguments */
Array<Expr> UnifyDTypeScale(const Array<Expr>& args,
DataType* dtype_ptr,
Expr* scale_ptr) {
const QConfig& cfg = QConfig::Current();
std::vector<const QRealizeIntExprNode*> nptrs;
Array<Expr> ret;
for (auto arg : args) {
const auto* nptr = arg.as<QRealizeIntExprNode>();
CHECK(nptr);
nptrs.push_back(nptr);
ret.push_back(nptr->data);
}
// unify the data type
DataType dtype = cfg->dtype_activation;
for (size_t i = 0; i < ret.size(); ++i) {
if (nptrs[i]->dtype != dtype) {
ret.Set(i, Cast(ret[i], dtype));
}
}
// unify the dom_scale
float s = ChooseDomScale(nptrs);
Expr dom_scale = MakeConstantScalar(Float(32), s);
for (size_t i = 0; i < ret.size(); ++i) {
float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
LOG(INFO) << "unify data scale from " << cur_s << " to " << s;
ret.Set(i, MulAndDiv(ret[i], cur_s, s));
}
*dtype_ptr = dtype;
*scale_ptr = dom_scale;
return ret;
}
Expr AddRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 2);
if (new_args[0].as<QRealizeIntExprNode>() && new_args[1].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(new_args, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("add")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AddRealize);
Expr ConcatenateRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
const auto* tuple = new_args[0].as<TupleNode>();
CHECK(tuple);
const Array<Expr>& arr = tuple->fields;
if (arr[0].as<QRealizeIntExprNode>()) {
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(arr, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
} else {
for (auto arg : new_args) {
CHECK(!arg->derived_from<TempExprNode>());
}
return Expr(nullptr);
}
}
RELAY_REGISTER_OP("concatenate")
.set_attr<FForwardRewrite>("FQRealizeRewrite", ConcatenateRealize);
/* \brief forward the original operator */
Expr IdentityRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = ForwardOp(ref_call, {n->data});
return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.relu")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
RELAY_REGISTER_OP("strided_slice")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);
Expr MaxPoolRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr data = Cast(n->data, cfg->dtype_input);
Expr ret = ForwardOp(ref_call, {data});
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.max_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", MaxPoolRealize);
Expr AvgPoolRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
const QConfig& cfg = QConfig::Current();
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr data = n->data;
if (n->dtype != cfg->dtype_activation) {
data = Cast(n->data, cfg->dtype_activation);
}
Expr ret = ForwardOp(ref_call, {data});
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation);
}
CHECK(!new_args[0]->derived_from<TempExprNode>());
return Expr(nullptr);
}
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
TVM_REGISTER_API("relay._quantize.realize")
.set_body_typed<Expr(Expr)>([](const Expr& e) {
Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr);
return ret;
});
// =============
// qconfig
QConfig qconfig() {
return QConfig(make_node<QConfigNode>());
}
/*! \brief Entry to hold the BuildConfig context stack. */
struct TVMQConfigThreadLocalEntry {
/*! \brief The default build config if the stack is empty */
QConfig default_config;
/*! \brief The current build config context */
std::stack<QConfig> context_stack;
TVMQConfigThreadLocalEntry() :
default_config(qconfig()) {
}
};
/*! \brief Thread local store to hold the BuildConfig context stack. */
typedef dmlc::ThreadLocalStore<TVMQConfigThreadLocalEntry> TVMQConfigThreadLocalStore;
void QConfig::EnterQConfigScope(const QConfig& build_config) {
TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get();
entry->context_stack.push(build_config);
}
void QConfig::ExitQConfigScope() {
TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get();
entry->context_stack.pop();
}
QConfig QConfig::Current() {
TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
return entry->default_config;
}
TVM_REGISTER_NODE_TYPE(QConfigNode);
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<QConfigNode>([](const QConfigNode *op, IRPrinter *p) {
p->stream << "qconfig(";
p->stream << "nbit_input=" << op->nbit_input << ", ";
p->stream << "nbit_weight=" << op->nbit_weight << ", ";
p->stream << "nbit_activation=" << op->nbit_activation << ", ";
p->stream << "global_scale=" << op->global_scale << ", ";
p->stream << "skip_k_conv==" << op->skip_k_conv << ", ";
p->stream << "round_for_shift==" << op->round_for_shift << ", ";
p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", ";
p->stream << "debug_enabled_ops==" << op->debug_enabled_ops;
p->stream << ")";
});
TVM_REGISTER_API("relay._quantize._GetCurrentQConfig")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = QConfig::Current();
});
TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
QConfig target = args[0];
QConfig::EnterQConfigScope(target);
});
TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
QConfig::ExitQConfigScope();
});
} // namespace quantize
} // namespace relay
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors.
*
* \file tvm/relay/pass/quantize.h
* \brief Header of definitions for quantization
*/
#ifndef TVM_RELAY_PASS_QUANTIZE_H_
#define TVM_RELAY_PASS_QUANTIZE_H_
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <string>
#include "pattern_util.h"
namespace tvm {
namespace relay {
namespace quantize {
/*! \brief Kind of annotate field */
enum QAnnotateKind : int {
kQInput = 1,
kQWeight = 2,
kQActivation = 3,
};
/*!
* \brief TempExpr used during annotate forward rewrite.
*/
class QAnnotateExpr;
/*!
* \brief TempExprNode used during annotate forward rewrite.
*/
class QAnnotateExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr expr;
/*! \brief The kind of annotate field */
QAnnotateKind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
v->Visit("kind", &kind);
}
TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QAnnotateExpr";
TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr);
/*! \brief TempExpr used during realize forward rewrite. */
class QRealizeExpr;
/*! \brief TempExpr representing integer. */
class QRealizeIntExpr;
class QRealizeExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr data;
static constexpr const char* _type_key = "relay.quantize.QRealizeExpr";
TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr);
class QRealizeIntExprNode : public QRealizeExprNode {
public:
Expr dom_scale;
/*! \brief current data type */
DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("dom_scale", &dom_scale);
v->Visit("dtype", &dtype);
}
Expr Realize() const final;
TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype);
static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr";
TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode);
};
RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr);
class QConfig;
/*!
* \brief Container for build configuration options
*/
class QConfigNode : public Node {
public:
int nbit_input = 8;
int nbit_weight = 8;
int nbit_activation = 32;
DataType dtype_input = Int(8);
DataType dtype_weight = Int(8);
DataType dtype_activation = Int(32);
double global_scale = 8.0;
int skip_k_conv = 1;
bool round_for_shift = true;
bool store_lowbit_output = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
void VisitAttrs(AttrVisitor* v) final {
v->Visit("nbit_input", &nbit_input);
v->Visit("nbit_weight", &nbit_weight);
v->Visit("nbit_activation", &nbit_activation);
v->Visit("dtype_input", &dtype_input);
v->Visit("dtype_weight", &dtype_weight);
v->Visit("dtype_activation", &dtype_activation);
v->Visit("global_scale", &global_scale);
v->Visit("skip_k_conv", &skip_k_conv);
v->Visit("round_for_shift", &round_for_shift);
v->Visit("store_lowbit_output", &store_lowbit_output);
v->Visit("debug_enabled_ops", &debug_enabled_ops);
}
static constexpr const char* _type_key = "relay.quantize.QConfig";
TVM_DECLARE_NODE_TYPE_INFO(QConfigNode, Node);
};
/*!
* \brief Container for build configuration options
*/
class QConfig : public NodeRef {
public:
QConfig() {}
explicit QConfig(NodePtr<Node> n) : NodeRef(n) {}
const QConfigNode* operator->() const {
return static_cast<const QConfigNode*>(node_.get());
}
QConfigNode* operator->() {
return static_cast<QConfigNode*>(node_.get());
}
/*!
* \brief Push a new BuildConfig context onto the thread local stack.
* \param build_config The configuration to set as the current context.
*/
static void EnterQConfigScope(const QConfig& qconfig);
/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
static void ExitQConfigScope();
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
*/
static QConfig Current();
using ContainerType = QConfigNode;
};
/*!
* \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
* context stack when constructed, and pops it when destructed.
*/
struct QConfigContext {
/*!
* \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
* context. When the BuildConfigContext is destructed, the previous context is restored.
* \param build_config The BuildConfig to set as the new current context.
*/
explicit QConfigContext(const QConfig& qconfig) {
QConfig::EnterQConfigScope(qconfig);
}
/*! \brief Destructor. Pops the context off the thread local stack. */
~QConfigContext() {
QConfig::ExitQConfigScope();
}
};
/*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \return The new BuildConfig
*/
TVM_DLL QConfig qconfig();
} // namespace quantize
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_QUANTIZE_H_
import math
import numpy as np
import tvm
from tvm import relay
from tvm.relay import quantize as qtz
def make_dataset(graph, size=100):
args = relay.ir_pass.infer_type(graph).params
def create_arr(var):
ttype = var.type_annotation
np_arr = np.random.uniform(-1.0, 1.0, size=ttype.concrete_shape).astype(ttype.dtype)
return tvm.ndarray.array(np_arr)
params = {}
for arg in args:
if arg.name_hint == 'data':
dataset = [{'data': create_arr(arg)} for _ in range(size)]
else:
params[arg.name_hint] = create_arr(arg)
return dataset, params
def test_simulated_quantize():
data = relay.var("data", relay.ty.TensorType((3, 4, 5, 6), "float32"))
out = qtz._annotate.attach_simulated_quantize(data, 1)
out = relay.ir_pass.infer_type(out)
assert out.checked_type == out.args[0].checked_type
assert out.args[1].checked_type == relay.ty.TensorType(tuple(), "float32")
assert out.args[2].checked_type == relay.ty.TensorType(tuple(), "float32")
assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")
def test_quantize_pass():
def quantize_weight(arr):
maximum = np.amax(np.abs(arr.asnumpy()))
scale = 2**math.ceil(math.log(maximum, 2))
out = np.around(arr.asnumpy() / scale * 128).astype('int8')
out = np.clip(out, -127, 127)
return relay.const(out, 'int8')
n, c, h, w = 1, 3, 224, 224
def make_graph(data):
weight = relay.var("conv_weight")
out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
out = relay.Function(relay.ir_pass.free_vars(out), out)
return out
def make_qgraph(data, weight):
out = data * relay.const(32.0)
out = relay.round(out)
out = relay.clip(out, a_min=-127, a_max=127)
out = out.astype('int8')
out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
padding=(1, 1), channels=c, out_dtype='int32')
out = out.astype('float32')
out = relay.multiply(out, relay.const(0.00024414062))
out = relay.Function(relay.ir_pass.free_vars(out), out)
return out
data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
graph = make_graph(data)
dataset, params = make_dataset(graph, 10)
with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
round_for_shift=False, store_lowbit_output=False):
qgraph0 = qtz.quantize(graph, params)
qgraph0 = relay.ir_pass.infer_type(qgraph0)
conv_weight = quantize_weight(params['conv_weight'])
qgraph1 = make_qgraph(data, conv_weight)
qgraph1 = relay.ir_pass.infer_type(qgraph1)
graph = relay.create_executor('graph')
res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy())
if __name__ == "__main__":
np.random.seed(42)
test_simulated_quantize()
test_quantize_pass()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment