kl_divergence.py 2.13 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Find optimal scale for quantization by minimizing KL-divergence"""

19
import ctypes
20 21
import numpy as np

22
from . import _quantize
23 24


25 26
def _find_scale_by_kl(arr, quantized_dtype='int8',
                      num_bins=8001, num_quantized_bins=255):
27 28 29 30 31 32 33 34 35 36
    """Given a tensor, find the optimal threshold for quantizing it.
    The reference distribution is `q`, and the candidate distribution is `p`.
    `q` is a truncated version of the original distribution.

    Ref:
    http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
    """
    assert isinstance(arr, np.ndarray)
    min_val = np.min(arr)
    max_val = np.max(arr)
37
    thres = max(abs(min_val), abs(max_val))
38 39 40 41 42

    if min_val >= 0 and quantized_dtype in ['uint8']:
        # We need to move negative bins to positive bins to fit uint8 range.
        num_quantized_bins = num_quantized_bins * 2 + 1

43 44 45
    def get_pointer(arr, ctypes_type):
        ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes_type))
        return ctypes.cast(ptr, ctypes.c_void_p)
46

47 48 49
    hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-thres, thres))
    hist_ptr = get_pointer(hist.astype(np.int32), ctypes.c_int)
    hist_edges_ptr = get_pointer(hist_edges, ctypes.c_float)
50

51 52
    return _quantize.FindScaleByKLMinimization(hist_ptr, hist_edges_ptr,
                                               num_bins, num_quantized_bins)