_annotate.py 12.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
#pylint: disable=unused-argument,inconsistent-return-statements
18 19
"""Internal module for registering attribute for annotation."""
from __future__ import absolute_import
20
import warnings
21 22

import topi
23
from ..._ffi.function import register_func
24
from .. import expr as _expr
25
from .. import analysis as _analysis
26 27 28
from .. import op as _op
from ..op import op as _reg
from ..base import register_relay_node
29 30 31
from . import _quantize
from .quantize import QAnnotateKind, current_qconfig, quantize_context
from .quantize import _forward_op
32 33 34 35 36 37 38 39 40 41 42


@_reg.register_compute("relay.op.annotation.simulated_quantize")
def simulated_quantize_compute(attrs, inputs, out_type, target):
    """Compiler for simulated_quantize."""
    assert len(inputs) == 4
    assert attrs.sign
    assert attrs.rounding == "round"

    data, scale, clip_min, clip_max = inputs

43 44 45
    if attrs.kind == QAnnotateKind.IDENTITY:
        return [topi.identity(data)]

46 47 48 49 50 51 52 53 54 55 56 57 58
    # simulate rounding error
    scaled_data = topi.divide(data, scale)
    clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
    round_data = topi.round(clipped_data)

    # recover data
    rdata = topi.multiply(round_data, scale)
    return [rdata]


_reg.register_schedule("relay.op.annotation.simulated_quantize",
                       _reg.schedule_injective)
_reg.register_pattern("relay.op.annotation.simulated_quantize",
59
                      _reg.OpPattern.ELEMWISE)
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110


@register_relay_node
class QAnnotateExpr(_expr.TempExpr):
    """A special kind of Expr for Annotating.

    Parameters
    ---------
    expr: Expr
        the original relay ir expr.

    kind: QAnnotateKind
        the kind of annotation field.
    """
    def __init__(self, expr, kind):
        self.__init_handle_by_constructor__(
            _quantize.make_annotate_expr, expr, kind)


def _get_expr_kind(anno):
    """Get the expression and QAnnotateKind from QAnnotateExpr or Expr"""
    if isinstance(anno, QAnnotateExpr):
        return anno.expr, anno.kind
    return anno, None


def register_annotate_function(op_name, frewrite=None, level=10):
    """register a rewrite function for operator, used by annotation.

    Parameters
    ---------
    op_name: str
        The name of operation

    frewrite : function, optional
        The function to be registered.

    level : int, optional
        The priority level
    """
    def default_rewrite(ref_call, new_args, ctx):
        # recover from QAnnotateExpr
        args = [_get_expr_kind(x)[0] for x in new_args]
        return _forward_op(ref_call, args)

    def _register(func):
        """internal register function"""
        def frewrite_with_guard(ref_call, new_args, ctx):
            if not current_qconfig().guard(ref_call):
                return default_rewrite(ref_call, new_args, ctx)
            return func(ref_call, new_args, ctx)
111
        _reg._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level)
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
        return frewrite_with_guard

    return _register(frewrite) if frewrite is not None else _register


def attach_simulated_quantize(data, kind, sign=True, rounding="round"):
    """Attach a simulated quantize operation after input data expr.

    Parameters
    ---------
    data: Expr
        the original data expr.

    kind: QAnnotateKind
        the kind of annotation field.
    """
128 129 130 131 132
    quantize_op = _op.get("relay.op.annotation.simulated_quantize")
    if isinstance(data, _expr.Call) and data.op == quantize_op:
        if data.attrs.kind == kind and data.attrs.sign == sign and data.attrs.rounding == rounding:
            return data

133
    qctx = quantize_context()
134
    key = tuple([data, kind, sign, rounding])
135 136
    if key in qctx.qnode_map:
        return qctx.qnode_map[key]
137

138 139 140
    dom_scale = _expr.var("dom_scale")
    clip_min = _expr.var("clip_min")
    clip_max = _expr.var("clip_max")
141
    qnode = _quantize.simulated_quantize(
142
        data, dom_scale, clip_min, clip_max, kind, sign, rounding)
143
    qctx.qnode_map[key] = qnode
144 145 146
    return qnode

register_func("relay.quantize.attach_simulated_quantize", attach_simulated_quantize)
147 148


149 150 151 152 153 154 155
@register_annotate_function("nn.contrib_conv2d_NCHWc")
def conv2d_nchwc_rewrite(ref_call, new_args, ctx):
    warnings.warn("NCHWc layout Conv2D detected, please use a lower "
                  "optimization level before applying the quantization "
                  "pass as quantization will have no effect here...")


156 157 158 159 160
@register_annotate_function("nn.conv2d")
def conv2d_rewrite(ref_call, new_args, ctx):
    """Rewrite function for conv2d. Lhs of conv will be quantized to
    input field, and rhs of conv will be quantized to weight field.
    Output would be in activation field"""
161 162
    if quantize_context().check_to_skip(ref_call):
        return None
163 164 165 166

    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
    rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

167
    if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
168 169 170 171 172 173
        lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

    assert rhs_kind is None
    rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)

    expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
174

175
    return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
176 177


178 179
# TODO(tmoreau89,ziheng) need to include an option to turn off dense quant
# @register_annotate_function("nn.dense")
180 181 182
def dense_rewrite(ref_call, new_args, ctx):
    """Rewrite function for dense. Lhs of dense will be quantized to input field, and rhs of
    dense will be quantized to weight field. Output would be in activation field."""
183
    if quantize_context().check_to_skip(ref_call):
184
        return None
185

186 187 188
    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
    rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

189
    if lhs_kind is None or lhs_kind == QAnnotateKind.ACTIVATION:
190 191 192 193 194 195
        lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)

    assert rhs_kind is None
    rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)

    expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
196

197
    return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
198 199 200 201 202


@register_annotate_function("multiply")
def multiply_rewrite(ref_call, new_args, ctx):
    """Rewrite function for multiply."""
203
    if quantize_context().check_to_skip(ref_call):
204 205 206 207 208 209 210
        return None

    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
    rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

    if lhs_kind is None and rhs_kind is None:
        return None
211

212
    if lhs_kind in [QAnnotateKind.ACTIVATION, QAnnotateKind.INPUT] and rhs_kind is None:
213
        # quantize lhs to INPUT field
214 215
        if lhs_kind == QAnnotateKind.ACTIVATION:
            lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
216 217 218 219
        # quantize rhs to WEIGHT field
        rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
        expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
        return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
220

221 222 223 224 225 226
    raise ValueError


@register_annotate_function("add")
def add_rewrite(ref_call, new_args, ctx):
    """Rewrite function for add."""
227
    if quantize_context().check_to_skip(ref_call):
228 229 230 231 232 233
        return None

    lhs_expr, lhs_kind = _get_expr_kind(new_args[0])
    rhs_expr, rhs_kind = _get_expr_kind(new_args[1])

    if lhs_kind is None and rhs_kind is None:
234
        # trivial case
235
        return None
236

237 238
    if lhs_kind is None and rhs_kind is not None:
        # quantize lhs to INPUT field if it is normal expression
239
        assert rhs_kind in [QAnnotateKind.INPUT, QAnnotateKind.ACTIVATION]
240
        lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
241 242 243
        expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
        return QAnnotateExpr(expr, QAnnotateKind.INPUT)

244
    if lhs_kind is not None and rhs_kind is None:
245 246
        if _analysis.check_constant(rhs_expr):
            # - introduced by batch_norm: add(out, const)
247 248 249
            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
        else:
            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
250 251
        expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
        return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
252 253 254 255 256 257 258 259 260

    if lhs_kind is not None and rhs_kind is not None:
        if lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.INPUT:
            expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
            return QAnnotateExpr(expr, QAnnotateKind.INPUT)
        if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.ACTIVATION:
            rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
            expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
            return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
261 262
        if (lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind == QAnnotateKind.INPUT) or \
            (lhs_kind == QAnnotateKind.INPUT and rhs_kind == QAnnotateKind.ACTIVATION):
263 264 265
            expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
            return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
    raise ValueError()
266 267 268 269


def identity_rewrite(ref_call, new_args, ctx):
    """Simply forward the original operation"""
270
    if quantize_context().check_to_skip(ref_call):
271 272 273 274 275 276 277 278 279 280
        return None

    x_expr, x_kind = _get_expr_kind(new_args[0])
    if x_kind is None:
        return None

    ret_expr = _forward_op(ref_call, [x_expr])
    return QAnnotateExpr(ret_expr, x_kind)


281
register_annotate_function("clip", identity_rewrite)
282 283 284
register_annotate_function("nn.relu", identity_rewrite)
register_annotate_function("strided_slice", identity_rewrite)
register_annotate_function("nn.avg_pool2d", identity_rewrite)
285
register_annotate_function("annotation.stop_fusion", identity_rewrite)
286 287 288 289


def pool2d_rewrite(ref_call, new_args, ctx):
    """Rewrite function for max pool2d"""
290
    if quantize_context().check_to_skip(ref_call):
291
        return None
292

293 294 295 296 297 298
    expr, x_kind = _get_expr_kind(new_args[0])

    if x_kind is None:
        return None
    if x_kind == QAnnotateKind.ACTIVATION:
        expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)
299

300 301 302 303 304 305 306
    expr = _forward_op(ref_call, [expr])
    return QAnnotateExpr(expr, QAnnotateKind.INPUT)


register_annotate_function("nn.max_pool2d", pool2d_rewrite)


307 308
@register_annotate_function("annotation.cast_hint")
def cast_hint_rewrite(ref_call, new_args, ctx):
309 310 311
    """Rewrite function to force cast"""
    expr, x_kind = _get_expr_kind(new_args[0])

312 313 314
    if quantize_context().check_to_skip(ref_call):
        return expr

315 316 317 318 319 320 321 322 323
    if x_kind is None:
        return new_args[0]
    if x_kind == QAnnotateKind.ACTIVATION:
        expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT)

    expr = _forward_op(ref_call, [expr])
    return QAnnotateExpr(expr, QAnnotateKind.INPUT)


324 325 326
@register_annotate_function("concatenate")
def concatenate_rewrite(ref_call, new_args, ctx):
    """Rewrite function for concatenate"""
327
    if quantize_context().check_to_skip(ref_call):
328 329 330 331 332 333 334 335
        return None

    input_tuple = new_args[0]
    expr_list = [_get_expr_kind(x)[0] for x in input_tuple]
    kind_list = [_get_expr_kind(x)[1] for x in input_tuple]

    # make sure the inputs of concatenate are all normal
    # expression or annotate expression
336
    if all([k is None for k in kind_list]):
337
        return None
338 339 340
    for i, k in enumerate(kind_list):
        if k is None:
            expr_list[i] = attach_simulated_quantize(expr_list[i], QAnnotateKind.ACTIVATION)
341 342
    expr = _forward_op(ref_call, [_expr.Tuple(expr_list)])
    return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
343 344


345 346 347 348 349
@register_annotate_function("nn.global_avg_pool2d")
def global_avg_pool2d_rewrite(ref_call, new_args, ctx):
    """Rewrite function for global_avg_pool2d for stopping quantize"""
    if quantize_context().check_to_skip(ref_call):
        return None
350

351
    expr, x_kind = _get_expr_kind(new_args[0])
352

353 354 355 356 357 358 359
    if x_kind is None:
        return None
    expr = _forward_op(ref_call, [new_args[0].realize()])

    # stop quantize after global_avg_pool2d
    quantize_context().stop_quantize()
    return expr