Commit 55f47e27 by Klin
parents bff8b078 8352e4cf
from model import *
from extract_ratio import *
from utils import *
import argparse
import openpyxl
import os
import os.path as osp
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Loss-Grad Analysis')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
args = parser.parse_args()
wb = openpyxl.Workbook()
ws = wb.active
writer = SummaryWriter(log_dir='log/' + args.model + '/qat_loss_grad')
# layer, par_ratio, flop_ratio = extract_ratio()
layer, par_ratio, flop_ratio = extract_ratio(args.model)
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
layer = []
for name, param in model.named_parameters():
if 'weight' in name:
n = name.split('.')
pre = '.'.join(n[:len(n)-1])
layer.append(pre)
# dir_prefix = 'ckpt/qat/epoch_'
dir_prefix = 'ckpt/qat/'+ args.model + '/'
quant_type_list = ['INT','POT','FLOAT']
for epoch in [5,10,15,20]:
ws_epoch = wb.create_sheet('epoch_%d'%epoch)
full_state = torch.load(dir_prefix+'%d/'%epoch + 'full.pt')
ws_epoch.cell(row=1,column=2,value='loss')
ws_epoch.cell(row=1,column=3,value='loss_sum')
ws_epoch.cell(row=1,column=4,value='loss_avg')
ws_epoch.cell(row=2,column=1,value='FP32')
ws_epoch.cell(row=2,column=2,value=full_state['loss'].cpu().item())
ws_epoch.cell(row=2,column=3,value=full_state['loss_sum'].cpu().item())
ws_epoch.cell(row=2,column=4,value=full_state['loss_avg'].cpu().item())
# full_grad = full_state['grad_dict']
# full_grad_sum = full_state['grad_dict_sum']
full_grad_avg = full_state['grad_dict_avg']
for name,tmpgrad in full_grad_avg.items():
writer.add_histogram('FULL: '+name,tmpgrad,global_step=epoch)
ws_epoch.cell(row=4,column=1,value='title')
ws_epoch.cell(row=4,column=2,value='loss')
ws_epoch.cell(row=4,column=3,value='loss_sum')
ws_epoch.cell(row=4,column=4,value='loss_avg')
ws_epoch.cell(row=4,column=5,value='js_grad_avg_norm')
# ws_epoch.cell(row=4,column=6,value='conv1.weight')
# ws_epoch.cell(row=4,column=7,value='conv1.bias')
# ws_epoch.cell(row=4,column=8,value='conv2.weight')
# ws_epoch.cell(row=4,column=9,value='conv2.bias')
# ws_epoch.cell(row=4,column=10,value='conv3.weight')
# ws_epoch.cell(row=4,column=11,value='conv3.bias')
# ws_epoch.cell(row=4,column=12,value='conv4.weight')
# ws_epoch.cell(row=4,column=13,value='conv4.bias')
# ws_epoch.cell(row=4,column=14,value='conv5.weight')
# ws_epoch.cell(row=4,column=15,value='conv5.bias')
# ws_epoch.cell(row=4,column=16,value='fc1.weight')
# ws_epoch.cell(row=4,column=17,value='fc1.bias')
# ws_epoch.cell(row=4,column=18,value='fc2.weight')
# ws_epoch.cell(row=4,column=19,value='fc2.bias')
# ws_epoch.cell(row=4,column=20,value='fc3.weight')
# ws_epoch.cell(row=4,column=21,value='fc3.bias')
cnt = 5
for n in layer:
cnt = cnt + 1
ws_epoch.cell(row=4,column=cnt,value=n)
currow=4
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nAnalyse: '+title)
currow += 1
qat_state=torch.load(dir_prefix+'%d/'%epoch+title+'.pt')
js_grad_avg_norm=0.
grad_avg = qat_state['grad_dict_avg']
for name,tmpgrad in grad_avg.items():
writer.add_histogram(title+': '+name,tmpgrad,global_step=epoch)
colidx=5
for name,_ in full_grad_avg.items():
prefix = name.split('.')[0]
colidx += 1
layer_idx = layer.index(prefix)
js_norm = js_div_norm(full_grad_avg[name],grad_avg[name])
ws_epoch.cell(row=currow,column=colidx,value=js_norm.cpu().item())
js_grad_avg_norm += flop_ratio[layer_idx] * js_norm
ws_epoch.cell(row=currow,column=1,value=title)
ws_epoch.cell(row=currow,column=2,value=qat_state['loss'].cpu().item())
ws_epoch.cell(row=currow,column=3,value=qat_state['loss_sum'].cpu().item())
ws_epoch.cell(row=currow,column=4,value=qat_state['loss_avg'].cpu().item())
ws_epoch.cell(row=currow,column=5,value=js_grad_avg_norm.cpu().item())
wb.save('loss_grad.xlsx')
writer.close()
\ No newline at end of file
...@@ -72,9 +72,9 @@ class ResNet(nn.Module): ...@@ -72,9 +72,9 @@ class ResNet(nn.Module):
self.layer2.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits) self.layer2.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer3.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits) self.layer3.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer4.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits) self.layer4.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.qavgpool1 = QAdaptiveAvgPool2d(quant_type,qi=False,num_bits=num_bits,e_bits=e_bits) self.qavgpool1 = QAdaptiveAvgPool2d(quant_type,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
# self.qfc1 = QLinear(quant_type, self.fc,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits) self.qfc1 = QLinear(quant_type, self.fc,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qfc1 = QLinear(quant_type, self.fc,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits) # self.qfc1 = QLinear(quant_type, self.fc,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)
def quantize_forward(self, x): def quantize_forward(self, x):
# for _, layer in self.quantize_layers.items(): # for _, layer in self.quantize_layers.items():
...@@ -102,8 +102,8 @@ class ResNet(nn.Module): ...@@ -102,8 +102,8 @@ class ResNet(nn.Module):
qo = self.layer3.freeze(qinput = qo) qo = self.layer3.freeze(qinput = qo)
qo = self.layer4.freeze(qinput = qo) qo = self.layer4.freeze(qinput = qo)
self.qavgpool1.freeze(qi=qo) self.qavgpool1.freeze(qi=qo)
# self.qfc1.freeze(qi=self.qavgpool1.qo) self.qfc1.freeze(qi=self.qavgpool1.qo)
self.qfc1.freeze() # self.qfc1.freeze()
def fakefreeze(self): def fakefreeze(self):
pass pass
......
...@@ -16,6 +16,7 @@ def get_nearest_val(quant_type,x,is_bias=False): ...@@ -16,6 +16,7 @@ def get_nearest_val(quant_type,x,is_bias=False):
plist = gol.get_value(is_bias) plist = gol.get_value(is_bias)
# print('get') # print('get')
# print(plist) # print(plist)
# x = x / 64
shape = x.shape shape = x.shape
xhard = x.view(-1) xhard = x.view(-1)
plist = plist.type_as(x) plist = plist.type_as(x)
...@@ -23,6 +24,7 @@ def get_nearest_val(quant_type,x,is_bias=False): ...@@ -23,6 +24,7 @@ def get_nearest_val(quant_type,x,is_bias=False):
idx = (xhard.unsqueeze(0) - plist.unsqueeze(1)).abs().min(dim=0)[1] idx = (xhard.unsqueeze(0) - plist.unsqueeze(1)).abs().min(dim=0)[1]
xhard = plist[idx].view(shape) xhard = plist[idx].view(shape)
xout = (xhard - x).detach() + x xout = (xhard - x).detach() + x
# xout = xout * 64
return xout return xout
# 采用对称有符号量化时,获取量化范围最大值 # 采用对称有符号量化时,获取量化范围最大值
...@@ -64,7 +66,7 @@ def bias_qmax(quant_type): ...@@ -64,7 +66,7 @@ def bias_qmax(quant_type):
elif quant_type == 'POT': elif quant_type == 'POT':
return get_qmax(quant_type) return get_qmax(quant_type)
else: else:
return get_qmax(quant_type, 16, 7) # e7 m9 (e5时不够大,导致数据溢出到两侧) return get_qmax(quant_type, 16, 7)
# 转化为FP32,不需再做限制 # 转化为FP32,不需再做限制
...@@ -222,7 +224,9 @@ class QConv2d(QModule): ...@@ -222,7 +224,9 @@ class QConv2d(QModule):
x = self.conv_module(x) x = self.conv_module(x)
x = self.M * x x = self.M * x
x = get_nearest_val(self.quant_type,x) if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point x = x + self.qo.zero_point
return x return x
...@@ -279,7 +283,11 @@ class QLinear(QModule): ...@@ -279,7 +283,11 @@ class QLinear(QModule):
x = x - self.qi.zero_point x = x - self.qi.zero_point
x = self.fc_module(x) x = self.fc_module(x)
x = self.M * x x = self.M * x
x = get_nearest_val(self.quant_type,x)
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point x = x + self.qo.zero_point
return x return x
...@@ -453,7 +461,11 @@ class QConvBNReLU(QModule): ...@@ -453,7 +461,11 @@ class QConvBNReLU(QModule):
x = x - self.qi.zero_point x = x - self.qi.zero_point
x = self.conv_module(x) x = self.conv_module(x)
x = self.M * x x = self.M * x
x = get_nearest_val(self.quant_type,x)
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point x = x + self.qo.zero_point
x.clamp_(min=0) x.clamp_(min=0)
return x return x
...@@ -567,38 +579,36 @@ class QConvBN(QModule): ...@@ -567,38 +579,36 @@ class QConvBN(QModule):
x = x - self.qi.zero_point x = x - self.qi.zero_point
x = self.conv_module(x) x = self.conv_module(x)
x = self.M * x x = self.M * x
x = get_nearest_val(self.quant_type,x)
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point x = x + self.qo.zero_point
# x.clamp_(min=0) # x.clamp_(min=0)
return x return x
# 待修改 需要有qo吧 # 待修改 需要有qo吧
class QAdaptiveAvgPool2d(QModule): class QAdaptiveAvgPool2d(QModule):
def __init__(self, quant_type, qi=False, qo=True,num_bits=8, e_bits=3): def __init__(self, quant_type, qi=False, qo=True, num_bits=8, e_bits=3):
super(QAdaptiveAvgPool2d, self).__init__(quant_type,qi,qo,num_bits,e_bits) super(QAdaptiveAvgPool2d, self).__init__(quant_type,qi,qo,num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None): def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None: if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.') raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None: if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.') raise ValueError('qi is not existed, should be provided.')
# if hasattr(self, 'qo') and qo is not None:
# raise ValueError('qo has been provided in init function.')
# if not hasattr(self, 'qo') and qo is None:
# raise ValueError('qo is not existed, should be provided.')
if qi is not None: if qi is not None:
self.qi = qi self.qi = qi
# if qo is not None:
# self.qo = qo
# def fakefreeze(self, qi=None): if hasattr(self, 'qo') and qo is not None:
# if hasattr(self, 'qi') and qi is not None: raise ValueError('qo has been provided in init function.')
# raise ValueError('qi has been provided in init function.') if not hasattr(self, 'qo') and qo is None:
# if not hasattr(self, 'qi') and qi is None: raise ValueError('qo is not existed, should be provided.')
# raise ValueError('qi is not existed, should be provided.') if qo is not None:
# if qi is not None: self.qo = qo
# self.qi = qi self.M.data = (self.qi.scale / self.qo.scale).data
def forward(self, x): def forward(self, x):
if hasattr(self, 'qi'): if hasattr(self, 'qi'):
...@@ -608,17 +618,23 @@ class QAdaptiveAvgPool2d(QModule): ...@@ -608,17 +618,23 @@ class QAdaptiveAvgPool2d(QModule):
x = F.adaptive_avg_pool2d(x,(1, 1)) # 对输入输出都量化一下就算是量化了 x = F.adaptive_avg_pool2d(x,(1, 1)) # 对输入输出都量化一下就算是量化了
# if hasattr(self, 'qo'): if hasattr(self, 'qo'):
# self.qo.update(x) self.qo.update(x)
# x = FakeQuantize.apply(x, self.qo) x = FakeQuantize.apply(x, self.qo)
return x return x
def quantize_inference(self, x): def quantize_inference(self, x):
x = F.adaptive_avg_pool2d(x,(1,1))
# x = FakeQuantize.apply(x, self.qo x = F.adaptive_avg_pool2d(x,(1, 1)) # 对输入输出都量化一下就算是量化了
# x = get_nearest_val(self.quant_type,x) # 这里可能并不适配于PoT的情况 缺少一个放缩? x = self.M * x
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
return x return x
...@@ -632,12 +648,7 @@ class QModule_2(nn.Module): ...@@ -632,12 +648,7 @@ class QModule_2(nn.Module):
if qi1: if qi1:
self.qi1 = QParam(quant_type,num_bits, e_bits) # qi在此处就已经被num_bits和mode赋值了 self.qi1 = QParam(quant_type,num_bits, e_bits) # qi在此处就已经被num_bits和mode赋值了
if qo: if qo:
# if num_bits <=9 : self.qo = QParam(quant_type,num_bits, e_bits) # qo在此处就已经被num_bits和mode赋值了
# self.qo = QParam(quant_type,num_bits, e_bits)
# if num_bits > 9 and num_bits<13:
# self.qo = QParam(quant_type,8, e_bits) # qo在此处就已经被num_bits和mode赋值了
# else:
self.qo = QParam(quant_type,num_bits, e_bits)
self.quant_type = quant_type self.quant_type = quant_type
self.num_bits = num_bits self.num_bits = num_bits
...@@ -688,6 +699,8 @@ class QElementwiseAdd(QModule_2): ...@@ -688,6 +699,8 @@ class QElementwiseAdd(QModule_2):
# 根据https://zhuanlan.zhihu.com/p/156835141, 这是式3 的系数 # 根据https://zhuanlan.zhihu.com/p/156835141, 这是式3 的系数
self.M0.data = self.qi0.scale / self.qo.scale self.M0.data = self.qi0.scale / self.qo.scale
self.M1.data = self.qi1.scale / self.qi0.scale self.M1.data = self.qi1.scale / self.qi0.scale
# self.M0.data = self.qi0.scale / self.qo.scale
# self.M1.data = self.qi1.scale / self.qo.scale
def forward(self, x0, x1): # 前向传播,输入张量,x为浮点型数据 def forward(self, x0, x1): # 前向传播,输入张量,x为浮点型数据
if hasattr(self, 'qi0'): if hasattr(self, 'qi0'):
...@@ -706,20 +719,15 @@ class QElementwiseAdd(QModule_2): ...@@ -706,20 +719,15 @@ class QElementwiseAdd(QModule_2):
return x return x
def quantize_inference(self, x0, x1): # 此处input为已经量化的qx def quantize_inference(self, x0, x1): # 此处input为已经量化的qx
# x0_d = self.qi0.dequantize_tensor(x0)
# x1_d = self.qi1.dequantize_tensor(x1)
# print(f"x0={x0_d.reshape(-1)[:10]}")
# print(f"x1={x1_d.reshape(-1)[:10]}")
x0 = x0 - self.qi0.zero_point x0 = x0 - self.qi0.zero_point
x1 = x1 - self.qi1.zero_point x1 = x1 - self.qi1.zero_point
x = self.M0 * (x0 + x1*self.M1) x = self.M0 * (x0 + x1*self.M1)
x = get_nearest_val(self.quant_type,x) if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point x = x + self.qo.zero_point
# x_d = self.qo.dequantize_tensor(x)
# print(f"x={x_d.reshape(-1)[:10]}")
# print(f"loss={x_d.reshape(-1)[:10]-(x0_d.reshape(-1)[:10]+x1_d.reshape(-1)[:10])}")
# print('=============')
return x return x
## update: <br>2023.4.17<br>
- 已针对4.12中的问题进行了修正和补充:
- INT量化位宽较大时,acc骤降,是因为QAdaptiveAvgPool2d的quantize_inference中的写法有问题,修正后结果符合预期。
- 已检查网络中的细节错误并修正,在QElementwiseAdd的quantize_inference中补充了x = get_nearest_val(self.quant_type,x)
- 已确认加权计算过程正确
- 发现并尝试解决的新问题:
- 在完成上述的修改后,INT量化和FP量化的结果与预期相符,但PoT量化出现了量化后推理精度极低的问题(10%~20%)。经过大量实验研究,我发现主要的问题是在PoT量化后的模型中,每层在quantize_inference中对输出进行的激活量化: x = get_nearest_val(self.quant_type,x)导致了较大的误差。<br>
针对该问题,我进行了具体的研究。影响量化后模型的因素不仅仅是模型的权重参数,每层输出数据在激活量化后也会引入误差。INT量化和FP量化中,输出数据在激活量化后引入的误差比较小(因为INT量化对输出数据就是直接rounding取整,而FP量化因为量化精度比较细致,因此误差也相对较小),但PoT量化不同,一般情况下在ResNet系列都有比较显著的误差。尤其当模型的层数较深时(比如ResNet50,ResNet152),量化后的推理精度因为这个原因有很大的损失(从90% => 10%~20%)。<br>
我最初有上述想法是在反复确认程序框架没有明显问题后,PoT量化的效果总是很差,而模型权重参数的相似度又很高,因此想到了可能是输入输出数据方面的问题。为验证该想法,我进行了实验,在各个量化模块的quantize_inference方法中,我设置在PoT量化时,不进行x = get_nearest_val(self.quant_type,x),即只对每层的权重参数量化,而不量化每层的输出tensor. 经过实验,PoT量化后的推理精度提升了不少(到达了50%~60%),总体拟合得到的曲线也比较符合预期了。<br>
考虑到目前的自变量只是模型量化前后权重参数的分布相似度,我便没有再引入其他指标来刻画每层输出的tensor在量化前后的分布相似度。
<br>
- 正在尝试的实验:
对PoT量化进行单独的考虑,先只是量化权值,看一下权值分布相似度跟acc的关系。然后选一个acc比较高的PoT权值保持不变,再用不同的数据表示去量化激活,这时候看激活的位宽和acc的关系。
- 尚未完成的内容:
- 收敛速度与梯度分布相似度的实验
<br> <br> <br> <br>
- 拟合曲线:
1. PoT量化中只对网络权值量化,INT和FP对网络权值和激活都进行量化:<br>
resnet18:
<img src = "fig/18_flops.png" class="h-90 auto">
<img src = "fig/18_params.png" class="h-90 auto">
resnet50:
<img src = "fig/50_flops.png" class="h-90 auto">
<img src = "fig/50_params.png" class="h-90 auto">
resnet152:
<img src = "fig/152_flops.png" class="h-90 auto">
<img src = "fig/152_params.png" class="h-90 auto">
2. 去掉PoT量化:<br>
resnet18:
<img src = "fig/18_flops_nopot.png" class="h-90 auto">
<img src = "fig/18_params_nopot.png" class="h-90 auto">
resnet50:
<img src = "fig/50_flops_nopot.png" class="h-90 auto">
<img src = "fig/50_params_nopot.png" class="h-90 auto">
resnet152:
<img src = "fig/152_flops_nopot.png" class="h-90 auto">
<img src = "fig/152_params_nopot.png" class="h-90 auto">
## update: <br>2023.4.12<br> ## update: <br>2023.4.12<br>
- 已修改get_param_flops.py, extract_ratio.py, ptq.py中与计算数据分布相似度、计算计算量、参数量加权权重相关的程序,使其能够适应于ResNet系列网络。<br> - 已修改get_param_flops.py, extract_ratio.py, ptq.py中与计算数据分布相似度、计算计算量、参数量加权权重相关的程序,使其能够适应于ResNet系列网络。<br>
(注: 目前需要在得到了param_flops_ResNet18/50/152.txt后,手动删除在layer中重复出现了一次的downsample module的统计数据,考虑到resnet18的只需要删3处,resnet50,152中的需要删4处,故直接手动进行了) (注: 目前需要在得到了param_flops_ResNet18/50/152.txt后,手动删除在layer中重复出现了一次的downsample module的统计数据,考虑到resnet18的只需要删3处,resnet50,152中的需要删4处,故直接手动进行了)
......
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
def ebit_list(quant_type, num_bits): def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT': if quant_type == 'FLOAT':
...@@ -22,7 +34,7 @@ def numbit_list(quant_type): ...@@ -22,7 +34,7 @@ def numbit_list(quant_type):
def build_bias_list(quant_type): def build_bias_list(quant_type):
if quant_type == 'POT': if quant_type == 'POT':
return build_pot_list(8) return build_pot_list(8) #
else: else:
return build_float_list(16,7) return build_float_list(16,7)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment