Unverified Commit 7eb1f353 by ziheng Committed by GitHub

[QUANTIZE] Refactor quantization codebase and fix model accuracy (#3543)

* Refactor.

* update

* update

* update

* update

* update

* update
parent 60fc9f74
......@@ -52,6 +52,18 @@ namespace relay {
TVM_DLL Kind KindCheck(const Type& t, const Module& mod);
/*!
* \brief Check whether an expression is constant.
*
* If the inputs of an expression are all constant, it means the expression
* itself is constant also.
*
* \param e the expression.
*
* \return whether the expression is constant.
*/
TVM_DLL bool ConstantCheck(const Expr& e);
/*!
* \brief Compare two expressions for structural equivalence.
*
* This comparison operator respects scoping and compares
......
......@@ -44,6 +44,19 @@ struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
}
};
/*!
* \brief Annotate an expression to be cast into specific data type.
*/
struct CastHintAttrs : public tvm::AttrsNode<CastHintAttrs> {
DataType dtype;
TVM_DECLARE_ATTRS(CastHintAttrs, "relay.attrs.CastHintAttrs") {
TVM_ATTR_FIELD(dtype)
.describe(
"The data type denoted to be cast.");
}
};
} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_ATTRS_ANNOTATION_H_
......@@ -91,6 +91,22 @@ def check_kind(t, mod=None):
return _analysis.check_kind(t)
def check_constant(expr):
"""Check whether an expression is constant
Parameters
----------
expr : tvm.relay.Expr
The input expression
Returns
-------
result : bool
Whether the expression is constant.
"""
return _analysis.check_constant(expr)
def free_vars(expr):
"""Get free Vars from expression expr in Post DFS order.
......
......@@ -19,5 +19,6 @@
from __future__ import absolute_import as _abs
from .quantize import *
from ._partition import register_partition_function
from ._annotate import register_annotate_function
from .kl_divergence import kl_divergence_scale
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#pylint: disable=unused-argument,inconsistent-return-statements
"""Internal module for registering attribute for annotation."""
from __future__ import absolute_import
from ... import target as _target
from .. import expr as _expr
from .. import analysis as _analysis
from ..base import register_relay_node
from ..op import op as _reg
from . import _quantize
from .quantize import _forward_op
def register_partition_function(op_name, frewrite=None, level=10):
def _register(func):
return _reg._Register(op_name, "FQPartitionRewrite", func, level)
return _register(frewrite) if frewrite is not None else _register
@register_relay_node
class QPartitionExpr(_expr.TempExpr):
def __init__(self, expr):
self.__init_handle_by_constructor__(
_quantize.make_partition_expr, expr)
def partition_expr_check(expr):
if isinstance(expr, QPartitionExpr):
return True, expr.expr
return False, expr
@register_partition_function("nn.conv2d")
def conv2d_partition_function(ref_call, new_args, ctx):
"""Rewrite function for conv2d for partition"""
data_cond, data = partition_expr_check(new_args[0])
kernel_cond, kernel = partition_expr_check(new_args[1])
assert not kernel_cond
if data_cond:
data = new_args[0].realize()
ret = _forward_op(ref_call, [data, kernel])
return QPartitionExpr(ret)
def identity_partition_function(ref_call, new_args, ctx):
cond, expr = partition_expr_check(new_args[0])
if cond:
return QPartitionExpr(_forward_op(ref_call, [expr]))
return None
register_partition_function("clip", identity_partition_function)
register_partition_function("nn.relu", identity_partition_function)
register_partition_function("nn.max_pool2d", identity_partition_function)
def add_partition_generic(ref_call, new_args, ctx):
"""Rewrite function for ewise add for partition for generic devices"""
lhs_cond, lhs = partition_expr_check(new_args[0])
rhs_cond, rhs = partition_expr_check(new_args[1])
if lhs_cond and rhs_cond:
# - introduced by ResNet, when for the first residual connection
# ...
# %0 = nn.conv2d(%data, %meta[relay.Constant])
# %1 = add(%0, %meta[relay.Constant])
# %2 = nn.relu(%1)
# %3 = nn.max_pool2d(%2)
# ...
# %9 = nn.conv2d(%8, %meta[relay.Constant])
# %10 = add(%9, %meta[relay.Constant])
# %11 = add(%3, %10) <- need to insert annotations for %3, %10
# ...
lhs = new_args[0].realize()
rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs])
elif not lhs_cond and rhs_cond:
# - introduced by residual connection in ResNet
# ...
# %13 = nn.conv2d(%12, %meta[relay.Constant])
# %14 = add(%13, %meta[relay.Constant])
# %15 = annotation.cast_hint(%15, 'int8')
# %16 = annotation.stop_fusion(%16)
# %17 = add(%5, %16)
# %18 = nn.relu(%17)
# ...
# %24 = nn.conv2d(%23, %meta[relay.Constant])
# %25 = add(%24, %meta[relay.Constant])
# %26 = add(%18, %25) <- need to insert annotations for %25
# ...
rhs = new_args[1].realize()
return _forward_op(ref_call, [lhs, rhs])
elif lhs_cond and not rhs_cond:
if _analysis.check_constant(rhs):
# - introduced by batch_norm: add(out, bias)
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
# - introduced by residual connection in MobileNetV2
# ...
# %81 = add(%80, meta[relay.Constant])
# %82 = annotation.cast_hint(%81, 'int8')
# %83 = annotation.stop_fusion(%82)
# %84 = add(%79, %83)
# ...
# %96 = nn.conv2d(%94, %meta[relay.Constant])
# %96 = add(%95, %meta[relay.Constant])
# %97 = add(%96, %84) <- need to insert annotations for %96
# ...
lhs = new_args[0].realize()
return _forward_op(ref_call, [lhs, rhs])
elif not lhs_cond and not rhs_cond:
# trivial case
return None
else:
raise ValueError
# TODO(ziheng) enhance `register_partition_function` to dispatch
# for target automatically
@register_partition_function("add")
def add_partition_function(ref_call, new_args, ctx):
"""Rewrite function for ewise add for partition"""
if 'cuda' in _target.current_target().keys:
#TODO(wuwei/ziheng) cuda specific rules
return add_partition_generic(ref_call, new_args, ctx)
return add_partition_generic(ref_call, new_args, ctx)
@register_partition_function("multiply")
def multiply_partition_function(ref_call, new_args, ctx):
"""Rewrite function for ewise add for partition"""
lhs_cond, lhs = partition_expr_check(new_args[0])
rhs_cond, rhs = partition_expr_check(new_args[1])
if lhs_cond:
# introduced by bn: multiply(out, scale)
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
assert (not lhs_cond) and (not rhs_cond)
return None
......@@ -50,6 +50,12 @@ def kind2str(kind):
return str_map[kind]
def _forward_op(ref_call, args):
"""forward the operator of ref_call with provided arguments"""
return _expr.Call(
ref_call.op, args, ref_call.attrs, ref_call.type_args)
@register_relay_node("relay.quantize.QConfig")
class QConfig(NodeBase):
"""Configure the quantization behavior by setting config variables.
......@@ -74,8 +80,8 @@ class QConfig(NodeBase):
"dtype_activation": "int32",
"global_scale": 8.0,
"skip_conv_layers": [0],
"do_simulation": False,
"round_for_shift": True,
"store_lowbit_output": True,
"debug_enabled_ops": None,
}
......@@ -92,6 +98,7 @@ class QConfig(NodeBase):
self.handle = handle
def guard(self, ref_call):
"""Return true if op is enabled, otherwise return false"""
op_name = ref_call.op.name
if self.debug_enabled_ops is not None:
name_list = [x.value for x in self.debug_enabled_ops]
......@@ -126,9 +133,7 @@ def current_qconfig():
"""Get the current quantization configuration."""
return _quantize._GetCurrentQConfig()
# TODO(tmoreau89, ZihengJiang) the skip parameters are
# hacky - we should explore a more future-proof way to
# skip operators based on pattern matching
def qconfig(**kwargs):
"""Configure the quantization behavior by setting config variables.
......@@ -142,15 +147,14 @@ def qconfig(**kwargs):
skip_conv_layers: list
Specifying which layers to be skipped. Provide a list of indices
that indicate which conv2d layers to leave untouched.
that indicate which conv2d layers to leave untouched. Start from 0.
do_simulation: boolean
Whether to do simulation with float operation only.
round_for_shift: boolean
Whether to add bias for rounding during shift.
store_lowbit_output: boolean
Whether to store low-bit integer back as output before dequantizing.
Some accelerators need this, e.g. VTA.
debug_enabled_ops: None or list of str
Partially quantize specified operators for debugging. The default value
is None, which means will try to call all operartors' annotate rewrite
......@@ -166,35 +170,79 @@ def qconfig(**kwargs):
return _make.node("relay.quantize.QConfig", **node_args)
class AnnotateContext(object):
"""A global singleton annotate scope"""
class QuantizeContext(object):
"""An internal used global context object for annotation,
for putting some state variables like `conv2d_counter`."""
Current = None
def __init__(self):
self.qnode_map = dict()
self._conv2d_counter = 0
self._stop_quantize = False
def check_to_skip(self, ref_call):
"""Check the index of conv2d layer to decide whether to
skip the current operator."""
if self._stop_quantize:
return True
if current_qconfig().skip_conv_layers is not None:
# check skip conv layers
skipped_indices = [int(x) for x in current_qconfig().skip_conv_layers]
if self._conv2d_counter in skipped_indices:
if ref_call.op.name == 'nn.conv2d':
self._conv2d_counter += 1
return True
if ref_call.op.name == 'nn.conv2d':
self._conv2d_counter += 1
return False
def stop_quantize(self):
self._stop_quantize = True
def reset(self):
self._conv2d_counter = 0
self._stop_quantize = False
def __enter__(self):
self._conv2d_counter = 0
self.reset()
return self
def conv2d_counter(self):
"""Get the counter for conv2d."""
return self._conv2d_counter
def count_conv2d(self):
"""Increase the value of the conv2d counter by one."""
self._conv2d_counter += 1
def __exit__(self, ptype, value, traceback):
pass
def annotate_context():
def quantize_context():
"""Get the global singleton scope"""
if AnnotateContext.Current is None:
AnnotateContext.Current = AnnotateContext()
return AnnotateContext.Current
if QuantizeContext.Current is None:
QuantizeContext.Current = QuantizeContext()
return QuantizeContext.Current
def partition():
"""Partition graph into small low-precision sections by `cast_hint` and
`stop_fusion`.
Returns
-------
ret: tvm.relay.Pass
The registered pass for VTA rewrite.
"""
return _quantize.QuantizePartition()
def annotate():
"""Given a float32 graph, this pass will rewrite the graph and return
a graph which simulates the error brought by the current quantization
scheme.
Returns
-------
ret: tvm.relay.Pass
The registered pass for quantization annotation.
"""
return _quantize.QuantizeAnnotate()
def collect_stats(graph):
......@@ -300,20 +348,8 @@ def calibrate(graph, mod=None, ctx=None, weight_scales='power2', scales=None):
const_params[nclip_max] = _make_const((valid_range - 1))
_analysis.post_order_visit(graph, visit_func)
return _expr.bind(graph, const_params)
def annotate():
"""Given a float32 graph, this pass will rewrite the graph and return
a graph which simulates the error brought by the current quantization
scheme.
Returns
-------
ret: tvm.relay.Pass
The registered pass for quantization annotation.
"""
return _quantize.QuantizeAnnotate()
ret = _expr.bind(graph, const_params)
return ret
def realize():
......@@ -330,17 +366,6 @@ def realize():
return _quantize.QuantizeRealize()
def rewrite_for_vta():
"""Performs rewriting for VTA target.
Returns
-------
ret: tvm.relay.Pass
The registered pass for VTA rewrite.
"""
return _quantize.QuantizeRewriteForVTA()
def _bind_params(func, params):
"""Bind the params to the expression.
"""
......@@ -362,6 +387,25 @@ def _bind_params(func, params):
return _expr.bind(func, bind_dict)
def prerequisite_optimize(graph, params=None):
""" Prerequisite optimization passes for quantization. Perform
"SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization. """
optimize = _transform.Sequential([_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
if params:
graph = _bind_params(graph, params)
mod = _module.Module.from_expr(graph)
with _transform.PassContext(opt_level=3):
mod = optimize(mod)
return mod["main"]
def quantize(graph, params=None, dataset=None):
""" The quantization procedure. Before running the three main
procedure of quantization, "annotate", "calibrate" and "realize"
......@@ -385,33 +429,23 @@ def quantize(graph, params=None, dataset=None):
ret: Function
The graph after quantization
"""
if params:
graph = _bind_params(graph, params)
graph = prerequisite_optimize(graph, params)
mod = _module.Module.from_expr(graph)
# Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
# "CanonicalizeOps" optimization before quantization.
optimize = _transform.Sequential([_transform.SimplifyInference(),
_transform.FoldConstant(),
_transform.FoldScaleAxis(),
_transform.CanonicalizeOps(),
_transform.FoldConstant()])
calibrate_pass = _transform.function_pass(calibrate, opt_level=1,
name="QuantizeCalibrate")
# Quantize pass list
quant_passes = [annotate(),
calibrate_pass,
realize(),
_transform.FoldConstant()]
if current_qconfig().store_lowbit_output:
quant_passes = [rewrite_for_vta()] + quant_passes
quant_passes = [partition(),
annotate(),
calibrate_pass]
if not current_qconfig().do_simulation:
quant_passes.append(realize())
quant_passes.append(_transform.FoldConstant())
quantize_seq = _transform.Sequential(quant_passes)
with _transform.PassContext(opt_level=3,
required_pass=["QuantizeAnnotate",
"QuantizeCalibrate",
"QuantizeRealize"]):
mod = optimize(mod)
mod = quantize_seq(mod)
with quantize_context():
mod = quantize_seq(mod)
return mod["main"]
......@@ -83,13 +83,18 @@ TVM_ADD_FILELINE)
return {topi::identity(inputs[0])};
});
Expr ForceCast(Expr data) {
static const Op& op = Op::Get("annotation.force_cast");
return CallNode::make(op, {data}, Attrs{}, {});
// relay.annotation.cast_hint
TVM_REGISTER_NODE_TYPE(CastHintAttrs);
Expr CastHint(Expr data, DataType dtype) {
auto attrs = make_node<CastHintAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("annotation.cast_hint");
return CallNode::make(op, {data}, Attrs{attrs}, {});
}
RELAY_REGISTER_OP("annotation.force_cast")
.describe(R"code(Annotate an expression to force a cast.)code"
RELAY_REGISTER_OP("annotation.cast_hint")
.describe(R"code(Annotate an expression to be cast into specific data type.)code"
TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input data.")
......
......@@ -66,6 +66,13 @@ class ConstantChecker : private ExprVisitor {
}
};
bool ConstantCheck(const Expr& e) {
return ConstantChecker().Check(e);
}
TVM_REGISTER_API("relay._analysis.check_constant")
.set_body_typed(ConstantCheck);
// TODO(tvm-team) consider combine dead-code with constant folder.
// or make a more powerful partial evaluator.
......
......@@ -31,6 +31,7 @@
#include <tvm/data_layout.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/attrs/reduce.h>
......@@ -420,7 +421,7 @@ Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array
Expr StopFusion(Expr data);
Expr ForceCast(Expr data);
Expr CastHint(Expr data, DataType dtype);
} // namespace relay
} // namespace tvm
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
*
* \file annotate.cc
*
* \brief Annotating the graph with simulated quantize operators.
*/
#include <tvm/relay/transform.h>
#include <tvm/relay/analysis.h>
#include "./quantize.h"
namespace tvm {
namespace relay {
namespace quantize {
using namespace relay::transform;
class QAnnotateExpr;
class QAnnotateExprNode : public TempExprNode {
public:
Expr expr;
QAnnotateKind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
v->Visit("kind", &kind);
}
TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QAnnotateExpr";
TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr);
Expr QAnnotateExprNode::Realize() const {
return expr;
}
QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) {
auto rnode = make_node<QAnnotateExprNode>();
rnode->expr = expr;
rnode->kind = kind;
return QAnnotateExpr(rnode);
}
TVM_REGISTER_API("relay._quantize.make_annotate_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QAnnotateExprNode::make(args[0],
static_cast<QAnnotateKind>(args[1].operator int()));
});
Pass QuantizeAnnotate() {
// TODO(tvm-teams): since partition has added cast_hint in different
// branches, try to remove this in the future.
std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
if (e->derived_from<TempExprNode>()) {
const auto* n = e.as<QAnnotateExprNode>();
CHECK(n);
const PackedFunc* f =
runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
}
return e;
};
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
auto new_params = func->params;
for (const auto& x : FreeVars(func)) {
new_params.push_back(x);
}
return FunctionNode::make(new_params,
func->body,
func->ret_type,
func->type_params,
func->attrs);
};
return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
}
TVM_REGISTER_API("relay._quantize.QuantizeAnnotate")
.set_body_typed(QuantizeAnnotate);
} // namespace quantize
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
*
* \file partition.cc
*
* \brief Partition a graph into sections for quantization.
*/
#include <tvm/relay/transform.h>
#include "../pattern_util.h"
#include "./quantize.h"
namespace tvm {
namespace relay {
namespace quantize {
using namespace relay::transform;
class QPartitionExpr;
class QPartitionExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr expr;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
}
TVM_DLL static QPartitionExpr make(Expr expr);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QPartitionExpr";
TVM_DECLARE_NODE_TYPE_INFO(QPartitionExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QPartitionExpr, QPartitionExprNode, TempExpr);
Expr QPartitionExprNode::Realize() const {
// insert cast hint and stop fusion
const QConfig& cfg = QConfig::Current();
Expr ret = CastHint(this->expr, cfg->dtype_input);
return StopFusion(ret);
}
QPartitionExpr QPartitionExprNode::make(Expr expr) {
auto rnode = make_node<QPartitionExprNode>();
rnode->expr = expr;
return QPartitionExpr(rnode);
}
TVM_REGISTER_API("relay._quantize.make_partition_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QPartitionExprNode::make(args[0]);
});
Pass QuantizePartition() {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
auto ret = Downcast<Function>(
ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr));
return ret;
};
return CreateFunctionPass(pass_func, 1, "QuantizePartition", {});
}
TVM_REGISTER_API("relay._quantize.QuantizePartition")
.set_body_typed(QuantizePartition);
} // namespace quantize
} // namespace relay
} // namespace tvm
......@@ -59,104 +59,8 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
}
};
/*!
* \brief TempExpr used during annotate forward rewrite.
*/
class QAnnotateExpr;
/*!
* \brief TempExprNode used during annotate forward rewrite.
*/
class QAnnotateExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr expr;
/*! \brief The kind of annotate field */
QAnnotateKind kind;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
v->Visit("kind", &kind);
}
TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QAnnotateExpr";
TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr);
/*!
* \brief TempExpr used to insert `force_cast` for VTA.
*/
class QVTAExpr;
/*!
* \brief TempExprNode used to insert `force_cast` for VTA.
*/
class QVTAExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr expr;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("expr", &expr);
}
TVM_DLL static QVTAExpr make(Expr expr);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QVTAExpr";
TVM_DECLARE_NODE_TYPE_INFO(QVTAExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QVTAExpr, QVTAExprNode, TempExpr);
/*! \brief TempExpr used during realize forward rewrite. */
class QRealizeExpr;
/*! \brief TempExpr representing integer. */
class QRealizeIntExpr;
class QRealizeExprNode : public TempExprNode {
public:
/*! \brief The original expression */
Expr data;
static constexpr const char* _type_key = "relay.quantize.QRealizeExpr";
TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode);
};
RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr);
class QRealizeIntExprNode : public QRealizeExprNode {
public:
Expr dom_scale;
/*! \brief current data type */
DataType dtype;
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
v->Visit("dom_scale", &dom_scale);
v->Visit("dtype", &dtype);
}
Expr Realize() const final;
TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype);
static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr";
TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode);
};
RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr);
class QConfig;
/*!
* \brief Container for build configuration options
*/
......@@ -170,8 +74,8 @@ class QConfigNode : public Node {
DataType dtype_activation = Int(32);
double global_scale = 8.0;
Array<Expr> skip_conv_layers = Array<Expr>(NodePtr<Node>(nullptr));
bool do_simulation = false;
bool round_for_shift = true;
bool store_lowbit_output = true;
Array<Expr> debug_enabled_ops = Array<Expr>(NodePtr<Node>(nullptr));
void VisitAttrs(AttrVisitor* v) final {
......@@ -183,8 +87,8 @@ class QConfigNode : public Node {
v->Visit("dtype_activation", &dtype_activation);
v->Visit("global_scale", &global_scale);
v->Visit("skip_conv_layers", &skip_conv_layers);
v->Visit("do_simulation", &do_simulation);
v->Visit("round_for_shift", &round_for_shift);
v->Visit("store_lowbit_output", &store_lowbit_output);
v->Visit("debug_enabled_ops", &debug_enabled_ops);
}
......@@ -250,12 +154,6 @@ struct QConfigContext {
}
};
/*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \return The new BuildConfig
*/
TVM_DLL QConfig qconfig();
} // namespace quantize
} // namespace relay
} // namespace tvm
......
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import namedtuple
import tvm
from tvm import relay
from tvm.relay import quantize as qtz
import mxnet as mx
from mxnet import gluon
import logging
import os
logging.basicConfig(level=logging.INFO)
Config = namedtuple('Config', ['model', 'nbit_input', 'dtype_input', 'nbit_output', 'dtype_output', 'global_scale', 'expected_acc'])
def get_val_data(model_name,
rec_val,
batch_size,
num_workers=4):
rec_val = os.path.expanduser(rec_val)
mean_rgb = [123.68, 116.779, 103.939]
std_rgb = [58.393, 57.12, 57.375]
def batch_fn(batch, ctx):
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
return data, label
img_size = 299 if model_name == 'inceptionv3' else 224
val_data = mx.io.ImageRecordIter(
path_imgrec = rec_val,
preprocess_threads = num_workers,
shuffle = False,
batch_size = batch_size,
resize = 256,
data_shape = (3, img_size, img_size),
mean_r = mean_rgb[0],
mean_g = mean_rgb[1],
mean_b = mean_rgb[2],
std_r = std_rgb[0],
std_g = std_rgb[1],
std_b = std_rgb[2],
)
return val_data, batch_fn
def get_model(model_name, batch_size, qconfig, target=None, original=False, simulated=False):
gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True)
img_size = 299 if model_name == 'inceptionv3' else 224
data_shape = (batch_size, 3, img_size, img_size)
mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
net = mod['main']
with relay.build_config(opt_level=3):
qfunc = relay.quantize.prerequisite_optimize(net, params=params)
logging.debug('original')
logging.debug(qfunc.astext(show_meta_data=False))
if original:
return qfunc
with qconfig:
logging.debug('current quantize config')
logging.debug(qtz.current_qconfig())
qfunc = qtz.quantize(qfunc)
logging.debug('after quantize')
logging.debug(qfunc.astext(show_meta_data=False))
return qfunc
def eval_acc(model, dataset, batch_fn, target=tvm.target.cuda(), ctx=tvm.gpu(), log_interval=100):
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(model, target)
# create runtime module
m = tvm.contrib.graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
# setup evaluaiton metric
dataset.reset()
batch_size = dataset.batch_size
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
acc_top1.reset()
acc_top5.reset()
# Execute
for i, batch in enumerate(dataset):
data, label = batch_fn(batch, [mx.cpu(0)])
m.run(data=data[0].asnumpy())
out_arr = m.get_output(0)
acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())])
acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())])
if not (i + 1) % log_interval:
_, top1 = acc_top1.get()
_, top5 = acc_top5.get()
nsamples = (i + 1) * batch_size
logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5)
logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5)
return top1
def test_quantize_acc(cfg, rec_val):
qconfig = qtz.qconfig(skip_conv_layers=[0],
nbit_input=cfg.nbit_input,
nbit_weight=cfg.nbit_input,
global_scale=cfg.global_scale,
dtype_input=cfg.dtype_input,
dtype_weight=cfg.dtype_input,
dtype_activation=cfg.dtype_output,
debug_enabled_ops=None)
model = get_model(cfg.model, 32, qconfig, tvm.target.cuda())
val_data, batch_fn = get_val_data(cfg.model, rec_val=rec_val, batch_size=32)
acc = eval_acc(model, val_data, batch_fn)
assert acc > cfg.expected_acc
return acc
if __name__ == "__main__":
#TODO(for user): replace the line with the path to imagenet validation dataset
rec_val = "/scratch/tqchen/imagenet/val.rec"
results = []
configs = [
Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=4.0, expected_acc=0.666),
Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=8.0, expected_acc=0.692),
Config('resnet18_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.692),
Config('resnet34_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.733),
Config('resnet50_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.747),
Config('resnet101_v1', nbit_input=8, dtype_input='int8', nbit_output=32, dtype_output='int32', global_scale=8.0, expected_acc=0.756),
# TODO: need to fix accuracy
# Config('mobilenetv2_1.0', nbit_input=8, dtype_input='int8', nbit_output=16, dtype_output='int16', global_scale=4.0),
]
for config in configs:
acc = test_quantize_acc(config, rec_val)
results.append((config, acc))
for res in results:
print(res)
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import math
import numpy as np
import tvm
from tvm import relay
from tvm.relay import quantize as qtz
from tvm.relay import transform
def run_infer_type(expr):
mod = relay.Module.from_expr(expr)
mod = transform.InferType()(mod)
entry = mod["main"]
return entry if isinstance(expr, relay.Function) else entry.body
def make_dataset(graph, size=100):
args = run_infer_type(graph).params
def create_arr(var):
ttype = var.type_annotation
np_arr = np.random.uniform(-1.0, 1.0, size=ttype.concrete_shape).astype(ttype.dtype)
return tvm.ndarray.array(np_arr)
params = {}
for arg in args:
if arg.name_hint == 'data':
dataset = [{'data': create_arr(arg)} for _ in range(size)]
else:
params[arg.name_hint] = create_arr(arg)
return dataset, params
def test_simulated_quantize():
data = relay.var("data", relay.ty.TensorType((3, 4, 5, 6), "float32"))
out = qtz._annotate.attach_simulated_quantize(data, 1)
out = run_infer_type(out)
assert out.checked_type == out.args[0].checked_type
assert out.args[1].checked_type == relay.ty.TensorType(tuple(), "float32")
assert out.args[2].checked_type == relay.ty.TensorType(tuple(), "float32")
assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")
def test_quantize_pass():
def quantize_weight(arr):
maximum = np.amax(np.abs(arr.asnumpy()))
scale = 2**math.ceil(math.log(maximum, 2))
out = np.around(arr.asnumpy() / scale * 128).astype('int8')
out = np.clip(out, -127, 127)
return relay.const(out, 'int8')
n, c, h, w = 1, 3, 224, 224
def make_graph(data):
weight = relay.var("conv_weight")
out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
out = relay.Function(relay.analysis.free_vars(out), out)
return out
def make_qgraph(data, weight):
out = data * relay.const(32.0)
out = relay.round(out)
out = relay.clip(out, a_min=-127, a_max=127)
out = out.astype('int8')
out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
padding=(1, 1), channels=c, out_dtype='int32')
out = out.astype('float32')
out = relay.multiply(out, relay.const(0.00024414062))
out = relay.Function(relay.analysis.free_vars(out), out)
return out
np.random.seed(42)
data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
graph = make_graph(data)
dataset, params = make_dataset(graph, 10)
with qtz.qconfig(skip_conv_layers=None, global_scale=4.0,
round_for_shift=False, store_lowbit_output=False):
qgraph0 = qtz.quantize(graph, params)
qgraph0 = run_infer_type(qgraph0)
conv_weight = quantize_weight(params['conv_weight'])
qgraph1 = make_qgraph(data, conv_weight)
qgraph1 = run_infer_type(qgraph1)
graph = relay.create_executor('graph')
res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
if __name__ == "__main__":
test_simulated_quantize()
test_quantize_pass()
#!/bin/bash
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
set -e
set -u
export PYTHONPATH=python:topi/python
# Rebuild cython
make cython3
rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
rm -rf topi/python/topi/*.pyc topi/python/topi/*/*.pyc topi/python/topi/*/*/*.pyc topi/python/topi/*/*/*/*.pyc
python3 -m nose -v topi/tests/python/nightly
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment