Commit 33ab3c60 by Wuwei Lin Committed by ziheng

[Relay][Quantization] KL-divergence-based per-layer calibration (#3538)

* [Relay][Quantization] Support floating-point scale

* [Relay][Quantization] KL-divergence calibration on dataset

* Fix unhandled LeftShift case in QuantizeRealize

* Fix lint

* drop QBias

* fix lint

* address comments

* address comments

* Update comments

* address comments

* lint

* kQIdentity = 0
parent 5357f49b
...@@ -20,3 +20,4 @@ from __future__ import absolute_import as _abs ...@@ -20,3 +20,4 @@ from __future__ import absolute_import as _abs
from .quantize import * from .quantize import *
from ._annotate import register_annotate_function from ._annotate import register_annotate_function
from .kl_divergence import kl_divergence_scale
...@@ -39,6 +39,9 @@ def simulated_quantize_compute(attrs, inputs, out_type, target): ...@@ -39,6 +39,9 @@ def simulated_quantize_compute(attrs, inputs, out_type, target):
data, scale, clip_min, clip_max = inputs data, scale, clip_min, clip_max = inputs
if attrs.kind == QAnnotateKind.IDENTITY:
return [topi.identity(data)]
# simulate rounding error # simulate rounding error
scaled_data = topi.divide(data, scale) scaled_data = topi.divide(data, scale)
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
...@@ -52,7 +55,7 @@ def simulated_quantize_compute(attrs, inputs, out_type, target): ...@@ -52,7 +55,7 @@ def simulated_quantize_compute(attrs, inputs, out_type, target):
_reg.register_schedule("relay.op.annotation.simulated_quantize", _reg.register_schedule("relay.op.annotation.simulated_quantize",
_reg.schedule_injective) _reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize", _reg.register_pattern("relay.op.annotation.simulated_quantize",
_reg.OpPattern.OPAQUE) _reg.OpPattern.ELEMWISE)
@register_relay_node @register_relay_node
...@@ -251,7 +254,7 @@ def add_rewrite(ref_call, new_args, ctx): ...@@ -251,7 +254,7 @@ def add_rewrite(ref_call, new_args, ctx):
if lhs_kind is None and rhs_kind is not None: if lhs_kind is None and rhs_kind is not None:
# quantize lhs to INPUT field if it is normal expression # quantize lhs to INPUT field if it is normal expression
assert rhs_kind == QAnnotateKind.INPUT assert rhs_kind in [QAnnotateKind.INPUT, QAnnotateKind.ACTIVATION]
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.INPUT) return QAnnotateExpr(expr, QAnnotateKind.INPUT)
...@@ -275,7 +278,8 @@ def add_rewrite(ref_call, new_args, ctx): ...@@ -275,7 +278,8 @@ def add_rewrite(ref_call, new_args, ctx):
rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT: if (lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT) or \
(lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION):
expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
raise ValueError() raise ValueError()
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Find optimal scale for quantization by minimizing KL-divergence"""
try:
from scipy import stats
except ImportError:
stats = None
import numpy as np
def _smooth_distribution(p, eps=0.0001):
"""Given a discrete distribution (may have not been normalized to 1),
smooth it by replacing zeros with eps multiplied by a scaling factor and taking the
corresponding amount off the non-zero values.
Ref: http://hanj.cs.illinois.edu/cs412/bk3/KL-divergence.pdf
"""
is_zeros = (p == 0).astype(np.float32)
is_nonzeros = (p != 0).astype(np.float32)
n_zeros = is_zeros.sum()
n_nonzeros = p.size - n_zeros
if not n_nonzeros:
raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
eps1 = eps * float(n_zeros) / float(n_nonzeros)
assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1)
hist = p.astype(np.float32)
hist += eps * is_zeros + (-eps1) * is_nonzeros
assert (hist <= 0).sum() == 0
return hist
# pylint: disable=invalid-name
def kl_divergence_scale(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255):
"""Given a tensor, find the optimal threshold for quantizing it.
The reference distribution is `q`, and the candidate distribution is `p`.
`q` is a truncated version of the original distribution.
Ref:
http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
"""
assert isinstance(arr, np.ndarray)
min_val = np.min(arr)
max_val = np.max(arr)
th = max(abs(min_val), abs(max_val))
if min_val >= 0 and quantized_dtype in ['uint8']:
# We need to move negative bins to positive bins to fit uint8 range.
num_quantized_bins = num_quantized_bins * 2 + 1
hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th))
zero_bin_idx = num_bins // 2
num_half_quantized_bins = num_quantized_bins // 2
thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2)
divergence = np.zeros_like(thresholds)
quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32)
# i means the number of bins on half axis excluding the zero bin.
for i in range(num_quantized_bins // 2,
num_bins // 2 + 1):
p_bin_idx_start = zero_bin_idx - i
p_bin_idx_stop = zero_bin_idx + i + 1
thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop]
sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop]
# generate reference distribution p
p = sliced_nd_hist.copy()
assert p.size % 2 == 1
assert p.size >= num_quantized_bins
# put left outlier count in p[0]
left_outlier_count = np.sum(hist[0:p_bin_idx_start])
p[0] += left_outlier_count
# put right outlier count in p[-1]
right_outlier_count = np.sum(hist[p_bin_idx_stop:])
p[-1] += right_outlier_count
# is_nonzeros[k] indicates whether hist[k] is nonzero
is_nonzeros = (p != 0).astype(np.int32)
# calculate how many bins should be merged to generate quantized distribution q
num_merged_bins = sliced_nd_hist.size // num_quantized_bins
# merge hist into num_quantized_bins bins
for j in range(num_quantized_bins):
start = j * num_merged_bins
stop = start + num_merged_bins
quantized_bins[j] = sliced_nd_hist[start:stop].sum()
quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum()
# expand quantized_bins into p.size bins
q = np.zeros(sliced_nd_hist.size, dtype=np.float32)
for j in range(num_quantized_bins):
start = j * num_merged_bins
if j == num_quantized_bins - 1:
stop = len(is_nonzeros)
else:
stop = start + num_merged_bins
norm = is_nonzeros[start:stop].sum()
if norm != 0:
q[start:stop] = float(quantized_bins[j]) / float(norm)
q[p == 0] = 0
p = _smooth_distribution(p)
# There is a chance that q is an invalid probability distribution.
try:
q = _smooth_distribution(q)
except ValueError:
divergence[i - num_half_quantized_bins] = float("inf")
divergence[i - num_half_quantized_bins] = stats.entropy(p, q)
min_divergence_idx = np.argmin(divergence)
opt_th = thresholds[min_divergence_idx]
return opt_th
...@@ -32,6 +32,7 @@ from ..base import NodeBase, register_relay_node ...@@ -32,6 +32,7 @@ from ..base import NodeBase, register_relay_node
class QAnnotateKind(object): class QAnnotateKind(object):
"""Denote the kind of annotation field, corresponding """Denote the kind of annotation field, corresponding
to different nbit configure.""" to different nbit configure."""
IDENTITY = 0
INPUT = 1 INPUT = 1
WEIGHT = 2 WEIGHT = 2
ACTIVATION = 3 ACTIVATION = 3
...@@ -43,6 +44,7 @@ def kind2str(kind): ...@@ -43,6 +44,7 @@ def kind2str(kind):
QAnnotateKind.INPUT: "input", QAnnotateKind.INPUT: "input",
QAnnotateKind.WEIGHT: "weight", QAnnotateKind.WEIGHT: "weight",
QAnnotateKind.ACTIVATION: "activation", QAnnotateKind.ACTIVATION: "activation",
QAnnotateKind.IDENTITY: "identity"
} }
assert kind in str_map assert kind in str_map
return str_map[kind] return str_map[kind]
...@@ -195,7 +197,26 @@ def annotate_context(): ...@@ -195,7 +197,26 @@ def annotate_context():
return AnnotateContext.Current return AnnotateContext.Current
def calibrate(graph, mod=None, ctx=None): def collect_stats(graph):
"""Given an annotated graph, create a profile graph to collect profile data from the
calibration dataset. This pass collects simulated_quantize op input into a tuple.
Simulated_quantize ops are rewritten to identity mode. The tuple is the output of the profile
graph.
Parameters
----------
graph: Function
The simulation graph after annotation.
Returns
-------
ret: Function
The profile graph which outputs a tuple of profile data.
"""
return _quantize.CollectStats(graph)
def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None):
"""The calibrate procedure will try to calculate the content of """The calibrate procedure will try to calculate the content of
dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` dom_scale, nbit, clip_min, clip_max for every `simulated_quantize`
operator. operator.
...@@ -211,6 +232,16 @@ def calibrate(graph, mod=None, ctx=None): ...@@ -211,6 +232,16 @@ def calibrate(graph, mod=None, ctx=None):
ctx: tvm.relay.PassContext ctx: tvm.relay.PassContext
The pass context used for calibration. The pass context used for calibration.
weight_scales: 'power2' or 'max'.
The way to calculate scales for weights (annotated with QAnnotateKind.WEIGHT).
power2: Find the maximum of the absolute value of the tensor, and then round up to power
of two.
max: Find the maximum of the absolute value of the tensor.
scales: List[float]
Pre-calculated scales for input and activations. Length and the order of elements of the
scales list should match the output tuple of the profile graph created by collect_stats.
Returns Returns
------- -------
ret: Function ret: Function
...@@ -221,12 +252,20 @@ def calibrate(graph, mod=None, ctx=None): ...@@ -221,12 +252,20 @@ def calibrate(graph, mod=None, ctx=None):
val = np.amax(np.abs(arr.asnumpy())) val = np.amax(np.abs(arr.asnumpy()))
return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0
def max_scale(arr):
"""calculate weight scale with maximum absolute value"""
val = np.amax(np.abs(arr.asnumpy()))
return val
scale_idx = 0
cfg = current_qconfig() cfg = current_qconfig()
const_params = {} const_params = {}
quantize_op = _op.get("relay.op.annotation.simulated_quantize") quantize_op = _op.get("relay.op.annotation.simulated_quantize")
def visit_func(expr): def visit_func(expr):
"""Internal visit function""" """Internal visit function"""
nonlocal scale_idx
if isinstance(expr, _expr.Call) and expr.op == quantize_op: if isinstance(expr, _expr.Call) and expr.op == quantize_op:
_, ndom_scale, nclip_min, nclip_max = expr.args _, ndom_scale, nclip_min, nclip_max = expr.args
attrs = expr.attrs attrs = expr.attrs
...@@ -234,11 +273,21 @@ def calibrate(graph, mod=None, ctx=None): ...@@ -234,11 +273,21 @@ def calibrate(graph, mod=None, ctx=None):
nbit = cfg.get_nbit_by_kind(kind) nbit = cfg.get_nbit_by_kind(kind)
valid_bit = nbit - attrs.sign valid_bit = nbit - attrs.sign
if kind in [QAnnotateKind.WEIGHT]:
if kind == QAnnotateKind.WEIGHT: if all([isinstance(arg, _expr.Constant)
for arg in [ndom_scale, nclip_min, nclip_max]]):
return
var = expr.args[0] var = expr.args[0]
assert isinstance(var, _expr.Constant) assert isinstance(var, _expr.Constant)
scale = power2_scale(var.data) if weight_scales == 'max':
scale = max_scale(var.data)
elif weight_scales == 'power2':
scale = power2_scale(var.data)
else:
raise ValueError('{} not supported'.format(weight_scales))
elif scales is not None:
scale = scales[scale_idx]
scale_idx += 1
else: else:
scale = cfg.global_scale scale = cfg.global_scale
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
*
* \file calibrate.cc
*
* \brief Create profile graph and calibrate on dataset
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include "./quantize.h"
namespace tvm {
namespace relay {
namespace quantize {
class StatsCollector : private ExprMutator {
public:
Expr Collect(const Expr& expr) {
auto new_e = this->Mutate(expr);
const FunctionNode* func = new_e.as<FunctionNode>();
CHECK(func) << "Input shoule be Function";
Expr new_body = TupleNode::make(std::move(profile_data_));
return FunctionNode::make(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
func->attrs);
}
private:
Array<Expr> profile_data_;
Expr VisitExpr_(const CallNode* call) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
Expr new_e = ExprMutator::VisitExpr_(call);
const CallNode* new_call = new_e.as<CallNode>();
CHECK(new_call);
if (new_call->op.same_as(simulated_quantize)) {
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
// rewrite the annotation
auto new_attrs = make_node<SimulatedQuantizeAttrs>();
const Expr& quantize_input = new_call->args[0]; // expression being quantized
auto placeholder = MakeConstantScalar(Float(32), 0.); // unused argument
Array<Expr> new_args{quantize_input, placeholder, placeholder, placeholder};
new_attrs->kind = QAnnotateKind::kQIdentity;
new_attrs->sign = attrs->sign;
new_attrs->rounding = attrs->rounding;
Expr identity_quantize = CallNode::make(new_call->op, new_args, Attrs{new_attrs}, {});
// add non-const expressions to profile data
if (attrs->kind != QAnnotateKind::kQWeight) {
CHECK(!quantize_input.as<ConstantNode>());
profile_data_.push_back(identity_quantize);
}
return identity_quantize;
} else {
return new_e;
}
}
};
/*
* \brief Given an annotated graph, create a profile graph to collect profile data from the
* calibration dataset.
*
* This pass collects simulated_quantize op into a tuple. Simulated_quantize ops are rewritten to
* identity mode. The tuple is the output of the profile graph. Both input and output of this pass
* are relay::Function.
*
* \param expr The simulation graph after annotation.
* \return The profile graph.
*/
Expr CollectStats(const Expr& expr) {
return StatsCollector().Collect(expr);
}
TVM_REGISTER_API("relay._quantize.CollectStats")
.set_body_typed(CollectStats);
} // namespace quantize
} // namespace relay
} // namespace tvm
...@@ -36,8 +36,8 @@ ...@@ -36,8 +36,8 @@
#include <vector> #include <vector>
#include <stack> #include <stack>
#include <utility> #include <utility>
#include "pattern_util.h" #include "../pattern_util.h"
#include "quantize.h" #include "./quantize.h"
namespace tvm { namespace tvm {
...@@ -46,22 +46,6 @@ namespace quantize { ...@@ -46,22 +46,6 @@ namespace quantize {
using namespace relay::transform; using namespace relay::transform;
/*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind;
bool sign;
std::string rounding;
TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
TVM_ATTR_FIELD(kind)
.describe("kind of field, hint for nbit/dtype configuration.");
TVM_ATTR_FIELD(sign).set_default(true)
.describe("whether to use signed data type.");
TVM_ATTR_FIELD(rounding).set_default("round")
.describe("rounding mode. Can be 'floor', 'ceil', 'round'");
}
};
TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);
bool SimulatedQuantizeRel(const Array<Type>& types, bool SimulatedQuantizeRel(const Array<Type>& types,
...@@ -166,23 +150,22 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) { ...@@ -166,23 +150,22 @@ inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
/* calculate `data * s1 / s2`, use shift if possible */ /* calculate `data * s1 / s2`, use shift if possible */
inline Expr MulAndDiv(Expr data, float s1, float s2) { inline Expr MulAndDiv(Expr data, float s1, float s2, DataType dtype) {
// here we assume the dtype of data is dtype activation // here we assume the dtype of data is dtype activation
const QConfig& cfg = QConfig::Current();
if (s1 == s2) return data; if (s1 == s2) return data;
float factor = s1 / s2; float factor = s1 / s2;
float shift_factor = std::log2(factor); float shift_factor = std::log2(factor);
CHECK_GT(shift_factor, 0); CHECK_GT(shift_factor, 0);
if (static_cast<int>(shift_factor) == shift_factor) { if (static_cast<int>(shift_factor) == shift_factor) {
return LeftShift(data, MakeConstantScalar(cfg->dtype_activation, return LeftShift(data, MakeConstantScalar(dtype,
static_cast<int>(shift_factor))); static_cast<int>(shift_factor)));
} else if (static_cast<int>(factor) == factor) { } else if (static_cast<int>(factor) == factor) {
return Multiply(data, MakeConstantScalar(cfg->dtype_activation, factor)); return Multiply(data, MakeConstantScalar(dtype, factor));
} else { } else {
LOG(FATAL) << "fall back to float computation";
data = Cast(data, Float(32)); data = Cast(data, Float(32));
return Multiply(data, MakeConstantScalar(Float(32), factor)); data = Multiply(data, MakeConstantScalar(Float(32), factor));
return Cast(Round(data), dtype);
} }
} }
...@@ -216,15 +199,21 @@ Expr QuantizeRealize(const Call& ref_call, ...@@ -216,15 +199,21 @@ Expr QuantizeRealize(const Call& ref_call,
} }
float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
CHECK_GT(shift_nbit, 0); CHECK_NE(shift_nbit, 0);
if (static_cast<int>(shift_nbit) == shift_nbit) { if (static_cast<int>(shift_nbit) == shift_nbit) {
// use right shift if (shift_nbit > 0) {
if (cfg->round_for_shift) { // use right shift
float round_bias = std::pow(2.0, shift_nbit - 1); if (cfg->round_for_shift) {
data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast<int>(round_bias))); float round_bias = std::pow(2.0, shift_nbit - 1);
data = Add(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(round_bias)));
}
data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
} else {
data = LeftShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
} }
data = RightShift(data, MakeConstantScalar(cfg->dtype_activation,
static_cast<int>(shift_nbit)));
data = Clip(data, clip_min_imm, clip_max_imm); data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype); return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
} else { } else {
...@@ -338,15 +327,11 @@ Expr MulRealize(const Call& ref_call, ...@@ -338,15 +327,11 @@ Expr MulRealize(const Call& ref_call,
Expr rdata = rhs->data; Expr rdata = rhs->data;
DataType dtype = cfg->dtype_activation; DataType dtype = cfg->dtype_activation;
if (lhs->dtype == Float(32)) { if (lhs->dtype != dtype) {
ldata = Cast(ldata, dtype); ldata = Cast(ldata, dtype);
} else {
CHECK_EQ(lhs->dtype, dtype);
} }
if (rhs->dtype == Float(32)) { if (rhs->dtype != dtype) {
rdata = Cast(rdata, dtype); rdata = Cast(rdata, dtype);
} else {
CHECK_EQ(rhs->dtype, dtype);
} }
Expr ret = ForwardOp(ref_call, {ldata, rdata}); Expr ret = ForwardOp(ref_call, {ldata, rdata});
...@@ -418,7 +403,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args ...@@ -418,7 +403,7 @@ Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args
Expr dom_scale = MakeConstantScalar(Float(32), s); Expr dom_scale = MakeConstantScalar(Float(32), s);
for (size_t i = 0; i < ret.size(); ++i) { for (size_t i = 0; i < ret.size(); ++i) {
float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale); float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);
ret.Set(i, MulAndDiv(ret[i], cur_s, s)); ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype));
} }
*dtype_ptr = dtype; *dtype_ptr = dtype;
......
...@@ -23,13 +23,13 @@ ...@@ -23,13 +23,13 @@
* \file tvm/relay/pass/quantize.h * \file tvm/relay/pass/quantize.h
* \brief Header of definitions for quantization * \brief Header of definitions for quantization
*/ */
#ifndef TVM_RELAY_PASS_QUANTIZE_H_ #ifndef TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_
#define TVM_RELAY_PASS_QUANTIZE_H_ #define TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <string> #include <string>
#include "pattern_util.h" #include "../pattern_util.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -37,9 +37,26 @@ namespace quantize { ...@@ -37,9 +37,26 @@ namespace quantize {
/*! \brief Kind of annotate field */ /*! \brief Kind of annotate field */
enum QAnnotateKind : int { enum QAnnotateKind : int {
kQIdentity = 0,
kQInput = 1, kQInput = 1,
kQWeight = 2, kQWeight = 2,
kQActivation = 3, kQActivation = 3
};
/*! \brief Attribute for simulated quantize operator */
struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
int kind;
bool sign;
std::string rounding;
TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
TVM_ATTR_FIELD(kind)
.describe("kind of field, hint for nbit/dtype configuration.");
TVM_ATTR_FIELD(sign).set_default(true)
.describe("whether to use signed data type.");
TVM_ATTR_FIELD(rounding).set_default("round")
.describe("rounding mode. Can be 'floor', 'ceil', 'round'");
}
}; };
/*! /*!
...@@ -242,4 +259,4 @@ TVM_DLL QConfig qconfig(); ...@@ -242,4 +259,4 @@ TVM_DLL QConfig qconfig();
} // namespace quantize } // namespace quantize
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_PASS_QUANTIZE_H_ #endif // TVM_RELAY_PASS_QUANTIZE_QUANTIZE_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment