Commit 217de88a by Zhihong Ma

fix: LSTM-OCR fp32 & naive fake quantization

parent 70477dab
......@@ -15,8 +15,8 @@ class seq_mnist(Dataset):
self.labels = []
self.input_lengths = np.ones(1, dtype=np.int32) * (28 * self.trainer_params.word_size)
self.label_lengths = np.ones(1, dtype=np.int32) * (self.trainer_params.word_size)
self.build_dataset()
# self.load_dataset()
# self.build_dataset()
self.load_dataset()
def build_dataset(self):
imgs = []
......@@ -57,12 +57,13 @@ class seq_mnist(Dataset):
def load_dataset(self):
self.images = np.load('data/images{}.npy'.format(self.suffix))
self.labels = np.load('data/labels{}.npy'.format(self.suffix))
print("Successfully load dataset!")
if self.trainer_params.quantize_input:
self.images = self.quantize_tensor_image(self.images)
self.images = np.asarray(self.images)
# 这里无需单独做
# 这里无需单独做, ptq时修改
def quantize_tensor_image(self, tensor_image):
frac_bits = self.trainer_params.recurrent_activation_bit_width-1
prescale = 2**frac_bits
......
import sys
import os
# 从get_param.py输出重定向文件val.txt中提取参数量和计算量
def extract_ratio(md='ResNet18'):
fr = open('param_flops_' + md + '.txt','r')
lines = fr.readlines()
layer = []
par_ratio = []
flop_ratio = []
for line in lines:
# if '(' in line and ')' in line:
if 'Conv' in line or 'BatchNorm2d' in line or 'Linear' in line:
layer.append(line.split(':')[1].split('(')[0])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
par_ratio.append(r1)
r2 = line.split('%')[-2].split(',')[-1]
r2 = float(r2)
flop_ratio.append(r2)
return layer, par_ratio, flop_ratio
if __name__ == "__main__":
layer, par_ratio, flop_ratio = extract_ratio()
print(len(layer))
print(len(par_ratio))
print(len(flop_ratio))
\ No newline at end of file
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
\ No newline at end of file
from model import *
import torch
from ptflops import get_model_complexity_info
import argparse
def get_children(model: torch.nn.Module):
# get children form model!
# 为了后续也能够更新参数,需要用nn.ModuleList来承载
# children = nn.ModuleList(model.children())
# print(children)
# 方便对其中的module进行后续的更新
# flatt_children = nn.ModuleList()
children = list(model.children())
# flatt_children = nn.ModuleList()
flatt_children = []
if len(children) == 0:
# if model has no children; model is last child! :O
return model
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
# print(flatt_children)
return flatt_children
# 定义获取不包含wrapper的所有子模块的函数
def get_all_child_modules(module):
for name, child in module.named_children():
if isinstance(child, nn.Sequential):
yield from get_all_child_modules(child)
elif len(list(child.children())) > 0:
yield from child.children()
else:
yield child
def filter_fn(module, n_inp, outp_shape):
# if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU,torch.nn.BatchNorm2d,torch.nn.Linear,torch.nn.AdaptiveAvgPool2d)):
if 'conv' in module or 'bn' in module or 'fc' in module or 'avg' in module or 'relu' in module:
return True
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Analysis --- params & flops')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
args = parser.parse_args()
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
full_file = 'ckpt/cifar10_' + args.model + '.pt'
model.load_state_dict(torch.load(full_file))
# flat = get_children(model)
# print(flat)
# flat = get_children(model)
# new_model = nn.Sequential(*flat)
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
class GlobalVariables:
SELF_INPLANES = 0
\ No newline at end of file
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
......@@ -36,7 +36,7 @@ def non_or_str(value):
if __name__ == '__main__':
# Training settings
parser = argparse.ArgumentParser(description='PyTorch Quantized BiLSTM Sequential MNIST Example')
parser = argparse.ArgumentParser(description='PyTorch BiLSTM Sequential MNIST Example')
parser.add_argument('--params', '-p', type=str, default="default_trainer_params.json", help='Path to params JSON file. Default ignored when resuming.')
# 这里是可以改的,原版本应该是支持多机训练
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
......@@ -69,18 +69,7 @@ if __name__ == '__main__':
parser.add_argument('--reduce_bidirectional', type=str)
parser.add_argument('--recurrent_bias_enabled', type=bool)
parser.add_argument('--checkpoint_interval', type=int)
parser.add_argument('--recurrent_weight_bit_width', type=int)
parser.add_argument('--recurrent_weight_quantization', type=str)
parser.add_argument('--recurrent_bias_bit_width', type=int)
parser.add_argument('--recurrent_bias_quantization', type=str)
parser.add_argument('--recurrent_activation_bit_width', type=int)
parser.add_argument('--recurrent_activation_quantization', type=str)
parser.add_argument('--internal_activation_bit_width', type=int)
parser.add_argument('--fc_weight_bit_width', type=int)
parser.add_argument('--fc_weight_quantization', type=str)
parser.add_argument('--fc_bias_bit_width', type=int)
parser.add_argument('--fc_bias_quantization', type=str)
parser.add_argument('--quantize_input', type=bool)
args = parser.parse_args()
if args.export:
......@@ -92,18 +81,23 @@ if __name__ == '__main__':
# 直接恢复
if (args.resume or args.eval or args.export) and args.params == "default_trainer_params.json":
package = torch.load(args.resume, map_location=lambda storage, loc: storage)
trainer_params = package['trainer_params']
# if (args.resume or args.eval) and args.params == "default_trainer_params.json":
# package = torch.load(args.resume, map_location=lambda storage, loc: storage)
# trainer_params = package['trainer_params']
# 重新训练
else:
# else:
with open(args.params) as d:
trainer_params = json.load(d, object_hook=ascii_encode_dict)
# trainer_params = json.load(d, object_hook=ascii_encode_dict)
trainer_params = json.load(d)
trainer_params = objdict(trainer_params)
for k in trainer_params.keys():
print(k, trainer_params[k])
if trainer_params[k] == 'LSTM':
print("LSTM YES")
elif trainer_params[k] == 'CONCAT':
print("CONCAT YES")
# args还是有用的,trainer_params中的default的和args中关注的参数往往是互补的
trainer = Seq_MNIST_Trainer(trainer_params, args)
......
......@@ -32,7 +32,7 @@ import math
import numpy
import torch
import torch.nn as nn
from module import *
from functools import partial
# from quantization.modules.rnn import QuantizedLSTM
......@@ -90,13 +90,13 @@ class BiLSTM(nn.Module):
def __init__(self, trainer_params):
super(BiLSTM, self).__init__()
self.trainer_params = trainer_params
# print(self.trainer_params.reduce_bidirectional)
self.trainer_params.reduce_bidirectional = 'CONCAT'
print(f"self.trainer_params.reduce_bidirectional:{self.trainer_params.reduce_bidirectional}")
# self.trainer_params.reduce_bidirectional = 'CONCAT'
if self.trainer_params.bidirectional and self.trainer_params.reduce_bidirectional == 'CONCAT':
self.reduce_factor = 2
else:
self.reduce_factor = 1
# if self.trainer_params.bidirectional and self.trainer_params.reduce_bidirectional == 'CONCAT':
# self.reduce_factor = 2
# else:
# self.reduce_factor = 1
# 若是 LSTM ,则括号中的是对类的输入设置
# self.recurrent_layer = self.recurrent_layer_type(input_size=self.trainer_params.input_size,
# hidden_size=self.trainer_params.num_units,
......@@ -105,39 +105,59 @@ class BiLSTM(nn.Module):
# bidirectional=self.trainer_params.bidirectional,
# bias=self.trainer_params.recurrent_bias_enabled)
self.recurrent_layer = nn.LSTM(input_size=self.trainer_params.input_size,
# self.recurrent_layer = nn.LSTM(input_size=self.trainer_params.input_size,
# hidden_size=self.trainer_params.num_units,
# num_layers=self.trainer_params.num_layers,
# batch_first=False,
# bidirectional=self.trainer_params.bidirectional,
# bias=self.trainer_params.recurrent_bias_enabled)
self.lstm_layers = nn.ModuleList()
# 创建第1层LSTM模型,并添加到ModuleList中
lstm = nn.LSTM( input_size=self.trainer_params.input_size,
hidden_size=self.trainer_params.num_units,
num_layers=1,
batch_first=False,
bidirectional=self.trainer_params.bidirectional,
bias=self.trainer_params.recurrent_bias_enabled)
self.lstm_layers.append(lstm)
# 创建第2至num_layers层LSTM模型,并添加到ModuleList中
for i in range(1, self.trainer_params.num_layers):
lstm = nn.LSTM(input_size=self.trainer_params.num_units * 2 if self.trainer_params.bidirectional else self.trainer_params.num_units,
hidden_size=self.trainer_params.num_units,
num_layers=self.trainer_params.num_layers,
num_layers=1,
batch_first=False,
bidirectional=self.trainer_params.bidirectional,
bias=self.trainer_params.recurrent_bias_enabled)
self.lstm_layers.append(lstm)
self.batch_norm_fc = FusedBatchNorm1dLinear(
trainer_params,
nn.BatchNorm1d(self.reduce_factor * self.trainer_params.num_units),
# QuantizedLinear(
# bias=True,
# self.batch_norm_fc = FusedBatchNorm1dLinear(
# trainer_params,
# nn.BatchNorm1d(self.reduce_factor * self.trainer_params.num_units),
# nn.Linear(
# in_features=self.reduce_factor * self.trainer_params.num_units,
# out_features=trainer_params.num_classes,
# bias_bit_width=self.trainer_params.fc_bias_bit_width,
# bias_q_type=self.trainer_params.fc_bias_quantization,
# weight_bit_width=self.trainer_params.fc_weight_bit_width,
# weight_q_type=self.trainer_params.fc_weight_quantization)
nn.Linear(
# bias=True )
# )
self.fc1 = nn.Linear(
in_features=self.reduce_factor * self.trainer_params.num_units,
out_features=trainer_params.num_classes,
bias=True )
)
# self.output_layer = nn.Sequential(SequenceWise(self.batch_norm_fc), nn.LogSoftmax(dim=2))
self.output_layer = nn.Sequential(SequenceWise(self.fc1), nn.LogSoftmax(dim=2))
self.output_layer = nn.Sequential(SequenceWise(self.batch_norm_fc), nn.LogSoftmax(dim=2))
# @property
# def reduce_factor(self):
# if self.trainer_params.bidirectional and self.trainer_params.reduce_bidirectional == 'CONCAT':
# return 2
# else:
# return 1
@property
def reduce_factor(self):
if self.trainer_params.bidirectional and self.trainer_params.reduce_bidirectional == 'CONCAT':
return 2
else:
return 1
# @property
# def recurrent_layer_type(self):
......@@ -163,7 +183,24 @@ class BiLSTM(nn.Module):
def forward(self, x):
# 似乎是因为现在只有一个lstm cell (num_layers = 1),所以h没用上
x, h = self.recurrent_layer(x)
# x, h = self.recurrent_layer(x)
h_n = []
c_n = []
# 遍历ModuleList中的每个LSTM模型,依次进行前向计算
for i, lstm in enumerate(self.lstm_layers):
# 如果不是第1层LSTM,则将输入的隐藏状态和细胞状态作为该层LSTM的初始状态
if i > 0:
x, (h, c) = lstm(x, (h_n[-1], c_n[-1]))
else:
x, (h, c) = lstm(x)
# 将该层LSTM的隐藏状态和细胞状态添加到列表中,用于下一层LSTM的输入
h_n.append(h)
c_n.append(c)
if self.trainer_params.bidirectional:
if self.trainer_params.reduce_bidirectional == 'SUM':
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1)
......@@ -175,6 +212,78 @@ class BiLSTM(nn.Module):
x = self.output_layer(x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
self.qlstm_layers = nn.ModuleDict()
for i, lstm in enumerate(self.lstm_layers):
# 如果不是第1层LSTM,则将输入的隐藏状态和细胞状态作为该层LSTM的初始状态
if i > 0:
self.qlstm_layers[str(i)] = QLSTM(quant_type=quant_type,lstm_module=lstm,qix=False,qih=False,qic=False,qox=True,qoh=True,qoc=True,num_bits=num_bits,e_bits=e_bits)
# 第一层lstm layer没有输入的h和c,因此qih,qic为False,有x,qix置为True
else:
self.qlstm_layers[str(i)] = QLSTM(quant_type=quant_type,lstm_module=lstm,qix=True,qih=False,qic=False,qox=True,qoh=True,qoc=True,num_bits=num_bits,e_bits=e_bits)
self.qfc1 = QLinear(quant_type, self.fc1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
# for name,layer in self.qlstm_layers.items():
# print(f"name:{name}")
def quantize_forward(self, x):
for name, layer in self.qlstm_layers.items():
if '0' in name:
x,(h,c) = layer(x)
else:
x,(h,c) = layer(x,h,c)
t, n = x.size(0), x.size(1)
x = x.view(t * n, -1)
x = self.qfc1(x)
x = x.view(t, n, -1)
x = F.log_softmax(x,dim=2)
# out = F.softmax(x, dim=1)
# return out
return x
def freeze(self):
for name, layer in self.qlstm_layers.items():
if '0' in name:
layer.freeze(flag=0)
else:
layer.freeze(qix = self.qlstm_layers[str(int(name)-1)].qox, qih=self.qlstm_layers[str(int(name)-1)].qoh, qic=self.qlstm_layers[str(int(name)-1)].qoc,flag=1)
self.qfc1.freeze(qi=self.qlstm_layers[name].qox)
def quantize_inference(self, x):
# 首先对x进行一个伪量化 (适配于lstm的伪量化)
x = FakeQuantize.apply(x,self.qlstm_layers['0'].qix)
for name, layer in self.qlstm_layers.items():
if '0' in name:
x,(h,c) = layer.quantize_inference(x)
else:
x,(h,c) = layer.quantize_inference(x,h,c)
t, n = x.size(0), x.size(1)
x = x.view(t * n, -1)
# 经过修改后的QLinear的quantize_inference中对输入的x进行过quantize,因此在这里需要dequantize一下.
x = self.qfc1.quantize_inference(x)
x = self.qfc1.qo.dequantize_tensor(x)
x = x.view(t, n, -1)
x = F.log_softmax(x,dim=2)
return x
def export(self, output_path, simd_factor, pe):
if self.trainer_params.neuron_type == 'QLSTM':
assert(self.trainer_params.input_size % simd_factor == 0)
......
import math
import numpy as np
import gol
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from function import FakeQuantize
# 获取最近的量化值
# def get_nearest_val(quant_type,x,is_bias=False):
# if quant_type=='INT':
# return x.round_()
# plist = gol.get_value(is_bias)
# # print('get')
# # print(plist)
# # x = x / 64
# shape = x.shape
# xhard = x.view(-1)
# plist = plist.type_as(x)
# # 取最近幂次作为索引
# idx = (xhard.unsqueeze(0) - plist.unsqueeze(1)).abs().min(dim=0)[1]
# xhard = plist[idx].view(shape)
# xout = (xhard - x).detach() + x
# # xout = xout * 64
# return xout
def get_nearest_val(quant_type, x, is_bias=False, block_size=1000000):
if quant_type == 'INT':
return x.round_()
plist = gol.get_value(is_bias)
shape = x.shape
# xhard = x.view(-1)
xhard = x.reshape(-1)
xout = torch.zeros_like(xhard)
plist = plist.type_as(x)
n_blocks = (x.numel() + block_size - 1) // block_size
for i in range(n_blocks):
start_idx = i * block_size
end_idx = min(start_idx + block_size, xhard.numel())
block_size_i = end_idx - start_idx
# print(x.numel())
# print(block_size_i)
# print(start_idx)
# print(end_idx)
xblock = xhard[start_idx:end_idx]
# xblock = xblock.view(shape[start_idx:end_idx])
plist_block = plist.unsqueeze(1) #.expand(-1, block_size_i)
idx = (xblock.unsqueeze(0) - plist_block).abs().min(dim=0)[1]
# print(xblock.shape)
xhard_block = plist[idx].view(xblock.shape)
xout[start_idx:end_idx] = (xhard_block - xblock).detach() + xblock
# xout = xout.view(shape)
xout = xout.reshape(shape)
return xout
# 采用对称有符号量化时,获取量化范围最大值
def get_qmax(quant_type,num_bits=None, e_bits=None):
if quant_type == 'INT':
qmax = 2. ** (num_bits - 1) - 1
elif quant_type == 'POT':
qmax = 1
else: #FLOAT
m_bits = num_bits - 1 - e_bits
dist_m = 2 ** (-m_bits)
e = 2 ** (e_bits - 1)
expo = 2 ** e
m = 2 ** m_bits -1
frac = 1. + m * dist_m
qmax = frac * expo
return qmax
# 都采用有符号量化,zeropoint都置为0
def calcScaleZeroPoint(min_val, max_val, qmax):
scale = torch.max(max_val.abs(),min_val.abs()) / qmax
zero_point = torch.tensor(0.)
return scale, zero_point
# 将输入进行量化,输入输出都为tensor
def quantize_tensor(quant_type, x, scale, zero_point, qmax, is_bias=False):
# 量化后范围,直接根据位宽确定
qmin = -qmax
q_x = zero_point + x / scale
q_x.clamp_(qmin, qmax)
q_x = get_nearest_val(quant_type, q_x, is_bias)
return q_x
# bias使用不同精度,需要根据量化类型指定num_bits/e_bits
def bias_qmax(quant_type):
if quant_type == 'INT':
return get_qmax(quant_type, 64)
elif quant_type == 'POT':
return get_qmax(quant_type)
else:
return get_qmax(quant_type, 16, 7)
# 转化为FP32,不需再做限制
def dequantize_tensor(q_x, scale, zero_point):
return scale * (q_x - zero_point)
class QParam(nn.Module):
def __init__(self,quant_type, num_bits=8, e_bits=3):
super(QParam, self).__init__()
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.qmax = get_qmax(quant_type, num_bits, e_bits)
scale = torch.tensor([], requires_grad=False)
zero_point = torch.tensor([], requires_grad=False)
min = torch.tensor([], requires_grad=False)
max = torch.tensor([], requires_grad=False)
# 通过注册为register,使得buffer可以被记录到state_dict
self.register_buffer('scale', scale)
self.register_buffer('zero_point', zero_point)
self.register_buffer('min', min)
self.register_buffer('max', max)
# 更新统计范围及量化参数
def update(self, tensor):
if self.max.nelement() == 0 or self.max.data < tensor.max().data:
self.max.data = tensor.max().data
self.max.clamp_(min=0)
if self.min.nelement() == 0 or self.min.data > tensor.min().data:
self.min.data = tensor.min().data
self.min.clamp_(max=0)
self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, self.qmax)
def quantize_tensor(self, tensor):
return quantize_tensor(self.quant_type, tensor, self.scale, self.zero_point, self.qmax)
def dequantize_tensor(self, q_x):
return dequantize_tensor(q_x, self.scale, self.zero_point)
# 该方法保证了可以从state_dict里恢复
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
key_names = ['scale', 'zero_point', 'min', 'max']
for key in key_names:
value = getattr(self, key)
value.data = state_dict[prefix + key].data
state_dict.pop(prefix + key)
# 该方法返回值将是打印该对象的结果
def __str__(self):
info = 'scale: %.10f ' % self.scale
info += 'zp: %.6f ' % self.zero_point
info += 'min: %.6f ' % self.min
info += 'max: %.6f' % self.max
return info
# 作为具体量化层的父类,qi和qo分别为量化输入/输出
class QModule(nn.Module):
def __init__(self,quant_type, qi=True, qo=True, num_bits=8, e_bits=3):
super(QModule, self).__init__()
if qi:
self.qi = QParam(quant_type,num_bits, e_bits)
if qo:
self.qo = QParam(quant_type,num_bits, e_bits)
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
def freeze(self):
pass # 空语句
def fakefreeze(self):
pass
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
"""
QModule 量化卷积
:quant_type: 量化类型
:conv_module: 卷积模块
:qi: 是否量化输入特征图
:qo: 是否量化输出特征图
:num_bits: 8位bit数
"""
class QConv2d(QModule):
def __init__(self, quant_type, conv_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QConv2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
# freeze方法可以固定真量化的权重参数,并将该值更新到原全精度层上,便于散度计算
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
# 这里因为在池化或者激活的输入,不需要对最大值和最小是进行额外的统计,会共享相同的输出
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
# 根据https://zhuanlan.zhihu.com/p/156835141, 这是式3 的系数
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
self.conv_module.bias.data = quantize_tensor(self.quant_type,
self.conv_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0.,qmax=self.bias_qmax, is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x): # 前向传播,输入张量,x为浮点型数据
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi) # 对输入张量X完成量化
# foward前更新qw,保证量化weight时候scale正确
self.qw.update(self.conv_module.weight.data)
# 注意:此处主要为了统计各层x和weight范围,未对bias进行量化操作
# tmp_wgt = FakeQuantize.apply(self.conv_module.weight, self.qw)
# x = F.conv2d(x, tmp_wgt, self.conv_module.bias,
# stride=self.conv_module.stride,
# padding=self.conv_module.padding, dilation=self.conv_module.dilation,
# groups=self.conv_module.groups)
x = F.conv2d(x, FakeQuantize.apply(self.conv_module.weight, self.qw), self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
# 利用公式 q_a = M(\sigma(q_w-Z_w)(q_x-Z_x) + q_b)
def quantize_inference(self, x): # 此处input为已经量化的qx
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
# if self.quant_type is 'INT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QLinear(QModule):
def __init__(self, quant_type, fc_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QLinear, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.fc_module = fc_module
self.qw = QParam(quant_type, num_bits, e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data)
self.fc_module.weight.data = self.fc_module.weight.data - self.qw.zero_point
self.fc_module.bias.data = quantize_tensor(self.quant_type,
self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax, is_bias=True)
def fakefreeze(self):
self.fc_module.weight.data = self.qw.dequantize_tensor(self.fc_module.weight.data)
self.fc_module.bias.data = dequantize_tensor(self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
self.qw.update(self.fc_module.weight.data)
# tmp_wgt = FakeQuantize.apply(self.fc_module.weight, self.qw)
# x = F.linear(x, tmp_wgt, self.fc_module.bias)
x = F.linear(x, FakeQuantize.apply(self.fc_module.weight, self.qw), self.fc_module.bias)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
# 这里是为了衔接lstm输出的 fp32 scale的x
x = self.qi.quantize_tensor(x)
x = x - self.qi.zero_point
x = self.fc_module(x)
x = self.M * x
# if self.quant_type is 'INT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QReLU(QModule):
def __init__(self,quant_type, qi=False, qo=True, num_bits=8, e_bits=3):
super(QReLU, self).__init__(quant_type, qi, qo, num_bits, e_bits)
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.relu(x)
return x
def quantize_inference(self, x):
x = x.clone()
# x[x < self.qi.zero_point] = self.qi.zero_point
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
a = self.qi.zero_point.float().to(device)
x[x < a] = a
return x
class QMaxPooling2d(QModule):
def __init__(self, quant_type, kernel_size=3, stride=1, padding=0, qi=False, qo=True, num_bits=8,e_bits=3):
super(QMaxPooling2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
return x
def quantize_inference(self, x):
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
class QConvBNReLU(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QConvBNReLU, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True))
else:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
# if self.quant_type is 'INT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
x.clamp_(min=0)
return x
class QConvBN(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QConvBN, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True))
else:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
# x = F.relu(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
# if self.quant_type is 'INT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
# x.clamp_(min=0)
return x
# 待修改 需要有qo吧
class QAdaptiveAvgPool2d(QModule):
def __init__(self, quant_type, qi=False, qo=True, num_bits=8, e_bits=3):
super(QAdaptiveAvgPool2d, self).__init__(quant_type,qi,qo,num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qo is not None:
self.qo = qo
self.M.data = (self.qi.scale / self.qo.scale).data
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi) # 与ReLu一样,先更新qi的scale,再将x用PoT表示了 (不过一般前一层的qo都是True,则x已经被PoT表示了)
x = F.adaptive_avg_pool2d(x,(1, 1)) # 对输入输出都量化一下就算是量化了
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = F.adaptive_avg_pool2d(x,(1, 1)) # 对输入输出都量化一下就算是量化了
x = self.M * x
# if self.quant_type is 'INT':
x = get_nearest_val(self.quant_type,x)
return x
class QConvBNReLU6(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QConvBNReLU6, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu6(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
a = torch.tensor(6)
a = self.qo.quantize_tensor(a)
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
# if self.quant_type is not 'POT':
# x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point # 属于qo范围的数据
x.clamp_(min=0, max=a.item())
return x
class QModule_2(nn.Module):
def __init__(self,quant_type, qi0=True, qi1=True, qo=True, num_bits=8, e_bits=3):
super(QModule_2, self).__init__()
if qi0:
self.qi0 = QParam(quant_type,num_bits, e_bits) # qi在此处就已经被num_bits和mode赋值了
if qi1:
self.qi1 = QParam(quant_type,num_bits, e_bits) # qi在此处就已经被num_bits和mode赋值了
if qo:
self.qo = QParam(quant_type,num_bits, e_bits) # qo在此处就已经被num_bits和mode赋值了
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
def freeze(self):
pass
def fakefreeze(self):
pass
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
class QElementwiseAdd(QModule_2):
def __init__(self, quant_type, qi0=True, qi1=True, qo=True, num_bits=8, e_bits=3):
super(QElementwiseAdd, self).__init__(quant_type, qi0, qi1, qo, num_bits, e_bits)
self.register_buffer('M0', torch.tensor([], requires_grad=False)) # 将M注册为buffer
self.register_buffer('M1', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi0=None, qi1=None ,qo=None):
if hasattr(self, 'qi0') and qi0 is not None:
raise ValueError('qi0 has been provided in init function.')
if not hasattr(self, 'qi0') and qi0 is None:
raise ValueError('qi0 is not existed, should be provided.')
if hasattr(self, 'qi1') and qi0 is not None:
raise ValueError('qi1 has been provided in init function.')
if not hasattr(self, 'qi1') and qi0 is None:
raise ValueError('qi1 is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
# 这里因为在池化或者激活的输入,不需要对最大值和最小是进行额外的统计,会共享相同的输出
if qi0 is not None:
self.qi0 = qi0
if qi1 is not None:
self.qi1 = qi1
if qo is not None:
self.qo = qo
# 根据https://zhuanlan.zhihu.com/p/156835141, 这是式3 的系数
self.M0.data = self.qi0.scale / self.qo.scale
self.M1.data = self.qi1.scale / self.qi0.scale
# self.M0.data = self.qi0.scale / self.qo.scale
# self.M1.data = self.qi1.scale / self.qo.scale
def forward(self, x0, x1): # 前向传播,输入张量,x为浮点型数据
if hasattr(self, 'qi0'):
self.qi0.update(x0)
x0 = FakeQuantize.apply(x0, self.qi0) # 对输入张量X完成量化
if hasattr(self, 'qi1'):
self.qi1.update(x1)
x1 = FakeQuantize.apply(x1, self.qi1) # 对输入张量X完成量化
x = x0 + x1
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x0, x1): # 此处input为已经量化的qx
x0 = x0 - self.qi0.zero_point
x1 = x1 - self.qi1.zero_point
x = self.M0 * (x0 + x1*self.M1)
# if self.quant_type is 'INT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QModule_3(nn.Module):
def __init__(self,quant_type, qix=True, qih=True, qic=True, qox=True, qoh=True, qoc=True, num_bits=8, e_bits=3):
super(QModule_3, self).__init__()
if qix:
self.qix = QParam(quant_type,num_bits, e_bits)
if qox:
self.qox = QParam(quant_type,num_bits, e_bits)
if qih:
self.qih = QParam(quant_type,num_bits, e_bits)
if qoh:
self.qoh = QParam(quant_type,num_bits, e_bits)
if qic:
self.qic = QParam(quant_type,num_bits, e_bits)
if qoc:
self.qoc = QParam(quant_type,num_bits, e_bits)
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
def freeze(self):
pass # 空语句
def fakefreeze(self):
pass
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
class QLSTM(QModule_3):
def __init__(self, quant_type, lstm_module, qix=True, qih=True, qic=True, qox=True, qoh=True, qoc=True, num_bits=8, e_bits=3):
super(QLSTM, self).__init__(quant_type, qix, qih, qic, qox, qoh, qoc, num_bits, e_bits)
self.lstm_module = lstm_module
self.qwih = QParam(quant_type, num_bits,e_bits)
self.qwhh = QParam(quant_type, num_bits,e_bits)
# self.qbih = QParam(quant_type, num_bits,e_bits)
# self.qbhh = QParam(quant_type, num_bits,e_bits)
# self.register_buffer('Mi', torch.tensor([], requires_grad=False)) # 将Mi注册为buffer
# self.register_buffer('Mh', torch.tensor([], requires_grad=False)) # 将M注册为buffer
# if self.lstm_module.bidirectional:
# self.qwihr = QParam(quant_type, num_bits,e_bits)
# self.qwhhr = QParam(quant_type, num_bits,e_bits)
# self.qbihr = QParam(quant_type, num_bits,e_bits)
# self.qbhhr = QParam(quant_type, num_bits,e_bits)
def freeze(self, qix=None, qih=None, qic=None,qox=None, qoh=None, qoc=None, flag=0):
if hasattr(self, 'qix') and qix is not None:
raise ValueError('qix has been provided in init function.')
if not hasattr(self, 'qix') and qix is None:
raise ValueError('qix is not existed, should be provided.')
if hasattr(self, 'qox') and qox is not None:
raise ValueError('qox has been provided in init function.')
if not hasattr(self, 'qox') and qox is None:
raise ValueError('qox is not existed, should be provided.')
if hasattr(self, 'qih') and qih is not None:
raise ValueError('qih has been provided in init function.')
if not hasattr(self, 'qih') and qih is None and flag==1: # 非第一个lstm layer
raise ValueError('qih is not existed, should be provided.')
if hasattr(self, 'qoh') and qoh is not None:
raise ValueError('qoh has been provided in init function.')
if not hasattr(self, 'qoh') and qoh is None:
raise ValueError('qoh is not existed, should be provided.')
if hasattr(self, 'qic') and qic is not None:
raise ValueError('qic has been provided in init function.')
if not hasattr(self, 'qic') and qic is None and flag==1: # 非第一个lstm layer
raise ValueError('qic is not existed, should be provided.')
if hasattr(self, 'qoc') and qoc is not None:
raise ValueError('qoc has been provided in init function.')
if not hasattr(self, 'qoc') and qoc is None:
raise ValueError('qoc is not existed, should be provided.')
if qix is not None:
self.qix = qix
if qox is not None:
self.qox = qox
# 为了避免第一个lstm layer没有h,c却拥有qih,qic
if qih is not None and flag==1:
self.qih = qih
if qoh is not None:
self.qoh = qoh
if qic is not None and flag==1:
self.qic = qic
if qoc is not None:
self.qoc = qoc
# 这里应该涉及到了两个问题:1. 输出、隐层输出 2. 双向
# 输出、隐层输出可以分别处理
# 双向比较麻烦,在量化后甚至还涉及到了SUM和CONCAT的整合方式
# self.Mi.data = (self.qwih.scale * self.qi.scale / self.qo.scale).data
# self.Mh.data = (self.qwhh.scale * self.qi.scale / self.qo.scale).data
# 对weight伪量化
self.lstm_module.weight_ih_l0.data = FakeQuantize.apply(self.lstm_module.weight_ih_l0.data,self.qwih)
self.lstm_module.weight_hh_l0.data = FakeQuantize.apply(self.lstm_module.weight_hh_l0.data,self.qwhh)
# 对bias伪量化
self.lstm_module.bias_ih_l0.data = quantize_tensor(self.quant_type,self.lstm_module.bias_ih_l0.data,scale=self.qix.scale*self.qwih.scale,zero_point=0,qmax=self.bias_qmax,is_bias=True)
self.lstm_module.bias_ih_l0.data = dequantize_tensor(self.lstm_module.bias_ih_l0.data,scale=self.qix.scale*self.qwih.scale,zero_point=0)
# 第一个layer是没有qih的,需要特殊处理
if flag==1:
self.lstm_module.bias_hh_l0.data = quantize_tensor(self.quant_type,self.lstm_module.bias_hh_l0.data,scale=self.qih.scale*self.qwhh.scale,zero_point=0,qmax=self.bias_qmax,is_bias=True)
self.lstm_module.bias_hh_l0.data = dequantize_tensor(self.lstm_module.bias_hh_l0.data,scale=self.qih.scale*self.qwhh.scale,zero_point=0)
def forward(self, x, h=None, c=None):
if hasattr(self, 'qix'):
self.qix.update(x)
x = FakeQuantize.apply(x, self.qix)
if hasattr(self, 'qih') and h is not None: # 兼顾第一个lstm layer无h,qih
self.qih.update(h)
h = FakeQuantize.apply(h, self.qih)
if hasattr(self, 'qic') and c is not None: # 兼顾第一个lstm layer无c,qic
self.qic.update(c)
c = FakeQuantize.apply(c, self.qic)
self.qwih.update(self.lstm_module.weight_ih_l0.data)
self.qwhh.update(self.lstm_module.weight_hh_l0.data)
layer = nn.LSTM(input_size=self.lstm_module.input_size,
hidden_size=self.lstm_module.hidden_size,
num_layers=1,
batch_first=False,
bidirectional=self.lstm_module.bidirectional,
bias=True)
layer.weight_ih_l0.data = FakeQuantize.apply(self.lstm_module.weight_ih_l0.data,self.qwih)
layer.weight_hh_l0.data = FakeQuantize.apply(self.lstm_module.weight_hh_l0.data,self.qwhh)
layer.bias_ih_l0.data = self.lstm_module.bias_ih_l0.data
layer.bias_hh_l0.data = self.lstm_module.bias_hh_l0.data
if h is None:
x, (h, c) = layer(x)
else:
x, (h, c) = layer(x, (h, c))
if hasattr(self, 'qox'):
self.qox.update(x)
x = FakeQuantize.apply(x, self.qox)
if hasattr(self, 'qoh'):
self.qoh.update(h)
h = FakeQuantize.apply(h, self.qoh)
if hasattr(self, 'qoc'):
self.qoc.update(c)
c = FakeQuantize.apply(c, self.qoc)
return x,(h,c)
def quantize_inference(self, x, h=None, c=None):
# freeze的时是fakequantize,因此这里直接算,无需做scale变换
if h is None:
x, (h, c) = self.lstm_module(x)
if hasattr(self, 'qox'):
x = FakeQuantize.apply(x, self.qox)
if hasattr(self, 'qoh'):
h = FakeQuantize.apply(h, self.qoh)
if hasattr(self, 'qoc'):
c = FakeQuantize.apply(c, self.qoc)
else:
x, (h, c) = self.lstm_module(x, (h, c))
if hasattr(self, 'qox'):
x = FakeQuantize.apply(x, self.qox)
if hasattr(self, 'qoh'):
h = FakeQuantize.apply(h, self.qoh)
if hasattr(self, 'qoc'):
c = FakeQuantize.apply(c, self.qoc)
return x,(h,c)
# new modules for full-precision model - fold bn
# inference应该也需要相应的适配
class ConvBNReLU(nn.Module):
def __init__(self,conv_module, bn_module):
super(ConvBNReLU, self).__init__()
self.conv_module = conv_module
self.bn_module = bn_module
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self):
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = weight.data
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(bias)
else:
self.conv_module.bias.data = bias
def fakefreeze(self):
pass
def forward(self, x):
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
x = F.conv2d(x, weight, bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu(x)
return x
def quantize_inference(self, x):
x = self.conv_module(x)
x.clamp_(min=0)
return x
class ConvBN(nn.Module):
def __init__(self,conv_module, bn_module):
super(ConvBN, self).__init__()
self.conv_module = conv_module
self.bn_module = bn_module
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self):
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = weight.data
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(bias)
else:
self.conv_module.bias.data = bias
def fakefreeze(self):
pass
def forward(self, x):
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
x = F.conv2d(x, weight, bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
return x
def quantize_inference(self, x):
x = self.conv_module(x)
return x
class ConvBNReLU6(nn.Module):
def __init__(self,conv_module, bn_module):
super(ConvBNReLU6, self).__init__()
self.conv_module = conv_module
self.bn_module = bn_module
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self):
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = weight.data
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(bias)
else:
self.conv_module.bias.data = bias
def fakefreeze(self):
pass
def forward(self, x):
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
x = F.conv2d(x, weight, bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu6(x)
return x
def quantize_inference(self, x):
x = self.conv_module(x)
x.clamp_(min=0,max=6)
return x
\ No newline at end of file
# -*- coding: utf-8 -*-
from torch.serialization import load
from model import *
# from extract_ratio import *
from utils import *
import gol
import openpyxl
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import torch.utils.bottleneck as bn
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
import json
from decoder import seq_mnist_decoder
from data import seq_mnist_train, seq_mnist_val
from torch.utils.data import DataLoader
import random
class objdict(dict):
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError("No such attribute: " + name)
def __setattr__(self, name, value):
self[name] = value
def __delattr__(self, name):
if name in self:
del self[name]
else:
raise AttributeError("No such attribute: " + name)
# def direct_quantize(model, test_loader,device):
# for i, (data, target) in enumerate(test_loader, 1):
# data = data.to(device)
# output = model.quantize_forward(data).cpu()
# if i % 500 == 0:
# break
# print('direct quantization finish')
# def full_inference(model, test_loader, device):
# correct = 0
# for i, (data, target) in enumerate(test_loader, 1):
# data = data.to(device)
# output = model(data).cpu()
# pred = output.argmax(dim=1, keepdim=True)
# correct += pred.eq(target.view_as(pred)).sum().item()
# print('\nTest set: Full Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
# return 100. * correct / len(test_loader.dataset)
# def quantize_inference(model, test_loader, device):
# correct = 0
# for i, (data, target) in enumerate(test_loader, 1):
# data = data.to(device)
# output = model.quantize_inference(data).cpu()
# pred = output.argmax(dim=1, keepdim=True)
# correct += pred.eq(target.view_as(pred)).sum().item()
# print('Test set: Quant Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
# return 100. * correct / len(test_loader.dataset)
def direct_quantize(model, val_loader , val_data ,args , trainer_params, decoder, criterion):
model.eval()
loss_value = 0
for i, (item) in enumerate(val_loader):
data, labels, output_len, lab_len = item
data = Variable(data.transpose(1,0), requires_grad=False)
labels = Variable(labels.view(-1), requires_grad=False)
output_len = Variable(output_len.view(-1), requires_grad=False)
lab_len = Variable(lab_len.view(-1), requires_grad=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = data.to(device)
output = model.quantize_forward(data)
# if i % 500 == 0:
# # break
print('direct quantization finish')
# loss_value /= (len(val_data)//trainer_params.test_batch_size)
# # loss_value = loss_value[0]
# loss_value = loss_value.item()
# print("Average Loss Value for Val Data is = {:.4f}\n".format(float(loss_value)))
def full_inference(model, val_loader , val_data ,args , trainer_params, decoder, criterion):
model.eval()
loss_value = 0
for i, (item) in enumerate(val_loader):
data, labels, output_len, lab_len = item
data = Variable(data.transpose(1,0), requires_grad=False)
labels = Variable(labels.view(-1), requires_grad=False)
output_len = Variable(output_len.view(-1), requires_grad=False)
lab_len = Variable(lab_len.view(-1), requires_grad=False)
# data = data.cuda()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = data.to(device)
output = model(data)
index = random.randint(0,trainer_params.test_batch_size-1)
label = labels[index*trainer_params.word_size:(index+1)*trainer_params.word_size].data.numpy()
label = label-1
prediction = decoder.decode(output[:,index,:], output_len[index], lab_len[index])
accuracy = decoder.hit(prediction, label)
print("Sample Label = {}".format(decoder.to_string(label)))
print("Sample Prediction = {}".format(decoder.to_string(prediction)))
print("Full Model Accuracy on Sample = {:.2f}%\n\n".format(accuracy))
loss = criterion(output, labels, output_len, lab_len)
# loss_value += loss.data.numpy()
loss_value += loss.cpu().data.numpy()
loss_value /= (len(val_data)//trainer_params.test_batch_size)
# loss_value = loss_value[0]
loss_value = loss_value.item()
print("Full Model Average Loss Value for Val Data is = {:.4f}\n".format(float(loss_value)))
def quantize_inference(model, val_loader , val_data ,args , trainer_params, decoder, criterion):
model.eval()
loss_value = 0
for i, (item) in enumerate(val_loader):
data, labels, output_len, lab_len = item
data = Variable(data.transpose(1,0), requires_grad=False)
labels = Variable(labels.view(-1), requires_grad=False)
output_len = Variable(output_len.view(-1), requires_grad=False)
lab_len = Variable(lab_len.view(-1), requires_grad=False)
# data = data.cuda()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = data.to(device)
output = model.quantize_inference(data)
index = random.randint(0,trainer_params.test_batch_size-1)
label = labels[index*trainer_params.word_size:(index+1)*trainer_params.word_size].data.numpy()
label = label-1
prediction = decoder.decode(output[:,index,:], output_len[index], lab_len[index])
accuracy = decoder.hit(prediction, label)
print("Sample Label = {}".format(decoder.to_string(label)))
print("Sample Prediction = {}".format(decoder.to_string(prediction)))
print("Quantize Model Accuracy on Sample = {:.2f}%\n\n".format(accuracy))
loss = criterion(output, labels, output_len, lab_len)
# loss_value += loss.data.numpy()
loss_value += loss.cpu().data.numpy()
loss_value /= (len(val_data)//trainer_params.test_batch_size)
# loss_value = loss_value[0]
loss_value = loss_value.item()
print("Quantize Model Average Loss Value for Val Data is = {:.4f}\n".format(float(loss_value)))
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PyTorch FP32 Training')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='ResNet18')
parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
parser.add_argument('-j','--workers', default=4, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')
parser.add_argument('-s', '--save', default=False, type=bool)
# parser.add_argument('-t', '--test', dest='test', action='store_true', help='test model on test set')
parser.add_argument('--params', '-p', type=str, default="default_trainer_params.json", help='Path to params JSON file. Default ignored when resuming.')
# 训练参数
args = parser.parse_args()
with open(args.params) as d:
trainer_params = json.load(d)
# trainer_params = json.load(d, object_hook=ascii_encode_dict)
trainer_params = objdict(trainer_params)
batch_size = args.batch_size
num_workers = args.workers
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
labels = [i for i in range(trainer_params.num_classes-1)]
decoder = seq_mnist_decoder(labels=labels)
criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)
random.seed(trainer_params.random_seed)
torch.manual_seed(trainer_params.random_seed)
# if args.cuda:
torch.cuda.manual_seed_all(trainer_params.random_seed)
train_data = seq_mnist_train(trainer_params)
val_data = seq_mnist_val(trainer_params)
train_loader = DataLoader(train_data, batch_size=trainer_params.batch_size, \
shuffle=True, num_workers=trainer_params.num_workers)
val_loader = DataLoader(val_data, batch_size=trainer_params.test_batch_size, \
shuffle=False, num_workers=trainer_params.num_workers)
if args.model == 'LSTM-OCR':
model = BiLSTM(trainer_params)
# writer = SummaryWriter(log_dir='log/' + args.model + '/ptq')
save_dir = 'ckpt'
full_file = save_dir + '/mnist_' + trainer_params.reduce_bidirectional +'_' + str(trainer_params.bidirectional) + '.pt'
model.load_state_dict(torch.load(full_file))
model.to(device)
load_ptq = False
ptq_file_prefix = 'ckpt/mnist_' + trainer_params.reduce_bidirectional +'_' + str(trainer_params.bidirectional) + '_ptq_'
model.eval()
full_acc = full_inference(model, val_loader, val_data, args, trainer_params, decoder,criterion)
# model_fold = fold_model(model) #
# full_params = []
# layer, par_ratio, flop_ratio = extract_ratio(args.model)
# layer = []
# for name, param in model.named_parameters():
# if 'weight' in name:
# n = name.split('.')
# pre = '.'.join(n[:len(n)-1])
# # 提取出weight前的名字(就是这个层的名字,if weight是避免bias重复提取一遍名字)
# layer.append(pre)
# print('===================')
# par_ratio, flop_ratio = fold_ratio(layer, par_ratio, flop_ratio)
# for name, param in model_fold.named_parameters():
# if 'bn' in name or 'sample.1' in name:
# continue
# param_norm = param.data.cpu()
# full_params.append(param_norm) # 没统计bn的 只统计了conv的 而且还是fold后的
# writer.add_histogram(tag='Full_' + name + '_data', values=param.data)
gol._init()
quant_type_list = ['INT','POT','FLOAT']
title_list = []
js_flops_list = []
js_param_list = []
ptq_acc_list = []
acc_loss_list = []
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
# model_ptq = resnet18()
if args.model == 'LSTM-OCR':
model_ptq = BiLSTM(trainer_params)
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nPTQ: '+title)
title_list.append(title)
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
# 判断是否需要载入
if load_ptq is True and osp.exists(ptq_file_prefix + title + '.pt'):
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.load_state_dict(torch.load(ptq_file_prefix + title + '.pt'))
model_ptq.to(device)
print('Successfully load ptq model: ' + title)
else:
model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.eval()
direct_quantize(model_ptq, val_loader, val_data, args, trainer_params, decoder,criterion)
# if args.save == True:
# torch.save(model_ptq.state_dict(), ptq_file_prefix + title + '.pt')
model_ptq.freeze()
quantize_inference(model_ptq, val_loader, val_data, args, trainer_params, decoder,criterion)
# ptq_acc = quantize_inference(model_ptq, val_loader, val_data, args, trainer_params, decoder,criterion)
# ptq_acc_list.append(ptq_acc)
# acc_loss = (full_acc - ptq_acc) / full_acc
# acc_loss_list.append(acc_loss)
# idx = -1
# 获取计算量/参数量下的js-div
js_flops = 0.
js_param = 0.
# for name, param in model_ptq.named_parameters():
# # if '.' not in name or 'bn' in name:
# if 'bn' in name or 'sample.1' in name:
# continue
# writer.add_histogram(tag=title +':'+ name + '_data', values=param.data)
# idx = idx + 1
# # renset中有多个. 需要改写拼一下
# # prefix = name.split('.')[0]
# n = name.split('.')
# prefix = '.'.join(n[:len(n) - 1])
# # weight和bias 1:1 ? 对于ratio,是按层赋予的,此处可以对weight和bias再单独赋予不同的权重,比如(8:2)
# if prefix in layer:
# layer_idx = layer.index(prefix)
# ptq_param = param.data.cpu()
# # 取L2范数
# # ptq_norm = F.normalize(ptq_param,p=2,dim=-1)
# ptq_norm = ptq_param
# writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param)
# # print(name)
# # print('=========')
# # print(ptq_norm)
# # print('=========')
# # print(full_params[idx])
# js = js_div(ptq_norm,full_params[idx]) # 这里算了fold后的量化前后模型的js距离
# js = js.item()
# if js < 0.:
# js = 0.
# js_flops = js_flops + js * flop_ratio[layer_idx]
# js_param = js_param + js * par_ratio[layer_idx]
# js_flops_list.append(js_flops)
# js_param_list.append(js_param)
# print(title + ': js_flops: %f js_param: %f acc_loss: %f' % (js_flops, js_param, acc_loss))
sys.exit()
# 写入xlsx
workbook = openpyxl.Workbook()
worksheet = workbook.active
worksheet.cell(row=1,column=1,value='FP32-acc')
worksheet.cell(row=1,column=2,value=full_acc)
worksheet.cell(row=3,column=1,value='title')
worksheet.cell(row=3,column=2,value='js_flops')
worksheet.cell(row=3,column=3,value='js_param')
worksheet.cell(row=3,column=4,value='ptq_acc')
worksheet.cell(row=3,column=5,value='acc_loss')
for i in range(len(title_list)):
worksheet.cell(row=i+4, column=1, value=title_list[i])
worksheet.cell(row=i+4, column=2, value=js_flops_list[i])
worksheet.cell(row=i+4, column=3, value=js_param_list[i])
worksheet.cell(row=i+4, column=4, value=ptq_acc_list[i])
worksheet.cell(row=i+4, column=5, value=acc_loss_list[i])
workbook.save('ptq_result_' + args.model + '.xlsx')
writer.close()
ft = open('ptq_result_' + args.model + '.txt','w')
print('title_list:',file=ft)
print(" ".join(title_list),file=ft)
print('js_flops_list:',file=ft)
print(" ".join(str(i) for i in js_flops_list), file=ft)
print('js_param_list:',file=ft)
print(" ".join(str(i) for i in js_param_list), file=ft)
print('ptq_acc_list:',file=ft)
print(" ".join(str(i) for i in ptq_acc_list), file=ft)
print('acc_loss_list:',file=ft)
print(" ".join(str(i) for i in acc_loss_list), file=ft)
ft.close()
This source diff could not be displayed because it is too large. You can view the blob instead.
## update 2023.5.4
### 对fp32的模型进行了改进,并进行了初步的PTQ实验(只做到了对部分参数伪量化 naive fakequantization)
1. 对fp32的改进
- 支持多层LSTM,并用nn.ModuleList组织
2. 对PTQ的尝试
- 对LSTM的结构,数据流向,输入输出有了更细致的理解。PTQ遇到的主要问题有:
- BiLSTM涉及到了双向的output,需要用SUM或者CONCAT处理,PTQ时可能需要引入更多module处理。
- LSTM内部进行的运算较为复杂,如下图所示:<img src = "fig/math.png" class="h-90 auto">
首先涉及到了多个门i,f,g,他们内部有Wx+b+Wh+b的结构,他们的scale是否整体考虑(即,是共用一个scale,还是多个scale并在+的时候做rescale)是一个问题。<br>他们外部的sigmoid或tanh也是一个问题(会导致不方便在各个层间通过scale变换保证PTQ结果的正确性)。<br>对c',h'的更新涉及到的*,+也是一个问题。
- LSTM内部各个门的权值被组织在`weight_ih_l[k]`、`weight_hh_l[k]`、`bias_ih_l[k]`和`bias_hh_l[k]`中,分属在不同行中,是否应该把这个权值矩阵拆开,分别量化后再合并。
- 暂时还没想好怎么处理上述问题,于是进行了简单的尝试性实验:
- 先只处理单向的LSTM,不考虑双向所需的SUM和CONCAT
-`weight_ih_l[k]``weight_hh_l[k]``bias_ih_l[k]``bias_hh_l[k]`各自整体量化,没有再对每个张量切成四块再分别对`weight_ir_l[k]`,`weight_hr_l[k]`,`bias_r_l[k]`,`weight_if_l[k]`,`weight_hf_l[k]`,`bias_f_l[k]`,`weight_ii_l[k]`,`weight_hi_l[k]`进行量化
- 对weight和bias和每层的输出tensor采取伪量化的方法,避免层间的scale变换。<br><br>后果:<br> (1) freeze后的权值不再是量化后的值,而是其量化值经过scale变换后的值
<br>(2) 没有考虑tanh,sigmoid的量化,他们不再是矩阵乘的形式,还没想好怎么消掉scale。只模拟了将权值和输出张量量化产生的rounding,溢出等误差。
<br>(3) 与实际可直接部署到硬件的lstm量化可能有差异。(不过目前我不太确定LSTM实际的量化应该怎么做,在网上没有找到很多lstm量化相关的资料)
<br>(4) 与之前的其他网络的PTQ量化略有差异,与scale相关的的计算顺序有些不同,但运算原理是类似的。
- 有待改进:
- 对LSTM的PTQ量化的更真实模拟,考虑sigmoid,tanh,各种乘加组合
- 补充对BiLSTM的PTQ量化 (如果还按现在的简化版处理方式,BiLSTM很容易实现,因为不需要考虑各种scale方面的问题,直接SUM或者CONCAT即可)
- 可以考虑把`weight_ih_l[k]``weight_hh_l[k]``bias_ih_l[k]``bias_hh_l[k]`拆开,按各自门的权值分别伪量化后再组合
- 使用更复杂的数据集
- 度量相似度
- 找到比较好的指标来度量精度
## update 2023.5.2
basic version: FP32版本,只有单个lstm cell,训练数据集采用序列化的MNIST,仅作记录方便后续修改。
\ No newline at end of file
import os
import math
import numpy
import torch
import torch.nn as nn
from model import *
import argparse
import json
# input_size = 32
# num_units = 128
# num_layers = 1
# bidirectional = True
# recurrent_bias_enabled = True
# lstm1 = nn.LSTM(input_size=input_size,
# hidden_size=num_units,
# num_layers=num_layers,
# batch_first=False,
# bidirectional= bidirectional,
# bias= recurrent_bias_enabled)
# lstm2 = nn.LSTM(input_size=input_size,
# hidden_size=num_units,
# num_layers=num_layers + 1,
# batch_first=False,
# bidirectional= bidirectional,
# bias= recurrent_bias_enabled)
# print("LSTM1:")
# for name,params in lstm1.named_parameters():
# print(f"name:{name},params:{params.shape}")
# print("=============================================")
# print("LSTM2:")
# for name,params in lstm2.named_parameters():
# print(f"name:{name},params:{params.shape}")
class objdict(dict):
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError("No such attribute: " + name)
def __setattr__(self, name, value):
self[name] = value
def __delattr__(self, name):
if name in self:
del self[name]
else:
raise AttributeError("No such attribute: " + name)
parser = argparse.ArgumentParser(description='PyTorch BiLSTM Sequential MNIST Example')
parser.add_argument('--params', '-p', type=str, default="default_trainer_params.json", help='Path to params JSON file. Default ignored when resuming.')
args = parser.parse_args()
with open(args.params) as d:
trainer_params = json.load(d)
# trainer_params = json.load(d, object_hook=ascii_encode_dict)
trainer_params = objdict(trainer_params)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 加入设备选择
model = BiLSTM(trainer_params).to(device)
model.quantize('INT',8,0)
......@@ -38,12 +38,29 @@ class Seq_MNIST_Trainer():
self.prev_loss = 10000
self.model = BiLSTM(trainer_params)
# self.criterion = wp.CTCLoss(size_average=False)
self.criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)
self.labels = [i for i in range(trainer_params.num_classes-1)]
self.decoder = seq_mnist_decoder(labels=self.labels)
self.optimizer = optim.Adam(self.model.parameters(), lr=trainer_params.lr)
# self.criterion = wp.CTCLoss(size_average=False)
# 默认为false
# if args.init_bn_fc_fusion:
# # 默认是false的,应该是用于记录当前是否fuse了吧
# if not trainer_params.prefused_bn_fc:
# self.model.batch_norm_fc.init_fusion() # fuse bn-fc
# self.trainer_params.prefused_bn_fc = True # 已fuse
# else:
# raise Exception("BN and FC are already fused.")
# 先fuse了再load fuse后的
if args.eval or args.resume :
save_dir = 'ckpt'
full_file = save_dir + '/mnist_' + self.trainer_params.reduce_bidirectional +'_' + str(self.trainer_params.bidirectional) + '.pt'
self.model.load_state_dict(torch.load(full_file))
print("load Model from existing file finished!")
if args.cuda:
# torch.cuda.set_device(args.gpus)
......@@ -53,29 +70,6 @@ class Seq_MNIST_Trainer():
self.model = self.model.to(device)
self.criterion = self.criterion.to(device)
if args.resume or args.eval or args.export:
print("Loading model from {}".format(args.resume))
package = torch.load(args.resume, map_location=lambda storage, loc: storage)
self.model.load_state_dict(package['state_dict'])
self.optimizer.load_state_dict(package['optim_dict'])
self.starting_epoch = package['starting_epoch']
self.prev_loss = package['prev_loss']
if args.cuda:
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
# state[k] = v.cuda()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state[k] = v.to(device)
# 默认为false
if args.init_bn_fc_fusion:
# 默认是false的,应该是用于记录当前是否fuse了吧
if not trainer_params.prefused_bn_fc:
self.model.batch_norm_fc.init_fusion()
self.trainer_params.prefused_bn_fc = True
else:
raise Exception("BN and FC are already fused.")
def serialize(self, model, trainer_params, optimizer, starting_epoch, prev_loss):
package = {'state_dict': model.state_dict(),
......@@ -86,11 +80,16 @@ class Seq_MNIST_Trainer():
}
return package
# 存储
def save_model(self, epoch, name):
path = self.args.experiments + '/' + name
print("Model saved at: {}\n".format(path))
torch.save(self.serialize(model=self.model, trainer_params=self.trainer_params,
optimizer=self.optimizer, starting_epoch=epoch + 1, prev_loss=self.prev_loss), path)
def save_model(self):
save_dir = 'ckpt'
if not os.path.isdir(save_dir):
os.makedirs(save_dir, mode=0o777)
os.chmod(save_dir, mode=0o777)
# path = self.args.experiments + '/' + name
torch.save(self.model.state_dict(), save_dir + '/mnist_' + self.trainer_params.reduce_bidirectional +'_' + str(self.trainer_params.bidirectional) + '.pt')
# print("Model saved at: {}\n".format(path))
# torch.save(self.serialize(model=self.model, trainer_params=self.trainer_params,
# optimizer=self.optimizer, starting_epoch=epoch + 1, prev_loss=self.prev_loss), path)
def train(self, epoch):
......@@ -163,9 +162,9 @@ class Seq_MNIST_Trainer():
print("Average Loss Value for Val Data is = {:.4f}\n".format(float(loss_value)))
if loss_value < self.prev_loss and save_model_flag:
self.prev_loss = loss_value
self.save_model(epoch, "best.tar")
elif save_model_flag:
self.save_model(epoch, "checkpoint.tar")
self.save_model()
# elif save_model_flag:
# self.save_model(epoch, "checkpoint.tar")
def eval_model(self):
self.test()
......@@ -173,7 +172,7 @@ class Seq_MNIST_Trainer():
def train_model(self):
for epoch in range(self.starting_epoch, self.args.epochs + 1):
self.train(epoch)
self.test(epoch=epoch, save_model_flag=False)
acc = self.test(epoch=epoch, save_model_flag=True) # 默认不save model (等做ptq的实验时在处理)
if epoch%20==0:
self.optimizer.param_groups[0]['lr'] = self.optimizer.param_groups[0]['lr']*0.98
......
import torch
import torch.nn as nn
import torch.nn.functional as F
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
e_bit_list = list(range(1,num_bits-1))
else:
e_bit_list = [0]
return e_bit_list
def numbit_list(quant_type):
if quant_type == 'INT':
num_bit_list = list(range(2,17))
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
else:
num_bit_list = list(range(2,9))
# num_bit_list = [8]
return num_bit_list
def build_bias_list(quant_type):
if quant_type == 'POT':
return build_pot_list(8) #
else:
return build_float_list(16,7)
def build_list(quant_type, num_bits, e_bits):
if quant_type == 'POT':
return build_pot_list(num_bits)
else:
return build_float_list(num_bits,e_bits)
def build_pot_list(num_bits):
plist = [0.]
for i in range(-2 ** (num_bits-1) + 2, 1):
# i最高到0,即pot量化最大值为1
plist.append(2. ** i)
plist.append(-2. ** i)
plist = torch.Tensor(list(set(plist)))
# plist = plist.mul(1.0 / torch.max(plist))
return plist
def build_float_list(num_bits,e_bits):
m_bits = num_bits - 1 - e_bits
plist = [0.]
# 相邻尾数的差值
dist_m = 2 ** (-m_bits)
e = -2 ** (e_bits - 1) + 1
for m in range(1, 2 ** m_bits):
frac = m * dist_m # 尾数部分
expo = 2 ** e # 指数部分
flt = frac * expo
plist.append(flt)
plist.append(-flt)
for e in range(-2 ** (e_bits - 1) + 2, 2 ** (e_bits - 1) + 1):
expo = 2 ** e
for m in range(0, 2 ** m_bits):
frac = 1. + m * dist_m
flt = frac * expo
plist.append(flt)
plist.append(-flt)
plist = torch.Tensor(list(set(plist)))
return plist
def fold_ratio(layer, par_ratio, flop_ratio):
idx = -1
for name in layer:
idx = idx + 1
# layer是for name, param in model.named_parameters()中提取出来的,一定是有downsample的
if 'bn' in name or 'sample.1' in name:
par_ratio[idx-1] += par_ratio[idx]
flop_ratio[idx-1] += flop_ratio[idx]
return par_ratio,flop_ratio
def fold_model(model):
idx = -1
module_list = []
# print('fold model:')
for name, module in model.named_modules():
# print(name+'-- +')
idx += 1
module_list.append(module)
# 这里之前忘记考虑downsampl里的conv了,导致少融合了一些
if 'bn' in name or 'sample.1' in name:
# print(name+'-- *')
module_list[idx-1] = fold_bn(module_list[idx-1],module) # 在这里修改了
return model
# def fold_model(model):
# last_conv = None
# last_bn = None
# for name, module in model.named_modules():
# if isinstance(module, nn.Conv2d):
# # 如果当前模块是卷积层,则将其 "fold" 到上一个 BN 层中
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# last_conv = module
# elif isinstance(module, nn.BatchNorm2d):
# # 如果当前模块是 BN 层,则将其 "fold" 到上一个卷积层中
# last_bn = module
# if last_conv is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# # 处理最后一个 BN 层
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# return model
def fold_bn(conv, bn):
# 获取 BN 层的参数
gamma = bn.weight.data
beta = bn.bias.data
mean = bn.running_mean
var = bn.running_var
eps = bn.eps
std = torch.sqrt(var + eps)
feat = bn.num_features
# 获取卷积层的参数
weight = conv.weight.data
if conv.bias is not None:
bias = conv.bias.data
if bn.affine:
gamma_ = gamma / std
weight = weight * gamma_.view(feat, 1, 1, 1)
if conv.bias is not None:
bias = gamma_ * bias - gamma_ * mean + beta
else:
bias = beta - gamma_ * mean
else:
gamma_ = 1 / std
weight = weight * gamma_
if conv.bias is not None:
bias = gamma_ * bias - gamma_ * mean
else:
bias = -gamma_ * mean
# 设置新的 weight 和 bias
conv.weight.data = weight
# 适用于bias=none的
if conv.bias is None:
conv.bias = nn.Parameter(bias)
else:
conv.bias.data = bias
return conv
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment