_annotate.py 12.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
#pylint: disable=unused-argument,inconsistent-return-statements
18 19
"""Internal module for registering attribute for annotation."""
from __future__ import absolute_import
20
import warnings
21 22

import topi
23
from ..._ffi.function import register_func
24
from .. import expr as _expr
25
from .. import analysis as _analysis
26 27 28
from .. import op as _op
from ..op import op as _reg
from ..base import register_relay_node
29 30 31
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig, quantize_context
from .quantize import _forward_op
32 33 34 35 36 37 38 39 40 41 42


@_reg.register_compute("relay.op.annotation.simulated_quantize")
def simulated_quantize_compute(attrs, inputs, out_type, target):
    """Compiler for simulated_quantize."""
    assert len(inputs) == 4
    assert attrs.sign
    assert attrs.rounding == "round"

    data, scale, clip_min, clip_max = inputs

43 44 45
    if attrs.kind == QAnnotateKind.IDENTITY:
        return [topi.identity(data)]

46 47 48 49 50 51 52 53 54 55 56 57 58
    # simulate rounding error
    scaled_data = topi.divide(data, scale)
    clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
    round_data = topi.round(clipped_data)

    # recover data
    rdata = topi.multiply(round_data, scale)
    return [rdata]


_reg.register_schedule("relay.op.annotation.simulated_quantize",
                       _reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize",
59
                      _reg.OpPattern.ELEMWISE)
60
_reg.register_schedule("annotation.cast_hint", _reg.schedule_injective)
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111


@register_relay_node
class QAnnotateExpr(_expr.TempExpr):
    """A special kind of Expr for Annotating.

    Parameters
    ---------
    expr: Expr
        the original relay ir expr.

    kind: QAnnotateKind
        the kind of annotation field.
    """
    def __init__(self, expr, kind):
        self.__init_handle_by_constructor__(
            _quantize.make_annotate_expr, expr, kind)


def _get_expr_kind(anno):
    """Get the expression and QAnnotateKind from QAnnotateExpr or Expr"""
    if isinstance(anno, QAnnotateExpr):
        return anno.expr, anno.kind
    return anno, None


def register_annotate_function(op_name, frewrite=None, level=10):
    """register a rewrite function for operator, used by annotation.

    Parameters
    ---------
    op_name: str
        The name of operation

    frewrite : function, optional
        The function to be registered.

    level : int, optional
        The priority level
    """
    def default_rewrite(ref_call, new_args, ctx):
        # recover from QAnnotateExpr
        args = [_get_expr_kind(x)[0] for x in new_args]
        return _forward_op(ref_call, args)

    def _register(func):
        """internal register function"""
        def frewrite_with_guard(ref_call, new_args, ctx):
            if not current_qconfig().guard(ref_call):
                return default_rewrite(ref_call, new_args, ctx)
            return func(ref_call, new_args, ctx)
112
        _reg._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
        return frewrite_with_guard

    return _register(frewrite) if frewrite is not None else _register


def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
    """Attach a simulated quantize operation after input data expr.

    Parameters
    ---------
    data: Expr
        the original data expr.

    kind: QAnnotateKind
        the kind of annotation field.
    """
129 130 131 132 133
    quantize_op = _op.get("relay.op.annotation.simulated_quantize")
    if isinstance(data, _expr.Call) and data.op == quantize_op:
        if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding:
            return data

134
    qctx = quantize_context()
135
    key = tuple([data, kind, sign, rounding])
136 137
    if key in qctx.qnode_map:
        return qctx.qnode_map[key]
138

139 140 141
    dom_scale = _expr.var("dom_scale")
    clip_min = _expr.var("clip_min")
    clip_max = _expr.var("clip_max")
142
    qnode = _quantize.simulated_quantize(
143
        data, dom_scale, clip_min, clip_max, kind, sign, rounding)
144
    qctx.qnode_map[key] = qnode
145 146 147
    return qnode

register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
148 149


150 151 152 153 154 155 156
@register_annotate_function("nn.contrib_conv2d_NCHWc")
def conv2d_nchwc_rewrite(ref_call, new_args, ctx):
    warnings.warn("NCHWc layout Conv2D detected, please use a lower "
                  "optimization level before applying the quantization "
                  "pass as quantization will have no effect here...")


157 158 159 160 161
@register_annotate_function("nn.conv2d")
def conv2d_rewrite(ref_call, new_args, ctx):
    """Rewrite function for conv2d. Lhs of conv will be quantized to
    input field, and rhs of conv will be quantized to weight field.
    Output would be in activation field"""
162 163
    if quantize_context().check_to_skip(ref_call):
        return None
164 165 166 167

    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
    rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

168
    if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
169 170 171 172 173 174
        lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

    assert rhs_kind is None
    rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)

    expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
175

176
    return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
177 178


179 180
# TODO(tmoreau89,ziheng) need to include an option to turn off dense quant
# @register_annotate_function("nn.dense")
181 182 183
def dense_rewrite(ref_call, new_args, ctx):
    """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
    dense will be quantized to weight field. Output would be in activation field."""
184
    if quantize_context().check_to_skip(ref_call):
185
        return None
186

187 188 189
    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
    rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

190
    if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
191 192 193 194 195 196
        lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

    assert rhs_kind is None
    rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)

    expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
197

198
    return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
199 200 201 202 203


@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
    """Rewrite function for multiply."""
204
    if quantize_context().check_to_skip(ref_call):
205 206 207 208 209 210 211
        return None

    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
    rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

    if lhs_kind is None and rhs_kind is None:
        return None
212

213
    if lhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and rhs_kind is None:
214
        # quantize lhs to INPUT field
215 216
        if lhs_kind == QAnnotateKind.ACTIVATION:
            lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
217 218 219 220
        if _analysis.check_constant(rhs_expr):
            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
        else:
            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
221 222
        expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
        return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
223

224 225 226 227 228 229
    raise ValueError


@register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx):
    """Rewrite function for add."""
230
    if quantize_context().check_to_skip(ref_call):
231 232 233 234 235 236
        return None

    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
    rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

    if lhs_kind is None and rhs_kind is None:
237
        # trivial case
238
        return None
239

240 241
    if lhs_kind is None and rhs_kind is not None:
        # quantize lhs to INPUT field if it is normal expression
242
        assert rhs_kind in [QAnnotateKind.INPUT, QAnnotateKind.ACTIVATION]
243
        lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
244 245 246
        expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
        return QAnnotateExpr(expr, QAnnotateKind.INPUT)

247
    if lhs_kind is not None and rhs_kind is None:
248 249
        if _analysis.check_constant(rhs_expr):
            # - introduced by batch_norm: add(out, const)
250 251 252
            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
        else:
            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
253 254
        expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
        return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
255 256 257 258 259 260 261 262 263

    if lhs_kind is not None and rhs_kind is not None:
        if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT:
            expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
            return QAnnotateExpr(expr, QAnnotateKind.INPUT)
        if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
            expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
            return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
264 265
        if (lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT) or \
            (lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION):
266 267 268
            expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
            return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
    raise ValueError()
269 270 271 272


def identity_rewrite(ref_call, new_args, ctx):
    """Simply forward the original operation"""
273
    if quantize_context().check_to_skip(ref_call):
274 275 276 277 278 279 280 281 282 283
        return None

    x_expr, x_kind = _get_expr_kind(new_args[0])
    if x_kind is None:
        return None

    ret_expr = _forward_op(ref_call, [x_expr])
    return QAnnotateExpr(ret_expr, x_kind)


284
register_annotate_function("clip", identity_rewrite)
285 286 287
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
register_annotate_function("nn.avg_pool2d", identity_rewrite)
288
register_annotate_function("annotation.stop_fusion", identity_rewrite)
289 290 291 292


def pool2d_rewrite(ref_call, new_args, ctx):
    """Rewrite function for max pool2d"""
293
    if quantize_context().check_to_skip(ref_call):
294
        return None
295

296 297 298 299 300 301
    expr, x_kind = _get_expr_kind(new_args[0])

    if x_kind is None:
        return None
    if x_kind == QAnnotateKind.ACTIVATION:
        expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
302

303 304 305 306 307 308 309
    expr = _forward_op(ref_call, [expr])
    return QAnnotateExpr(expr, QAnnotateKind.INPUT)


register_annotate_function("nn.max_pool2d", pool2d_rewrite)


310 311
@register_annotate_function("annotation.cast_hint")
def cast_hint_rewrite(ref_call, new_args, ctx):
312 313 314
    """Rewrite function to force cast"""
    expr, x_kind = _get_expr_kind(new_args[0])

315 316 317
    if quantize_context().check_to_skip(ref_call):
        return expr

318 319 320 321 322 323 324 325 326
    if x_kind is None:
        return new_args[0]
    if x_kind == QAnnotateKind.ACTIVATION:
        expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)

    expr = _forward_op(ref_call, [expr])
    return QAnnotateExpr(expr, QAnnotateKind.INPUT)


327 328 329
@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
    """Rewrite function for concatenate"""
330
    if quantize_context().check_to_skip(ref_call):
331 332 333 334 335 336 337 338
        return None

    input_tuple = new_args[0]
    expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
    kind_list = [_get_expr_kind(x)[1] for x in input_tuple]

    # make sure the inputs of concatenate are all normal
    # expression or annotate expression
339
    if all([k is None for k in kind_list]):
340
        return None
341 342 343
    for i, k in enumerate(kind_list):
        if k is None:
            expr_list[i] = attach_simulated_quantize(expr_list[i], QAnnotateKind.ACTIVATION)
344 345
    expr = _forward_op(ref_call, [_expr.Tuple(expr_list)])
    return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
346 347


348 349 350 351 352
@register_annotate_function("nn.global_avg_pool2d")
def global_avg_pool2d_rewrite(ref_call, new_args, ctx):
    """Rewrite function for global_avg_pool2d for stopping quantize"""
    if quantize_context().check_to_skip(ref_call):
        return None
353

354
    expr, x_kind = _get_expr_kind(new_args[0])
355

356 357 358 359 360 361 362
    if x_kind is None:
        return None
    expr = _forward_op(ref_call, [new_args[0].realize()])

    # stop quantize after global_avg_pool2d
    quantize_context().stop_quantize()
    return expr