# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Find optimal scale for quantization by minimizing KL-divergence""" import ctypes import numpy as np from . import _quantize def _find_scale_by_kl(arr, quantized_dtype='int8', num_bins=8001, num_quantized_bins=255): """Given a tensor, find the optimal threshold for quantizing it. The reference distribution is `q`, and the candidate distribution is `p`. `q` is a truncated version of the original distribution. Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf """ assert isinstance(arr, np.ndarray) min_val = np.min(arr) max_val = np.max(arr) thres = max(abs(min_val), abs(max_val)) if min_val >= 0 and quantized_dtype in ['uint8']: # We need to move negative bins to positive bins to fit uint8 range. num_quantized_bins = num_quantized_bins * 2 + 1 def get_pointer(arr, ctypes_type): ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes_type)) return ctypes.cast(ptr, ctypes.c_void_p) hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-thres, thres)) hist_ptr = get_pointer(hist.astype(np.int32), ctypes.c_int) hist_edges_ptr = get_pointer(hist_edges, ctypes.c_float) return _quantize.FindScaleByKLMinimization(hist_ptr, hist_edges_ptr, num_bins, num_quantized_bins)