Commit efd05aee by Zhihong Ma

fix: fakefreeze --- better curve for ResNet & MobileNetV2

parent 3da1eb1f
## update: <br>2023.4.24<br>
- 解决的问题
1. 解决了权值参数相似度不正常,无法拟合曲线的问题。
2. 修改了一些小bug
- 思路记录 <br>
在ResNet中遇到的权值参数相似度不太合理的问题在MobileNetV2中更加显著,使得MobileNet无法拟合曲线。考虑到PTQ后推理精度数据比较正常,则问题可能出在了权值参数的相似度计算上。<br>
我仔细检查了是否成功fold了BN,以及相应的ratio,计算js散度时是否使用的是匹配的两个权值参数,发现都没有问题。但在具体查看权值参数的数据时,发现了一些层的权值参数数据异常,也是这些异常层的js非常大干扰了整体的js计算。<br>
从tensorboard来观察,这些层的数据分布相似度应该与全精度模型很像,但js计算结果没能反应出这一点,对这些层的数据重点观察,他们有很多1或-1的值,因此我想到可能是对量化前后的模型权值参数先进行的normalize操作导致了数据分布变得不合理,进而导致了问题。<br>
考虑到normalize操作的本意应该是为了将量化前后的模型权值参数归一到同一个scale进而方便使用js散度计算距离,则可以考虑将量化后的模型的权值参数通过dequantize来恢复到与全精度模型相近的scale,而后再使用js散度计算距离。我将上述过程命名为fakefreeze. 经过实践,效果很好,重新计算的js散度反映的权值参数相似程度与tensorboard直接对数据分布的观察比较一致。
## update: <br>2023.4.23<br> ## update: <br>2023.4.23<br>
1. 实现了MobileNetV2的PTQ量化 1. 实现了MobileNetV2的PTQ量化
2. 目前存在一些问题:<br> 2. 目前存在一些问题:<br>
......
## update: <br>2023.4.24<br>
在ResNet_nobias中采用了fakefreeze,无需再对PoT进行单独考虑。
## update: <br>2023.4.23<br> ## update: <br>2023.4.23<br>
对PoT量化进行单独的考虑,先只是量化权值,看一下权值分布相似度跟acc的关系。然后选一个acc比较高的PoT权值保持不变,再用不同的数据表示去量化激活,这时候看激活的位宽和acc的关系。<br> 对PoT量化进行单独的考虑,先只是量化权值,看一下权值分布相似度跟acc的关系。然后选一个acc比较高的PoT权值保持不变,再用不同的数据表示去量化激活,这时候看激活的位宽和acc的关系。<br>
......
# -*- coding: utf-8 -*-
from torch.serialization import load
from model import *
from extract_ratio import *
from utils import *
import gol
import openpyxl
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import torch.utils.bottleneck as bn
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
def direct_quantize(model, test_loader,device):
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model.quantize_forward(data).cpu()
if i % 500 == 0:
break
print('direct quantization finish')
def full_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model(data).cpu()
pred = output.argmax(dim=1, keepdim=True)
# print(pred)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
return 100. * correct / len(test_loader.dataset)
def quantize_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model.quantize_inference(data).cpu()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('Test set: Quant Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
return 100. * correct / len(test_loader.dataset)
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PyTorch FP32 Training')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='ResNet18')
parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
parser.add_argument('-j','--workers', default=4, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')
parser.add_argument('-s', '--save', help='Save the output', action='store_true')
# parser.add_argument('-t', '--test', dest='test', action='store_true', help='test model on test set')
# 训练参数
args = parser.parse_args()
batch_size = args.batch_size
num_workers = args.workers
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=True, download=False,
transform=transforms.Compose([
transforms.RandomCrop(32, padding=2),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=False, download=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
)
# model = AlexNet_BN()
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
writer = SummaryWriter(log_dir='log/' + args.model + '/ptq')
full_file = 'ckpt/cifar10_' + args.model + '.pt'
model.load_state_dict(torch.load(full_file))
model.to(device)
load_ptq = False
ptq_file_prefix = 'ckpt/cifar10_' + args.model + '_ptq_'
model.eval()
full_acc = full_inference(model, test_loader, device)
model_fold = fold_model(model) #
full_params = []
layer, par_ratio, flop_ratio = extract_ratio(args.model)
# print(layer)
layer = []
for name, param in model.named_parameters():
if 'weight' in name:
n = name.split('.')
pre = '.'.join(n[:len(n)-1])
# 提取出weight前的名字(就是这个层的名字,if weight是避免bias重复提取一遍名字)
layer.append(pre)
# print(name)
print('===================')
# print(layer)
par_ratio, flop_ratio = fold_ratio(layer, par_ratio, flop_ratio)
# sys.exit()
for name, param in model_fold.named_parameters():
if 'bn' in name or 'sample.1' in name:
continue
# param_norm = F.normalize(param.data.cpu(),p=2,dim=-1)
param_norm = param.data.cpu()
full_params.append(param_norm) # 没统计bn的 只统计了conv的 而且还是fold后的
writer.add_histogram(tag='Full_' + name + '_data', values=param.data)
gol._init()
quant_type_list = ['INT','POT','FLOAT']
title_list = []
js_flops_list = []
js_param_list = []
ptq_acc_list = []
acc_loss_list = []
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
# model_ptq = resnet18()
if args.model == 'ResNet18':
model_ptq = resnet18()
elif args.model == 'ResNet50':
model_ptq = resnet50()
elif args.model == 'ResNet152':
model_ptq = resnet152()
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nPTQ: '+title)
title_list.append(title)
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
# 判断是否需要载入
if load_ptq is True and osp.exists(ptq_file_prefix + title + '.pt'):
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.load_state_dict(torch.load(ptq_file_prefix + title + '.pt'))
model_ptq.to(device)
print('Successfully load ptq model: ' + title)
else:
model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.eval()
direct_quantize(model_ptq, train_loader, device)
if args.save:
torch.save(model_ptq.state_dict(), ptq_file_prefix + title + '.pt')
model_ptq.freeze()
ptq_acc = quantize_inference(model_ptq, test_loader, device)
ptq_acc_list.append(ptq_acc)
acc_loss = (full_acc - ptq_acc) / full_acc
acc_loss_list.append(acc_loss)
idx = -1
model_ptq.fakefreeze()
# 获取计算量/参数量下的js-div
js_flops = 0.
js_param = 0.
for name, param in model_ptq.named_parameters():
# if '.' not in name or 'bn' in name:
if 'bn' in name or 'sample.1' in name:
continue
writer.add_histogram(tag=title +':'+ name + '_data', values=param.data)
idx = idx + 1
# renset中有多个. 需要改写拼一下
# prefix = name.split('.')[0]
n = name.split('.')
prefix = '.'.join(n[:len(n) - 1])
# weight和bias 1:1 ? 对于ratio,是按层赋予的,此处可以对weight和bias再单独赋予不同的权重,比如(8:2)
if prefix in layer:
layer_idx = layer.index(prefix)
ptq_param = param.data.cpu()
# 取L2范数
# ptq_norm = F.normalize(ptq_param,p=2,dim=-1)
ptq_norm = ptq_param
writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param)
# print(name)
# print('=========')
# print(ptq_norm)
# print('=========')
# print(full_params[idx])
js = js_div(ptq_norm,full_params[idx]) # 这里算了fold后的量化前后模型的js距离
js = js.item()
if js < 0.:
js = 0.
js_flops = js_flops + js * flop_ratio[layer_idx]
js_param = js_param + js * par_ratio[layer_idx]
js_flops_list.append(js_flops)
js_param_list.append(js_param)
print(title + ': js_flops: %f js_param: %f acc_loss: %f' % (js_flops, js_param, acc_loss))
# 写入xlsx
workbook = openpyxl.Workbook()
worksheet = workbook.active
worksheet.cell(row=1,column=1,value='FP32-acc')
worksheet.cell(row=1,column=2,value=full_acc)
worksheet.cell(row=3,column=1,value='title')
worksheet.cell(row=3,column=2,value='js_flops')
worksheet.cell(row=3,column=3,value='js_param')
worksheet.cell(row=3,column=4,value='ptq_acc')
worksheet.cell(row=3,column=5,value='acc_loss')
for i in range(len(title_list)):
worksheet.cell(row=i+4, column=1, value=title_list[i])
worksheet.cell(row=i+4, column=2, value=js_flops_list[i])
worksheet.cell(row=i+4, column=3, value=js_param_list[i])
worksheet.cell(row=i+4, column=4, value=ptq_acc_list[i])
worksheet.cell(row=i+4, column=5, value=acc_loss_list[i])
workbook.save('ptq_result_' + args.model + '.xlsx')
writer.close()
ft = open('ptq_result_' + args.model + '.txt','w')
print('title_list:',file=ft)
print(" ".join(title_list),file=ft)
print('js_flops_list:',file=ft)
print(" ".join(str(i) for i in js_flops_list), file=ft)
print('js_param_list:',file=ft)
print(" ".join(str(i) for i in js_param_list), file=ft)
print('ptq_acc_list:',file=ft)
print(" ".join(str(i) for i in ptq_acc_list), file=ft)
print('acc_loss_list:',file=ft)
print(" ".join(str(i) for i in acc_loss_list), file=ft)
ft.close()
...@@ -106,7 +106,12 @@ class ResNet(nn.Module): ...@@ -106,7 +106,12 @@ class ResNet(nn.Module):
# self.qfc1.freeze() # self.qfc1.freeze()
def fakefreeze(self): def fakefreeze(self):
pass self.qconvbnrelu1.fakefreeze()
self.layer1.fakefreeze()
self.layer2.fakefreeze()
self.layer3.fakefreeze()
self.layer4.fakefreeze()
self.qfc1.fakefreeze()
def quantize_inference(self, x): def quantize_inference(self, x):
qx = self.qconvbnrelu1.qi.quantize_tensor(x) qx = self.qconvbnrelu1.qi.quantize_tensor(x)
...@@ -209,6 +214,16 @@ class BasicBlock(nn.Module): ...@@ -209,6 +214,16 @@ class BasicBlock(nn.Module):
self.qrelu1.freeze(qi = self.qelementadd.qo) self.qrelu1.freeze(qi = self.qelementadd.qo)
return self.qrelu1.qi # relu后的qo可用relu统计的qi return self.qrelu1.qi # relu后的qo可用relu统计的qi
def fakefreeze(self):
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
self.qconvbnrelu1.fakefreeze() # 需要接前一个module的最后一个qo
self.qconvbn1.fakefreeze()
if self.downsample is not None:
self.qconvbn2.fakefreeze() # 一条支路
def quantize_inference(self, x): def quantize_inference(self, x):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。 # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
identity = x identity = x
...@@ -318,6 +333,19 @@ class Bottleneck(nn.Module): ...@@ -318,6 +333,19 @@ class Bottleneck(nn.Module):
self.qrelu1.freeze(qi = self.qelementadd.qo) # 需要自己统计qi self.qrelu1.freeze(qi = self.qelementadd.qo) # 需要自己统计qi
return self.qrelu1.qi # relu后的qo可用relu统计的qi return self.qrelu1.qi # relu后的qo可用relu统计的qi
def fakefreeze(self):
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
self.qconvbnrelu1.fakefreeze()
self.qconvbnrelu2.fakefreeze()
self.qconvbn1.fakefreeze()
if self.downsample is not None:
self.qconvbn2.fakefreeze() # 一条支路
def quantize_inference(self, x): def quantize_inference(self, x):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。 # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
identity = x identity = x
...@@ -407,6 +435,9 @@ class MakeLayer(nn.Module): ...@@ -407,6 +435,9 @@ class MakeLayer(nn.Module):
return qo # 供后续的层用 return qo # 供后续的层用
def fakefreeze(self):
for _, layer in self.blockdict.items():
layer.fakefreeze()
def quantize_inference(self, x): def quantize_inference(self, x):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。 # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
......
...@@ -9,24 +9,57 @@ from torch.autograd import Variable ...@@ -9,24 +9,57 @@ from torch.autograd import Variable
from function import FakeQuantize from function import FakeQuantize
# 获取最近的量化值 # 获取最近的量化值
def get_nearest_val(quant_type,x,is_bias=False): # def get_nearest_val(quant_type,x,is_bias=False):
if quant_type=='INT': # if quant_type=='INT':
return x.round_() # return x.round_()
# plist = gol.get_value(is_bias)
# # print('get')
# # print(plist)
# # x = x / 64
# shape = x.shape
# xhard = x.view(-1)
# plist = plist.type_as(x)
# # 取最近幂次作为索引
# idx = (xhard.unsqueeze(0) - plist.unsqueeze(1)).abs().min(dim=0)[1]
# xhard = plist[idx].view(shape)
# xout = (xhard - x).detach() + x
# # xout = xout * 64
# return xout
def get_nearest_val(quant_type, x, is_bias=False, block_size=1000000):
if quant_type == 'INT':
return x.round_()
plist = gol.get_value(is_bias) plist = gol.get_value(is_bias)
# print('get')
# print(plist)
# x = x / 64
shape = x.shape shape = x.shape
xhard = x.view(-1) xhard = x.view(-1)
xout = torch.zeros_like(xhard)
plist = plist.type_as(x) plist = plist.type_as(x)
# 取最近幂次作为索引 n_blocks = (x.numel() + block_size - 1) // block_size
idx = (xhard.unsqueeze(0) - plist.unsqueeze(1)).abs().min(dim=0)[1]
xhard = plist[idx].view(shape) for i in range(n_blocks):
xout = (xhard - x).detach() + x start_idx = i * block_size
# xout = xout * 64 end_idx = min(start_idx + block_size, xhard.numel())
block_size_i = end_idx - start_idx
# print(x.numel())
# print(block_size_i)
# print(start_idx)
# print(end_idx)
xblock = xhard[start_idx:end_idx]
# xblock = xblock.view(shape[start_idx:end_idx])
plist_block = plist.unsqueeze(1) #.expand(-1, block_size_i)
idx = (xblock.unsqueeze(0) - plist_block).abs().min(dim=0)[1]
# print(xblock.shape)
xhard_block = plist[idx].view(xblock.shape)
xout[start_idx:end_idx] = (xhard_block - xblock).detach() + xblock
xout = xout.view(shape)
return xout return xout
# 采用对称有符号量化时,获取量化范围最大值 # 采用对称有符号量化时,获取量化范围最大值
def get_qmax(quant_type,num_bits=None, e_bits=None): def get_qmax(quant_type,num_bits=None, e_bits=None):
if quant_type == 'INT': if quant_type == 'INT':
...@@ -145,6 +178,9 @@ class QModule(nn.Module): ...@@ -145,6 +178,9 @@ class QModule(nn.Module):
def freeze(self): def freeze(self):
pass # 空语句 pass # 空语句
def fakefreeze(self):
pass
def quantize_inference(self, x): def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.') raise NotImplementedError('quantize_inference should be implemented.')
...@@ -198,6 +234,11 @@ class QConv2d(QModule): ...@@ -198,6 +234,11 @@ class QConv2d(QModule):
self.conv_module.bias.data, scale=self.qi.scale * self.qw.scale, self.conv_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0.,qmax=self.bias_qmax, is_bias=True) zero_point=0.,qmax=self.bias_qmax, is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x): # 前向传播,输入张量,x为浮点型数据 def forward(self, x): # 前向传播,输入张量,x为浮点型数据
if hasattr(self, 'qi'): if hasattr(self, 'qi'):
self.qi.update(x) self.qi.update(x)
...@@ -264,6 +305,12 @@ class QLinear(QModule): ...@@ -264,6 +305,12 @@ class QLinear(QModule):
self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale, self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax, is_bias=True) zero_point=0., qmax=self.bias_qmax, is_bias=True)
def fakefreeze(self):
self.fc_module.weight.data = self.qw.dequantize_tensor(self.fc_module.weight.data)
self.fc_module.bias.data = dequantize_tensor(self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x): def forward(self, x):
if hasattr(self, 'qi'): if hasattr(self, 'qi'):
self.qi.update(x) self.qi.update(x)
...@@ -416,6 +463,11 @@ class QConvBNReLU(QModule): ...@@ -416,6 +463,11 @@ class QConvBNReLU(QModule):
zero_point=0., qmax=self.bias_qmax,is_bias=True) zero_point=0., qmax=self.bias_qmax,is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x): def forward(self, x):
if hasattr(self, 'qi'): if hasattr(self, 'qi'):
...@@ -539,6 +591,10 @@ class QConvBN(QModule): ...@@ -539,6 +591,10 @@ class QConvBN(QModule):
bias, scale=self.qi.scale * self.qw.scale, bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True) zero_point=0., qmax=self.bias_qmax,is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x): def forward(self, x):
if hasattr(self, 'qi'): if hasattr(self, 'qi'):
...@@ -648,8 +704,134 @@ class QAdaptiveAvgPool2d(QModule): ...@@ -648,8 +704,134 @@ class QAdaptiveAvgPool2d(QModule):
return x return x
class QConvBNReLU6(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QConvBNReLU6, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu6(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
a = torch.tensor(6)
a = self.qo.quantize_tensor(a)
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
# if self.quant_type is not 'POT':
# x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point # 属于qo范围的数据
x.clamp_(min=0, max=a.item())
return x
class QModule_2(nn.Module): class QModule_2(nn.Module):
def __init__(self,quant_type, qi0=True, qi1=True, qo=True, num_bits=8, e_bits=3): def __init__(self,quant_type, qi0=True, qi1=True, qo=True, num_bits=8, e_bits=3):
......
...@@ -143,7 +143,8 @@ if __name__ == "__main__": ...@@ -143,7 +143,8 @@ if __name__ == "__main__":
for name, param in model_fold.named_parameters(): for name, param in model_fold.named_parameters():
if 'bn' in name or 'sample.1' in name: if 'bn' in name or 'sample.1' in name:
continue continue
param_norm = F.normalize(param.data.cpu(),p=2,dim=-1) # param_norm = F.normalize(param.data.cpu(),p=2,dim=-1)
param_norm = param.data.cpu()
full_params.append(param_norm) # 没统计bn的 只统计了conv的 而且还是fold后的 full_params.append(param_norm) # 没统计bn的 只统计了conv的 而且还是fold后的
writer.add_histogram(tag='Full_' + name + '_data', values=param.data) writer.add_histogram(tag='Full_' + name + '_data', values=param.data)
...@@ -229,7 +230,8 @@ if __name__ == "__main__": ...@@ -229,7 +230,8 @@ if __name__ == "__main__":
layer_idx = layer.index(prefix) layer_idx = layer.index(prefix)
ptq_param = param.data.cpu() ptq_param = param.data.cpu()
# 取L2范数 # 取L2范数
ptq_norm = F.normalize(ptq_param,p=2,dim=-1) # ptq_norm = F.normalize(ptq_param,p=2,dim=-1)
ptq_norm = ptq_param
writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param) writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param)
# print(name) # print(name)
# print('=========') # print('=========')
......
title_list: title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6 INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list: js_flops_list:
1354.2945784950603 722.8096011458398 255.38235121030203 77.72514867219756 24.195293094663096 7.6188281643621805 2.711759981703649 1.1327320388891737 0.4918583902845934 0.2514171979338759 0.0851226729374511 0.031643228209198994 0.014960137087802734 0.014775519758112417 0.014717476906251569 1354.017620655594 530.956126028799 29.270594072095776 23.95970497578537 23.946279778497708 23.945813186814195 23.946195581641064 349.9397526929324 98.90322611449565 77.39375850717117 32.951875627562636 26.29951378038865 7.452277149981623 13.696730185700897 13.489537436477727 2.070642776277882 6.647399487862168 7.9287061447075295 9.3465570209075 0.5745562357250971 1.7289203756753833 6.647398112960826 6.036095650090301 7.68762550328311 0.18322236331308744 0.43723723366736766 1.728919472161257 6.887208440038224 17.39068197353568 4.992727157836896 0.7082449333226905 0.15278217922691148 0.035902245013520316 0.008569443717644162 0.0021334381003386507 0.000530096122199641 0.00013202739139047992 4.320540852718224e-05 1.330992682049459e-05 2.0368878546145547e-05 8.560127359019233e-06 1.5414435348480747e-05 1.034017508399343e-05 13.592069723232997 2.603415308985623 1.0163834751275276 1.0004914009898496 1.0267242270943244 1.0377851066533914 1.0093247382340054 2.4882911516734794 0.8719669278759606 0.39824198801738997 0.48712327345618334 0.13437272845045903 0.15283618430614965 0.36366539038391377 0.067001336425203 0.04023162172835145 0.15278158008432788 0.3153947829963702 0.04619076816233325 0.010272439305968415 0.040191841741053155 0.15277940050779587 0.2944708799554367 0.03873975046512612 0.002562897125930476 0.010259629159672047 0.040194168479249566 0.16940471118115202
js_param_list: js_param_list:
1905.034397666333 1304.431928190669 532.9886077118246 162.16424238527497 45.90216325868581 12.55083851407541 3.737229506223789 1.1881391381377582 0.4333202294403889 0.17966875415589165 0.0686934709653976 0.02711635083384763 0.012422401703485457 0.011924650749714109 0.011811756854569257 1904.9690246678717 1024.4893074529332 44.47069839549869 37.182205895517406 37.17084546399217 37.170360480922994 37.170596963709116 676.9975549510849 192.141761781115 146.37833455189465 54.830161088403294 46.75208855879235 11.459473919635343 17.814252125337802 23.67092841602859 3.110218252992642 10.411529270878503 8.13295996799088 16.519087866858648 0.8965427653389706 2.6972225487456454 10.411531041285842 5.310143987405656 13.920591998362562 0.3160476684964184 0.6824627078929684 2.697221260834613 10.486319878844261 13.75589548965823 6.139934297648725 0.9983266223973442 0.21685218331277975 0.05073819961183949 0.012236724314191861 0.0030479968659331298 0.0007642922990751943 0.00018892417520173234 6.914265082028416e-05 2.0654949731546248e-05 1.2782403100869643e-05 9.44194520770214e-06 1.7287066209449194e-05 1.2478257919601581e-05 13.5254986055193 3.463308629604464 1.160512064451701 1.1084529894485604 1.193382464700844 1.2041072939829902 1.2033002903326413 2.4474367498054836 0.7570219992909055 0.5246367584038191 0.3844942455969616 0.18715129322472332 0.13583544042939869 0.2740889994593467 0.10076411331478825 0.03570359182314611 0.13571819932622736 0.2344246202009843 0.07290873555706959 0.008973915201922977 0.03566098435431149 0.13571299092732977 0.21809584818526528 0.0623769948948551 0.00226774688205569 0.00895695770659421 0.03566482115654845 0.18569380791656256
ptq_acc_list: ptq_acc_list:
10.0 10.0 18.77 75.9 86.88 88.89 89.36 89.46 89.51 89.48 89.47 89.5 89.5 89.49 89.48 10.0 11.15 45.68 42.59 32.54 49.07 45.67 13.83 30.06 63.49 63.24 81.67 81.33 74.09 84.67 88.05 80.51 78.4 86.44 89.1 87.93 79.02 78.99 86.51 89.29 89.18 88.01 82.21 10.0 9.97 23.88 76.77 87.61 88.95 89.11 89.49 89.46 89.51 89.44 89.49 89.5 89.49 89.48 10.0 10.2 41.62 46.35 41.57 33.14 18.97 11.7 27.9 58.44 63.17 81.54 80.64 71.73 85.63 88.05 80.84 76.92 86.23 88.79 87.8 81.29 79.62 86.98 89.18 89.08 88.03 82.63
acc_loss_list: acc_loss_list:
0.8882431828341529 0.8882431828341529 0.790232454179705 0.15176575771122036 0.029056772463120346 0.006593652212785018 0.0013410818059902162 0.00022351363433180857 -0.00033527045149755406 0.0 0.00011175681716590428 -0.00022351363433164976 -0.00022351363433164976 -0.00011175681716574548 0.0 0.8882431828341529 0.8753911488600804 0.4894948591864104 0.5240277156906571 0.6363433169423335 0.4516092981671882 0.48960661600357624 0.8454403218596335 0.6640590075994636 0.29045596781403665 0.2932498882431828 0.08728207420652662 0.09108180599016547 0.17199374161823872 0.053755029056772485 0.015981224854716213 0.10024586499776485 0.12382655341975858 0.03397407241841759 0.004246759052302298 0.01732230666070627 0.11689763075547617 0.11723290120697372 0.03319177469825658 0.0021233795261510697 0.0033527045149753815 0.016428252123379512 0.08124720607957096 0.8882431828341529 0.8885784532856504 0.7331247206079572 0.14204291461779175 0.02089852481001346 0.00592311130978991 0.004135002235136394 -0.00011175681716574548 0.00022351363433180857 -0.00033527045149755406 0.0004470272686634583 -0.00011175681716574548 -0.00022351363433164976 -0.00011175681716574548 0.0 0.8882431828341529 0.8860080464908359 0.5348681269557444 0.4820071524362986 0.5354269110415736 0.6296379079123826 0.787997317836388 0.8692445239159589 0.6881984801072866 0.3468931604827895 0.2940321859633438 0.08873491282968259 0.09879302637460889 0.19836835046937862 0.04302637460885123 0.015981224854716213 0.09655789003129191 0.140366562360304 0.03632096557890031 0.007711220384443425 0.018775145283862392 0.09152883325882875 0.11019222172552524 0.027939204291461777 0.0033527045149753815 0.004470272686633948 0.016204738489047864 0.07655341975860537
## update: <br>2023.4.24<br>
- 解决的问题:
1. PoT拟合异常 <br>
2. 曲线拟合效果较差
- 思路记录<br>
在解决MobileNetV2遇到的js散度,曲线拟合问题时,发现了F.normalize对相似度计算的影响很大,其虽然能够帮助把量化后的权值参数调整成与全精度模型相同的scale,但是导致数据分布改变过大,进而导致了出现了反常的关系和拟合结果。具体思路和分析可见MobileNetV2的readme,此处同样也使用了fakefreeze来将量化后的权值参数dequantize,使其与全精度模型的权值参数处于相同的scale,再直接用js散度计算距离。
<br>
在ResNet系列上进行了重新实验,ResNet50,152的数据和曲线待程序运行完后再补充. <br>
ResNet18:
<img src = "fig/18_flops_f.png" class="h-90 auto">
<img src = "fig/18_params_f.png" class="h-90 auto">
可以看到曲线拟合效果有了非常好的提升。
## update: <br>2023.4.23<br> ## update: <br>2023.4.23<br>
1. 本文件夹下的ResNet系列模型的Conv层不具有bias,与经典的ResNet一致。但在含有fold BN操作的ConvBN,ConvBNReLU后,Conv会具有bias,在代码中进行了相应修改以适配。 1. 本文件夹下的ResNet系列模型的Conv层不具有bias,与经典的ResNet一致。但在含有fold BN操作的ConvBN,ConvBNReLU后,Conv会具有bias,在代码中进行了相应修改以适配。
2. ResNet18的无bias的acc比有bias的略高。且无bias的情况下,训练速度明显增快。 2. ResNet18的无bias的acc比有bias的略高。且无bias的情况下,训练速度明显增快。
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment