Unverified Commit 741b6bbe by ziheng Committed by GitHub

[OPT] Low-bit Quantization (#2116)

* [QUANTIZE] Quantization implementation.

* Update.

* Update.

* Update.

* Update.
parent da972bdf
......@@ -139,7 +139,6 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
}
};
struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
Array<Integer> axes;
......@@ -151,16 +150,16 @@ struct SliceLikeAttrs : public tvm::AttrsNode<SliceLikeAttrs> {
}
};
// Clip
/*! \brief Attributes for Clip operator */
struct ClipAttrs : public tvm::AttrsNode<ClipAttrs> {
double a_min;
double a_max;
TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") {
TVM_ATTR_FIELD(a_min)
.describe("The minimum clip value.");
TVM_ATTR_FIELD(a_max)
.describe("The maximum clip value.");
TVM_ATTR_FIELD(a_min)
.describe("The minimum clip value.");
TVM_ATTR_FIELD(a_max)
.describe("The maximum clip value.");
}
};
......
......@@ -551,6 +551,7 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
return map_.get<ValueType>(expr, def_value);
}
/*!
* \brief Check that an expression is a "primtive operator".
*
......
......@@ -8,7 +8,7 @@ from . import expr
from . import expr_functor
from . import module
from . import ir_pass
from .build_module import build, build_config, create_executor
from .build_module import build, build_config, create_executor, optimize
from . import parser
from . import debug
......@@ -23,6 +23,7 @@ from . import vision
from . import image
from . import frontend
from . import backend
from . import quantize
from .scope_builder import ScopeBuilder
......
......@@ -129,7 +129,7 @@ def _bind_params_by_name(func, params):
return expr.bind(func, bind_dict)
def optimize(func, target, params=None):
def optimize(func, target=None, params=None):
"""Perform target invariant optimizations.
Parameters
......@@ -400,7 +400,7 @@ class GraphExecutor(_interpreter.Executor):
graph_json, mod, params = build(func, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params:
gmodule.set_input(*params)
gmodule.set_input(**params)
def _graph_wrapper(*args, **kwargs):
args = self._convert_args(func, args, kwargs)
......
#pylint: disable=wildcard-import, redefined-builtin
"""Automatic quantization utilities."""
from __future__ import absolute_import as _abs
from .quantize import *
from ._annotate import register_annotate_function
#pylint: disable=unused-argument
"""Internal module for registering attribute for annotation."""
from __future__ import absolute_import
import topi
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig
from .quantize import _conv_counter, _set_conv_counter
from .. import expr as _expr
from .. import op as _op
from ..op import op as _reg
from ..base import register_relay_node
from ..._ffi.function import register_func
@_reg.register_compute("relay.op.annotation.simulated_quantize")
def simulated_quantize_compute(attrs, inputs, out_type, target):
"""Compiler for simulated_quantize."""
assert len(inputs) == 4
assert attrs.sign
assert attrs.rounding == "round"
data, scale, clip_min, clip_max = inputs
# simulate rounding error
scaled_data = topi.divide(data, scale)
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
round_data = topi.round(clipped_data)
# recover data
rdata = topi.multiply(round_data, scale)
return [rdata]
_reg.register_schedule("relay.op.annotation.simulated_quantize",
_reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.OpPattern.OPAQUE)
@register_relay_node
class QAnnotateExpr(_expr.TempExpr):
"""A special kind of Expr for Annotating.
Parameters
---------
expr: Expr
the original relay ir expr.
kind: QAnnotateKind
the kind of annotation field.
"""
def __init__(self, expr, kind):
self.__init_handle_by_constructor__(
_quantize.make_annotate_expr, expr, kind)
def _forward_op(ref_call, args):
"""forward the operator of ref_call with provided arguments"""
return _expr.Call(
ref_call.op, args, ref_call.attrs, ref_call.type_args)
def _get_expr_kind(anno):
"""Get the expression and QAnnotateKind from QAnnotateExpr or Expr"""
if isinstance(anno, QAnnotateExpr):
return anno.expr, anno.kind
return anno, None
def register_annotate_function(op_name, frewrite=None, level=10):
"""register a rewrite function for operator, used by annotation.
Parameters
---------
op_name: str
The name of operation
frewrite : function, optional
The function to be registered.
level : int, optional
The priority level
"""
def default_rewrite(ref_call, new_args, ctx):
# recover from QAnnotateExpr
args = [_get_expr_kind(x)[0] for x in new_args]
return _forward_op(ref_call, args)
def _register(func):
"""internal register function"""
def frewrite_with_guard(ref_call, new_args, ctx):
if not current_qconfig().guard(ref_call):
return default_rewrite(ref_call, new_args, ctx)
return func(ref_call, new_args, ctx)
_op.op._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
return frewrite_with_guard
return _register(frewrite) if frewrite is not None else _register
@register_func("relay.quantize.attach_simulated_quantize")
def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
"""Attach a simulated quantize operation after input data expr.
Parameters
---------
data: Expr
the original data expr.
kind: QAnnotateKind
the kind of annotation field.
"""
dom_scale = _expr.var("dom_scale")
clip_min = _expr.var("clip_min")
clip_max = _expr.var("clip_max")
return _quantize.simulated_quantize(
data, dom_scale, clip_min, clip_max, kind, sign, rounding)
@register_annotate_function("nn.conv2d")
def conv2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for conv2d. Lhs of conv will be quantized to
input field, and rhs of conv will be quantized to weight field.
Output would be in activation field"""
cnt = _conv_counter()
if cnt < current_qconfig().skip_k_conv:
_set_conv_counter(cnt + 1)
return None
_set_conv_counter(cnt + 1)
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
assert rhs_kind is None
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
"""Rewrite function for multiply."""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None and rhs_kind is None:
return None
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind is None:
# quantize lhs to INPUT field
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
# quantize rhs to WEIGHT field
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError
@register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx):
"""Rewrite function for add."""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None
lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
rhs_expr, rhs_kind = _get_expr_kind(new_args[1])
if lhs_kind is None and rhs_kind is None:
return None
if lhs_kind is None and rhs_kind is not None:
# quantize lhs to INPUT field if it is normal expression
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
if lhs_kind is not None and rhs_kind is None:
if isinstance(rhs_expr, _expr.Constant):
# quantize rhs to WEIGHT field if it is Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
else:
# quantize rhs to INPUT field if it is not Constant
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
def identity_rewrite(ref_call, new_args, ctx):
"""Simply forward the original operation"""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None
x_expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
return None
ret_expr = _forward_op(ref_call, [x_expr])
return QAnnotateExpr(ret_expr, x_kind)
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
register_annotate_function("nn.avg_pool2d", identity_rewrite)
def pool2d_rewrite(ref_call, new_args, ctx):
"""Rewrite function for max pool2d"""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None
expr, x_kind = _get_expr_kind(new_args[0])
if x_kind is None:
return None
if x_kind == QAnnotateKind.ACTIVATION:
expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT)
register_annotate_function("nn.max_pool2d", pool2d_rewrite)
@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
"""Rewrite function for concatenate"""
if _conv_counter() <= current_qconfig().skip_k_conv:
return None
input_tuple = new_args[0]
expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
kind_list = [_get_expr_kind(x)[1] for x in input_tuple]
# make sure the inputs of concatenate are all normal
# expression or annotate expression
if kind_list[0] is None:
for k in kind_list:
assert k is None
return None
for k in kind_list:
assert k is not None
expr = _forward_op(ref_call, [_expr.Tuple(expr_list)])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
#pylint: disable=unused-argument
"""Internal module for quantization."""
from __future__ import absolute_import
from tvm._ffi.function import _init_api
_init_api("relay._quantize", __name__)
#pylint: disable=unused-argument
"""Automatic quantization toolkit."""
from __future__ import absolute_import
import numpy as np
from . import _quantize
from .. import expr as _expr
from .. import ir_pass as _ir_pass
from .. import build_module as _build
from .. import op as _op
from ... import make as _make
from ..base import NodeBase, register_relay_node
class QAnnotateKind(object):
"""Denote the kind of annotation field, corresponding
to different nbit configure."""
INPUT = 1
WEIGHT = 2
ACTIVATION = 3
def kind2str(kind):
"""Convert a `QAnnotateKind` to string"""
str_map = {
QAnnotateKind.INPUT: "input",
QAnnotateKind.WEIGHT: "weight",
QAnnotateKind.ACTIVATION: "activation",
}
assert kind in str_map
return str_map[kind]
@register_relay_node("relay.quantize.QConfig")
class QConfig(NodeBase):
"""Configure the quantization behavior by setting config variables.
Note
----
This object is backed by node system in C++, with arguments that can be
exchanged between python and C++.
Do not construct directly, use qconfig instead.
The fields that are backed by the C++ node are immutable once an instance
is constructed. See _node_defaults for the fields.
"""
_node_defaults = {
"nbit_input": 8,
"nbit_weight": 8,
"nbit_activation": 32,
"dtype_input": "int8",
"dtype_weight": "int8",
"dtype_activation": "int32",
"global_scale": 8.0,
"skip_k_conv": 1,
"round_for_shift": True,
"store_lowbit_output": True,
"debug_enabled_ops": None,
}
# pylint: disable=no-member
def __init__(self, handle):
"""Initialize the function with handle
Parameters
----------
handle : SymbolHandle
the handle to the underlying C++ Symbol
"""
super(QConfig, self).__init__(handle)
self.handle = handle
def guard(self, ref_call):
op_name = ref_call.op.name
if self.debug_enabled_ops is not None:
name_list = [x.value for x in self.debug_enabled_ops]
if op_name not in name_list:
return False
return True
def get_nbit_by_kind(self, kind):
name = kind2str(kind)
return getattr(self, 'nbit_' + name)
def get_dtype_by_kind(self, kind):
name = kind2str(kind)
return getattr(self, 'dtype_' + name)
def __enter__(self):
# pylint: disable=protected-access
_quantize._EnterQConfigScope(self)
return self
def __exit__(self, ptype, value, trace):
_quantize._ExitQConfigScope(self)
def __setattr__(self, name, value):
if name in QConfig._node_defaults:
raise AttributeError(
"'%s' object cannot set attribute '%s'" % (str(type(self)), name))
return super(QConfig, self).__setattr__(name, value)
def current_qconfig():
"""Get the current quantization configuration."""
return _quantize._GetCurrentQConfig()
def qconfig(**kwargs):
"""Configure the quantization behavior by setting config variables.
Parameters
---------
nbit_dict: dict of QAnnotateKind -> int
Number of bit for every kind of annotate field.
global_scale: float
The global scale for calibration.
skip_k_conv: int
The number of skipped conv2d.
round_for_shift: boolean
Whether to add bias for rounding during shift.
store_lowbit_output: boolean
Whether to store low-bit integer back as output before dequantizing.
Some accelerators need this, e.g. VTA.
Returns
-------
config: QConfig
The quantization configuration
"""
node_args = {k: v if k not in kwargs else kwargs[k]
for k, v in QConfig._node_defaults.items()}
return _make.node("relay.quantize.QConfig", **node_args)
CONV_COUNTER = 0
def _conv_counter():
"""Get the global counter for conv2d."""
return CONV_COUNTER
def _set_conv_counter(n):
"""Set the value of the global conv2d counter."""
global CONV_COUNTER
CONV_COUNTER = n
def annotate(graph):
"""Given a float32 graph, annotate will rewrite the graph
and return back a graph which simulates the error brought by
current quantization scheme.
Parameters
---------
graph: Function
The original graph
Returns
-------
ret: Function
The graph after annotation
"""
_set_conv_counter(0) # reset counter
return _quantize.annotate(graph)
def calibrate(graph, dataset=None):
"""The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator.
Parameters
---------
graph: Function
The simulation graph after annotation.
dataset: list of dict of Var -> NDArray
The calibration dataset.
Returns
-------
ret: Function
The graph after calibration
"""
def power2_scale(arr):
"""calculate weight scale with nearest mode-2 scale"""
val = np.amax(np.abs(arr.asnumpy()))
return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0
cfg = current_qconfig()
const_params = {}
quantize_op = _op.get("relay.op.annotation.simulated_quantize")
def visit_func(expr):
"""Internal visit function"""
if isinstance(expr, _expr.Call) and expr.op == quantize_op:
_, ndom_scale, nclip_min, nclip_max = expr.args
attrs = expr.attrs
kind = attrs.kind
nbit = cfg.get_nbit_by_kind(kind)
valid_bit = nbit - attrs.sign
if kind == QAnnotateKind.WEIGHT:
var = expr.args[0]
assert isinstance(var, _expr.Constant)
scale = power2_scale(var.data)
else:
scale = cfg.global_scale
def _make_const(val):
return _expr.const(val, 'float32')
valid_range = 2**valid_bit
const_params[ndom_scale] = _make_const(scale / valid_range)
const_params[nclip_min] = _make_const(- (valid_range - 1))
const_params[nclip_max] = _make_const((valid_range - 1))
_ir_pass.post_order_visit(graph, visit_func)
return _expr.bind(graph, const_params)
def realize(graph):
"""The realize pass will transform the simulated quantized
graph, which computes with float32 actually, to a real low-bit
integer graph. It will replace the simulated_quantize with
several fine-grained operators like add, multiply, and shift
as more as possible for performance (fusion, etc.)
Parameters
---------
graph: Function
The simulated graph after calibrating.
Returns
-------
ret: Function
The graph after realization
"""
return _quantize.realize(graph)
def quantize(graph, params=None, dataset=None):
""" The quantization procedure. Before running the three main
procedure of quantization, "annotate", "calibrate" and "realize"
, we need to do "SimplifyInference", "FoldScaleAxis", "FoldConstant"
first for optimizing.
Parameters
---------
graph: Function
The original graph.
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
dataset: list of dict of Var -> NDArray
The calibration dataset.
Returns
-------
ret: Function
The graph after quantization
"""
opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]
with _build.build_config(add_pass=opt_passes):
graph = _build.optimize(graph, params=params)
graph = annotate(graph)
graph = calibrate(graph, dataset)
graph = realize(graph)
graph = _ir_pass.fold_constant(graph)
return graph
......@@ -228,11 +228,11 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
void ExprVisitor::VisitType(const Type& t) { return; }
// visitor to implement apply
class ExprApplyVisit : public ExprVisitor {
public:
explicit ExprApplyVisit(std::function<void(const Expr&)> f) : f_(f) {}
void VisitExpr(const Expr& e) final {
if (visited_.count(e.get()) != 0) return;
visited_.insert(e.get());
......@@ -257,7 +257,6 @@ TVM_REGISTER_API("relay._ir_pass.post_order_visit")
});
});
// Implement bind.
class ExprBinder : public ExprMutator {
public:
......
......@@ -1601,7 +1601,6 @@ RELAY_REGISTER_OP("slice_like")
.set_attr<FTVMCompute>("FTVMCompute", SliceLikeCompute)
.set_attr<TOpPattern>("TOpPattern", kInjective);
// relay.layout_transform
Array<Tensor> LayoutTransformCompute(const Attrs& attrs,
const Array<Tensor>& inputs,
......
......@@ -103,8 +103,10 @@ This function takes a tensor, a minimum value `a_min`, and a maximum value `a_ma
.set_attr<TOpPattern>("TOpPattern", kElemWise)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout)
.set_attrs_type_key("relay.attrs.ClipAttrs")
.set_support_level(3);
RELAY_REGISTER_UNARY_OP("floor")
.describe(R"code(Returns the floor of input array, computed element-wise.
)code" TVM_ADD_FILELINE)
......
......@@ -10,6 +10,7 @@
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/nn.h>
#include <string>
......@@ -20,6 +21,45 @@ namespace tvm {
namespace relay {
/*!
* \brief Dispatch DataType to the C++ data type
* during runtime.
*/
#define TVM_DTYPE_DISPATCH(type, DType, ...) \
if (type == Float(64)) { \
typedef double DType; \
{__VA_ARGS__} \
} else if (type == Float(32)) { \
typedef float DType; \
{__VA_ARGS__} \
} else if (type == Int(64)) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else if (type == Int(32)) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if (type == Int(16)) { \
typedef int16_t DType; \
{__VA_ARGS__} \
} else if (type == Int(8)) { \
typedef int8_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(64)) { \
typedef uint64_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(32)) { \
typedef uint32_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(16)) { \
typedef uint16_t DType; \
{__VA_ARGS__} \
} else if (type == UInt(8)) { \
typedef uint8_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "unknown data type " << type; \
}
/*!
* \brief Try to match lhs and rhs via broadcasting rule, such that:
*
* rhs matches the dimension of lhs specified by lhs_axes
......@@ -145,9 +185,10 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) {
*/
template<typename T>
inline Constant MakeConstantScalar(DataType dtype, T value) {
CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch";
runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0});
*static_cast<T*>(arr->data) = value;
TVM_DTYPE_DISPATCH(dtype, DType, {
*static_cast<DType*>(arr->data) = value;
})
return ConstantNode::make(arr);
}
......@@ -168,6 +209,25 @@ inline Expr Log(Expr e) {
static const Op& op = Op::Get("log");
return CallNode::make(op, {e});
}
/*!
* \brief Get an immediate scalar from a Constant expr.
*
* \param expr The Constant expr.
* \return A scalar with type T.
*/
template <typename T>
T GetScalarFromConstant(Expr expr) {
const auto* n = expr.as<ConstantNode>();
CHECK(n->is_scalar());
return static_cast<T*>(n->data->data)[0];
}
inline Expr Cast(Expr x, DataType dtype) {
static const Op& op = Op::Get("cast");
auto attrs = make_node<CastAttrs>();
attrs->dtype = dtype;
return CallNode::make(op, {x}, Attrs(attrs), {});
}
inline Expr Negative(Expr x) {
static const Op& op = Op::Get("negative");
......@@ -181,12 +241,39 @@ inline Expr Sqrt(Expr x) {
}
inline Expr Relu(Expr x) {
static const Op& op = Op::Get("nn.relu");
return CallNode::make(op, {x}, Attrs(), {});
}
inline Expr Round(Expr x) {
static const Op& op = Op::Get("round");
return CallNode::make(op, {x}, Attrs(), {});
}
inline Expr Clip(Expr x, double a_min, double a_max) {
static const Op& op = Op::Get("clip");
auto attrs = make_node<ClipAttrs>();
attrs->a_min = a_min;
attrs->a_max = a_max;
return CallNode::make(op, {x}, Attrs(attrs), {});
}
inline Expr Add(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr Substract(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("subtract");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr Multiply(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("multiply");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
......@@ -208,6 +295,24 @@ inline Expr OneLike(Expr e) {
return CallNode::make(op, {e});
}
inline Expr Power(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("power");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
}
inline Expr RightShift(Expr x, Expr nbit) {
static const Op& op = Op::Get("right_shift");
return CallNode::make(op, {x, nbit}, Attrs(), {});
}
inline Expr LeftShift(Expr x, Expr nbit) {
static const Op& op = Op::Get("left_shift");
return CallNode::make(op, {x, nbit}, Attrs(), {});
}
inline Expr ReshapeLike(Expr lhs, Expr rhs) {
static const Op& op = Op::Get("reshape_like");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
......
/*!
* Copyright (c) 2018 by Contributors.
*
* \file tvm/relay/pass/quantize.h
* \brief Header of definitions for quantization
*/
#ifndef TVM_RELAY_PASS_QUANTIZE_H_
#define TVM_RELAY_PASS_QUANTIZE_H_
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <string>
#include "pattern_util.h"
namespace tvm {
namespace relay {
namespace quantize {
/*! \brief Kind of annotate field */
enum QAnnotateKind : int {
kQInput = 1,
kQWeight = 2,
kQActivation = 3,
};
/*!
* \brief TempExpr used during annotate forward rewrite.
*/
class QAnnotateExpr;
/*!
* \brief TempExprNode used during annotate forward rewrite.
*/
class QAnnotateExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr expr;
/*! \brief The kind of annotate field */
QAnnotateKind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
v->Visit("kind", &kind);
}
TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QAnnotateExpr";
TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr);
/*! \brief TempExpr used during realize forward rewrite. */
class QRealizeExpr;
/*! \brief TempExpr representing integer. */
class QRealizeIntExpr;
class QRealizeExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr data;
static constexpr const char* _type_key = "relay.quantize.QRealizeExpr";
TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr);
class QRealizeIntExprNode : public QRealizeExprNode {
public:
Expr dom_scale;
/*! \brief current data type */
DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("dom_scale", &dom_scale);
v->Visit("dtype", &dtype);
}
Expr Realize() const final;
TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype);
static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr";
TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode);
};
RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr);
class QConfig;
/*!
* \brief Container for build configuration options
*/
class QConfigNode : public Node {
public:
int nbit_input = 8;
int nbit_weight = 8;
int nbit_activation = 32;
DataType dtype_input = Int(8);
DataType dtype_weight = Int(8);
DataType dtype_activation = Int(32);
double global_scale = 8.0;
int skip_k_conv = 1;
bool round_for_shift = true;
bool store_lowbit_output = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
void VisitAttrs(AttrVisitor* v) final {
v->Visit("nbit_input", &nbit_input);
v->Visit("nbit_weight", &nbit_weight);
v->Visit("nbit_activation", &nbit_activation);
v->Visit("dtype_input", &dtype_input);
v->Visit("dtype_weight", &dtype_weight);
v->Visit("dtype_activation", &dtype_activation);
v->Visit("global_scale", &global_scale);
v->Visit("skip_k_conv", &skip_k_conv);
v->Visit("round_for_shift", &round_for_shift);
v->Visit("store_lowbit_output", &store_lowbit_output);
v->Visit("debug_enabled_ops", &debug_enabled_ops);
}
static constexpr const char* _type_key = "relay.quantize.QConfig";
TVM_DECLARE_NODE_TYPE_INFO(QConfigNode, Node);
};
/*!
* \brief Container for build configuration options
*/
class QConfig : public NodeRef {
public:
QConfig() {}
explicit QConfig(NodePtr<Node> n) : NodeRef(n) {}
const QConfigNode* operator->() const {
return static_cast<const QConfigNode*>(node_.get());
}
QConfigNode* operator->() {
return static_cast<QConfigNode*>(node_.get());
}
/*!
* \brief Push a new BuildConfig context onto the thread local stack.
* \param build_config The configuration to set as the current context.
*/
static void EnterQConfigScope(const QConfig& qconfig);
/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
static void ExitQConfigScope();
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
*/
static QConfig Current();
using ContainerType = QConfigNode;
};
/*!
* \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
* context stack when constructed, and pops it when destructed.
*/
struct QConfigContext {
/*!
* \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
* context. When the BuildConfigContext is destructed, the previous context is restored.
* \param build_config The BuildConfig to set as the new current context.
*/
explicit QConfigContext(const QConfig& qconfig) {
QConfig::EnterQConfigScope(qconfig);
}
/*! \brief Destructor. Pops the context off the thread local stack. */
~QConfigContext() {
QConfig::ExitQConfigScope();
}
};
/*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \return The new BuildConfig
*/
TVM_DLL QConfig qconfig();
} // namespace quantize
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_PASS_QUANTIZE_H_
import math
import numpy as np
import tvm
from tvm import relay
from tvm.relay import quantize as qtz
def make_dataset(graph, size=100):
args = relay.ir_pass.infer_type(graph).params
def create_arr(var):
ttype = var.type_annotation
np_arr = np.random.uniform(-1.0, 1.0, size=ttype.concrete_shape).astype(ttype.dtype)
return tvm.ndarray.array(np_arr)
params = {}
for arg in args:
if arg.name_hint == 'data':
dataset = [{'data': create_arr(arg)} for _ in range(size)]
else:
params[arg.name_hint] = create_arr(arg)
return dataset, params
def test_simulated_quantize():
data = relay.var("data", relay.ty.TensorType((3, 4, 5, 6), "float32"))
out = qtz._annotate.attach_simulated_quantize(data, 1)
out = relay.ir_pass.infer_type(out)
assert out.checked_type == out.args[0].checked_type
assert out.args[1].checked_type == relay.ty.TensorType(tuple(), "float32")
assert out.args[2].checked_type == relay.ty.TensorType(tuple(), "float32")
assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")
def test_quantize_pass():
def quantize_weight(arr):
maximum = np.amax(np.abs(arr.asnumpy()))
scale = 2**math.ceil(math.log(maximum, 2))
out = np.around(arr.asnumpy() / scale * 128).astype('int8')
out = np.clip(out, -127, 127)
return relay.const(out, 'int8')
n, c, h, w = 1, 3, 224, 224
def make_graph(data):
weight = relay.var("conv_weight")
out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
out = relay.Function(relay.ir_pass.free_vars(out), out)
return out
def make_qgraph(data, weight):
out = data * relay.const(32.0)
out = relay.round(out)
out = relay.clip(out, a_min=-127, a_max=127)
out = out.astype('int8')
out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
padding=(1, 1), channels=c, out_dtype='int32')
out = out.astype('float32')
out = relay.multiply(out, relay.const(0.00024414062))
out = relay.Function(relay.ir_pass.free_vars(out), out)
return out
data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
graph = make_graph(data)
dataset, params = make_dataset(graph, 10)
with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
round_for_shift=False, store_lowbit_output=False):
qgraph0 = qtz.quantize(graph, params)
qgraph0 = relay.ir_pass.infer_type(qgraph0)
conv_weight = quantize_weight(params['conv_weight'])
qgraph1 = make_qgraph(data, conv_weight)
qgraph1 = relay.ir_pass.infer_type(qgraph1)
graph = relay.create_executor('graph')
res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy())
if __name__ == "__main__":
np.random.seed(42)
test_simulated_quantize()
test_quantize_pass()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment