Commit bff8b078 by Klin

feat: Inception BN: new frame and naive ptq

parent c5562a31
## Inception_BN 量化说明
## 新框架说明
+ 使用cfg_table进行模型快速部署和量化(可应用于其他模型),cfg_table提供整体forward平面结构,包括inc模块。inc模块则由inc_ch_table和inc_cfg_table进行部署量化。其规则可详见文件
+ cfg_table每项对应一个可进行量化融合的模块,如相邻conv bn relu可融合,在cfg_table中表现为['C','BR',...]。从而可以更方便的从该表进行量化和flops/param权重提取
+ 更改fold_ratio方法,以支持cfg_table的同项融合,多考虑了relu层,摆脱原先临近依赖的限制。方案:读取到conv时,获取相同前后缀的层,并相加
+ 更改module,允许量化层传入层的bias为none。原fold_bn已经考虑了,只需要改动freeze即可。
+ 对于Conv,freeze前后都无bias。
+ 对于ConvBN,freeze前后有bias。forward使用临时值,inference使用固定值(固定到相应conv_module)。
+ 由于允许conv.bias=None,相应改变全精度模型fold_bn方法,从而保证量化前后可比参数相同。改写方式同量化层
+ 更改js_div计算方法,一个层如果同时有多个参数,例如weight和bias,应该总共加起来权重为1。当前直接简单取平均(即js除以该层参数量),后续考虑加权。PS: Inception_BN中,外层conv层有bias,Inception模块内由于后接bn层,bias为false
+ 由于named_parameters迭代器长度不固定,需要先将排成固定列表再处理,从而获得同一层参数数,改动见ptq.py。对全精度模型做此操作即可
## ptq部分
+ 量化结果
![Inception_BN_table](image/Inception_BN_table.png)
+ 拟合结果![flops](image/flops.png)
+ ![param](image/param.png)
### debug
+ 观察量化结果可知,POT量化精度损失较大。尝试在Inception BN网络的不同位置反量化,观察POT量化的精度效果。即整体结构如下:量化->前半部分量化层->反量化->后半部分全精度层
+ 模型结构
Inception_BN_cfg_table = [
['C','',True,3,64,3,1,1],
['R'],
['C','',False,64,64,3,1,1],
['R'],
['Inc',0],
['Inc',1],
['MP',3,2,1],
['Inc',2],
['Inc',3],
['Inc',4],
['Inc',5],
['Inc',6],
['MP',3,2,1],
['Inc',7],
['Inc',8],
['AAP',1],
['C','',False,1024,10,1,1,0],
['F']
]
+ 反量化位置:
+ 最后
| Title | ptq2 | ptq3 | ptq4 | ptq5 | ptq6 | ptq7 | ptq8 |
| ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
| acc | 10.00 | 10.03 | 17.55 | 21.00 | 22.80 | 24.38 | 16.54 |
+ AAP层前
| Title | ptq2 | ptq3 | ptq4 | ptq5 | ptq6 | ptq7 | ptq8 |
| ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
| acc | 10.00 | 12.39 | 21.06 | 26.21 | 23.88 | 25.56 | 30.04 |
+ 第一个inc前
| Title | ptq2 | ptq3 | ptq4 | ptq5 | ptq6 | ptq7 | ptq8 |
| ----- | ---- | ----- | ----- | ----- | ----- | ----- | ----- |
| acc | 9.99 | 22.07 | 85.71 | 85.72 | 85.36 | 85.66 | 85.70 |
+ 第二个inc前
| Title | ptq2 | ptq3 | ptq4 | ptq5 | ptq6 | ptq7 | ptq8 |
| ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
| acc | 10.01 | 14.67 | 82.92 | 81.54 | 82.05 | 82.41 | 83.10 |
+ 第二个inc后
| Title | ptq2 | ptq3 | ptq4 | ptq5 | ptq6 | ptq7 | ptq8 |
| ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
| acc | 10.00 | 12.53 | 42.75 | 64.89 | 20.54 | 66.84 | 34.05 |
+ 根据不同反量化位置,初步推断,POT精度损失应该不是量化框架问题,而是该模型本身与量化方式不适配。
\ No newline at end of file
import sys
import os
# 从get_param.py输出重定向文件val.txt中提取参数量和计算量
def extract_ratio():
fr = open('param_flops.txt','r')
lines = fr.readlines()
layer = []
par_ratio = []
flop_ratio = []
for line in lines:
if '(' in line and ')' in line:
layer.append(line.split(')')[0].split('(')[1])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
par_ratio.append(r1)
r2 = line.split('%')[-2].split(',')[-1]
r2 = float(r2)
flop_ratio.append(r2)
return layer, par_ratio, flop_ratio
if __name__ == "__main__":
layer, par_ratio, flop_ratio = extract_ratio()
print(layer)
print(par_ratio)
print(flop_ratio)
\ No newline at end of file
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
\ No newline at end of file
from model import *
import torch
from ptflops import get_model_complexity_info
if __name__ == "__main__":
model = Inception_BN()
full_file = 'ckpt/cifar10_Inception_BN.pt'
model.load_state_dict(torch.load(full_file))
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
from torch.serialization import load
from model import *
from extract_ratio import *
from utils import *
import gol
import openpyxl
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import torch.utils.bottleneck as bn
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
def direct_quantize(model, test_loader,device):
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model.quantize_forward(data).cpu()
if i % 500 == 0:
break
print('direct quantization finish')
def full_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model(data).cpu()
pred = output.argmax(dim=1, keepdim=True)
# print(pred)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
return 100. * correct / len(test_loader.dataset)
def quantize_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model.quantize_inference(data).cpu()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('Test set: Quant Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
return 100. * correct / len(test_loader.dataset)
if __name__ == "__main__":
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = Inception_BN()
writer = SummaryWriter(log_dir='./log')
full_file = 'ckpt/cifar10_Inception_BN.pt'
model.load_state_dict(torch.load(full_file))
model.to(device)
load_ptq = True
store_ptq = False
ptq_file_prefix = 'ckpt/cifar10_Inception_BN_ptq_'
model.eval()
full_acc = full_inference(model, test_loader, device)
# 传入后可变
fold_model(model)
layer, par_ratio, flop_ratio = extract_ratio()
par_ratio, flop_ratio = fold_ratio(layer, par_ratio, flop_ratio)
full_names = []
full_params = []
for name, param in model.named_parameters():
if 'conv' in name or 'fc' in name:
full_names.append(name)
param_norm = F.normalize(param.data.cpu(),p=2,dim=-1)
full_params.append(param_norm)
writer.add_histogram(tag='Full_' + name + '_data', values=param.data)
#统计每个参数对应层的参数个数
full_par_num=[]
for name in full_names:
prefix = name.rsplit('.',1)[0]
cnt = 0
for str in full_names:
sprefix = str.rsplit('.',1)[0]
if prefix == sprefix:
cnt += 1
full_par_num.append(cnt)
# print(full_names)
# print(full_par_num)
# print('-------')
# input()
gol._init()
# quant_type_list = ['INT','POT','FLOAT']
# quant_type_list = ['INT']
quant_type_list = ['POT']
title_list = []
js_flops_list = []
js_param_list = []
ptq_acc_list = []
acc_loss_list = []
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
model_ptq = Inception_BN()
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nPTQ: '+title)
title_list.append(title)
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
# 判断是否需要载入
if load_ptq is True and osp.exists(ptq_file_prefix + title + '.pt'):
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.load_state_dict(torch.load(ptq_file_prefix + title + '.pt'))
model_ptq.to(device)
print('Successfully load ptq model: ' + title)
else:
model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.eval()
direct_quantize(model_ptq, train_loader, device)
if store_ptq:
torch.save(model_ptq.state_dict(), ptq_file_prefix + title + '.pt')
model_ptq.freeze()
ptq_acc = quantize_inference(model_ptq, test_loader, device)
ptq_acc_list.append(ptq_acc)
acc_loss = (full_acc - ptq_acc) / full_acc
acc_loss_list.append(acc_loss)
# 获取计算量/参数量下的js-div
js_flops = 0.
js_param = 0.
for name, param in model_ptq.named_parameters():
if 'conv' not in name and 'fc' not in name:
continue
prefix = name.rsplit('.',1)[0]
layer_idx = layer.index(prefix)
name_idx = full_names.index(name)
layer_idx = layer.index(prefix)
ptq_param = param.data.cpu()
# 取L2范数
ptq_norm = F.normalize(ptq_param,p=2,dim=-1)
writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param)
js = js_div(ptq_norm,full_params[name_idx])
js /= full_par_num[name_idx]
js = js.item()
if js < 0.:
js = 0.
js_flops = js_flops + js * flop_ratio[layer_idx]
js_param = js_param + js * par_ratio[layer_idx]
js_flops_list.append(js_flops)
js_param_list.append(js_param)
print(title + ': js_flops: %f js_param: %f acc_loss: %f' % (js_flops, js_param, acc_loss))
# 写入xlsx
workbook = openpyxl.Workbook()
worksheet = workbook.active
worksheet.cell(row=1,column=1,value='FP32-acc')
worksheet.cell(row=1,column=2,value=full_acc)
worksheet.cell(row=3,column=1,value='title')
worksheet.cell(row=3,column=2,value='js_flops')
worksheet.cell(row=3,column=3,value='js_param')
worksheet.cell(row=3,column=4,value='ptq_acc')
worksheet.cell(row=3,column=5,value='acc_loss')
for i in range(len(title_list)):
worksheet.cell(row=i+4, column=1, value=title_list[i])
worksheet.cell(row=i+4, column=2, value=js_flops_list[i])
worksheet.cell(row=i+4, column=3, value=js_param_list[i])
worksheet.cell(row=i+4, column=4, value=ptq_acc_list[i])
worksheet.cell(row=i+4, column=5, value=acc_loss_list[i])
workbook.save('ptq_result.xlsx')
writer.close()
ft = open('ptq_result.txt','w')
print('title_list:',file=ft)
print(" ".join(title_list),file=ft)
print('js_flops_list:',file=ft)
print(" ".join(str(i) for i in js_flops_list), file=ft)
print('js_param_list:',file=ft)
print(" ".join(str(i) for i in js_param_list), file=ft)
print('ptq_acc_list:',file=ft)
print(" ".join(str(i) for i in ptq_acc_list), file=ft)
print('acc_loss_list:',file=ft)
print(" ".join(str(i) for i in acc_loss_list), file=ft)
ft.close()
title_list:
POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8
js_flops_list:
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Inception(nn.Module):
def __init__(self,channel,batch_norm=False):
super(Inception, self).__init__()
if batch_norm==False:
self.branch1x1=nn.Conv2d(channel[0],channel[1],kernel_size=(1,1),stride=1)
self.branch3x3_1=nn.Conv2d(channel[0],channel[2],kernel_size=(1,1),stride=1)
self.branch3x3_2=nn.Conv2d(channel[2],channel[3],kernel_size=(3,3),stride=1,padding=1)
self.branch5x5_1=nn.Conv2d(channel[0],channel[4],kernel_size=(1,1),stride=1)
self.branch5x5_2=nn.Conv2d(channel[4],channel[5],kernel_size=(5,5),stride=1,padding=2)
self.branchM_1=nn.MaxPool2d(kernel_size=3,stride=1,padding=1)
self.branchM_2=nn.Conv2d(channel[0],channel[6],kernel_size=(1,1),stride=1)
else:
self.branch1x1=BasicConv2d(channel[0],channel[1],kernel_size=(1,1),stride=1)
self.branch3x3_1=BasicConv2d(channel[0],channel[2],kernel_size=(1,1),stride=1)
self.branch3x3_2=BasicConv2d(channel[2],channel[3],kernel_size=(3,3),stride=1,padding=1)
self.branch5x5_1=BasicConv2d(channel[0],channel[4],kernel_size=(1,1),stride=1)
self.branch5x5_2=BasicConv2d(channel[4],channel[5],kernel_size=(5,5),stride=1,padding=2)
self.branchM_1=nn.MaxPool2d(kernel_size=3,stride=1,padding=1)
self.branchM_2=BasicConv2d(channel[0],channel[6],kernel_size=(1,1),stride=1)
self.relu=nn.ReLU(True)
def forward(self,x):
branch1x1=self.relu(self.branch1x1(x))
branch3x3_1=self.relu(self.branch3x3_1(x))
branch3x3_2=self.relu(self.branch3x3_2(branch3x3_1))
branch5x5_1=self.relu(self.branch5x5_1(x))
branch5x5_2=self.relu(self.branch5x5_2(branch5x5_1))
branchM_1=self.relu(self.branchM_1(x))
branchM_2=self.relu(self.branchM_2(branchM_1))
outputs = [branch1x1, branch3x3_2, branch5x5_2, branchM_2]
return torch.cat(outputs,1)
channel=[
[192, 64, 96,128, 16, 32, 32],#3a
[256,128,128,192, 32, 96, 64],#3b
[480,192, 96,208, 16, 48, 64],#4a
[512,160,112,224, 24, 64, 64],#4b
[512,128,128,256, 24, 64, 64],#4c
[512,112,144,288, 32, 64, 64],#4d
[528,256,160,320, 32,128,128],#4e
[832,256,160,320, 32,128,128],#5a
[832,384,192,384, 48,128,128] #5b
]
class InceptionNet(nn.Module):
def __init__(self,num_classes=1000,batch_norm=False):
super(InceptionNet, self).__init__()
if num_classes==10:
channel[0][0]=64
self.begin=nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=1),
nn.ReLU(True),
nn.Conv2d(64,64,kernel_size=3,stride=1),
nn.ReLU(True)
)
self.auxout1=nn.Sequential(
nn.Conv2d(512,512,kernel_size=5,stride=3), #4x4x512
nn.ReLU(True),
nn.Conv2d(512,128,kernel_size=1), #4x4x128
nn.ReLU(True),
nn.Conv2d(128, 10,kernel_size=4) #1x1x10
)
self.auxout2=nn.Sequential(
nn.Conv2d(528,528,kernel_size=5,stride=3), #4x4x528,
nn.ReLU(True),
nn.Conv2d(528,128,kernel_size=1), #4x4x128,
nn.ReLU(True),
nn.Conv2d(128, 10,kernel_size=4) #1x1x10
)
else:
self.begin=nn.Sequential(
nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
nn.Conv2d(64,192,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
)
self.auxout1=nn.Sequential(
nn.Conv2d(512,512,kernel_size=5,stride=3),#4x4x512
nn.ReLU(True),
nn.Conv2d(512,128,kernel_size=1), #4x4x128
nn.ReLU(True)
)
self.auxout12=nn.Sequential(
nn.Linear(2048,1024),
nn.Dropout(0.5),
nn.linear(1024,num_classes)
)
self.auxout2=nn.Sequential(
nn.Conv2d(528,528,kernel_size=5,stride=3),#4x4x528
nn.ReLU(True),
nn.Conv2d(528,128,kernel_size=1), #4x4x128
nn.ReLU(True)
)
self.auxout22=nn.Sequential(
nn.Linear(2048,1024),
nn.Dropout(0.5),
nn.linear(1024,num_classes)
)
self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.inception3a=Inception(channel[0],batch_norm)
self.inception3b=Inception(channel[1],batch_norm)
self.inception4a=Inception(channel[2],batch_norm)
self.inception4b=Inception(channel[3],batch_norm)
self.inception4c=Inception(channel[4],batch_norm)
self.inception4d=Inception(channel[5],batch_norm)
self.inception4e=Inception(channel[6],batch_norm)
self.inception5a=Inception(channel[7],batch_norm)
self.inception5b=Inception(channel[8],batch_norm)
self.avgpool=nn.AdaptiveAvgPool2d((1,1))
self.conv1x1=nn.Conv2d(1024,num_classes,kernel_size=1)
self._initialize_weights()
'''
#follow the original papar,but for the computation ,I do not use it
self.drop=nn.Dropout()
self.linear=nn.Linear(1024,1000)
'''
def _initialize_weights(self):
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias,0)
elif isinstance(m,nn.BatchNorm2d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias,0)
elif isinstance(m,nn.Linear):
nn.init.normal_(m.weight,0,0.01)
nn.init.constant_(m.bias,0)
def forward(self,x):
x=self.begin(x)
x=self.inception3a(x)
x=self.inception3b(x)
x=self.maxpool(x)
x=self.inception4a(x)
auxout1=self.auxout1(x)
auxout1=auxout1.view(auxout1.size(0),-1)
#if you use this network to train on ImageNet you should add this code
#auxout1=self.auxout12(auxout1)
x=self.inception4b(x)
x=self.inception4c(x)
x=self.inception4d(x)
auxout2=self.auxout2(x)
auxout2=auxout2.view(auxout2.size(0),-1)
#if you use this network to train on ImageNet you should add this code
#auxout2=self.auxout22(auxout2)
x=self.inception4e(x)
x=self.maxpool(x)
x=self.inception5a(x)
x=self.inception5b(x)
x=self.avgpool(x)
outputs=self.conv1x1(x)
outputs=outputs.view(outputs.size(0),-1)
return outputs,auxout1,auxout2
if __name__ == '__main__':
net=InceptionNet(num_classes=10,batch_norm=True)
print(net)
from model import *
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += lossLayer(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
test_loss, 100. * correct / len(test_loader.dataset)
))
if __name__ == "__main__":
batch_size = 32
seed = 1
epochs_cfg = [20, 30, 30, 20, 20, 10, 10]
lr_cfg = [0.01, 0.008, 0.005, 0.002, 0.001, 0.0005, 0.0001]
momentum = 0.5
save_model = True
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = Inception_BN().to(device)
epoch_start = 1
for epochs,lr in zip(epochs_cfg,lr_cfg):
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
epoch_end = epoch_start+epochs
for epoch in range(epoch_start,epoch_end):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
epoch_start += epochs
if save_model:
if not osp.exists('ckpt'):
os.makedirs('ckpt')
torch.save(model.state_dict(), 'ckpt/cifar10_Inception_BN.pt')
\ No newline at end of file
import torch
import torch.nn as nn
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
e_bit_list = list(range(1,num_bits-1))
else:
e_bit_list = [0]
return e_bit_list
def numbit_list(quant_type):
if quant_type == 'INT':
num_bit_list = list(range(2,17))
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
else:
num_bit_list = list(range(2,9))
# num_bit_list = [8]
return num_bit_list
def build_bias_list(quant_type):
if quant_type == 'POT':
return build_pot_list(8)
else:
return build_float_list(16,7)
def build_list(quant_type, num_bits, e_bits):
if quant_type == 'POT':
return build_pot_list(num_bits)
else:
return build_float_list(num_bits,e_bits)
def build_pot_list(num_bits):
plist = [0.]
for i in range(-2 ** (num_bits-1) + 2, 1):
# i最高到0,即pot量化最大值为1
plist.append(2. ** i)
plist.append(-2. ** i)
plist = torch.Tensor(list(set(plist)))
# plist = plist.mul(1.0 / torch.max(plist))
return plist
def build_float_list(num_bits,e_bits):
m_bits = num_bits - 1 - e_bits
plist = [0.]
# 相邻尾数的差值
dist_m = 2 ** (-m_bits)
e = -2 ** (e_bits - 1) + 1
for m in range(1, 2 ** m_bits):
frac = m * dist_m # 尾数部分
expo = 2 ** e # 指数部分
flt = frac * expo
plist.append(flt)
plist.append(-flt)
for e in range(-2 ** (e_bits - 1) + 2, 2 ** (e_bits - 1) + 1):
expo = 2 ** e
for m in range(0, 2 ** m_bits):
frac = 1. + m * dist_m
flt = frac * expo
plist.append(flt)
plist.append(-flt)
plist = torch.Tensor(list(set(plist)))
return plist
#此处不必cfg,直接取同前缀同后缀即可。将relu一起考虑进去
def fold_ratio(layer, par_ratio, flop_ratio):
idx = -1
for name in layer:
if 'conv' in name:
conv_idx = layer.index(name)
[prefix,suffix] = name.split('conv')
bn_name = prefix+'bn'+suffix
relu_name = prefix+'relu'+suffix
if bn_name in layer:
bn_idx = layer.index(bn_name)
par_ratio[conv_idx]+=par_ratio[bn_idx]
flop_ratio[conv_idx]+=flop_ratio[bn_idx]
if relu_name in layer:
relu_idx = layer.index(relu_name)
par_ratio[conv_idx]+=par_ratio[relu_idx]
flop_ratio[conv_idx]+=flop_ratio[bn_idx]
return par_ratio,flop_ratio
def fold_model(model):
for name, module in model.named_modules():
if 'conv' in name:
[prefix,suffix] = name.split('conv')
bn_name = prefix+'bn'+suffix
if hasattr(model,bn_name):
bn_layer = getattr(model,bn_name)
fold_bn(module,bn_layer)
def fold_bn(conv, bn):
# 获取 BN 层的参数
mean = bn.running_mean
var = bn.running_var
eps = bn.eps
std = torch.sqrt(var + eps)
if bn.affine:
gamma_ = bn.weight / std
weight = conv.weight * gamma_.view(conv.out_channels, 1, 1, 1)
if conv.bias is not None:
bias = gamma_ * conv.bias - gamma_ * mean + bn.bias
else:
bias = bn.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = conv.weight * gamma_
if conv.bias is not None:
bias = gamma_ * conv.bias - gamma_ * mean
else:
bias = -gamma_ * mean
# 设置新的 weight 和 bias
conv.weight.data = weight.data
if conv.bias is not None:
conv.bias.data = bias.data
else:
conv.bias = torch.nn.Parameter(bias)
# 改动说明
## update:2023/04/17
## update: 2023/04/22
+ 添加Inception BN模型,对框架改动如下
+ 使用cfg_table进行模型快速部署和量化(可应用于其他模型),cfg_table提供整体forward平面结构,包括inc模块。inc模块则由inc_ch_table和inc_cfg_table进行部署量化。其规则可详见文件
+ cfg_table每项对应一个可进行量化融合的模块,如相邻conv bn relu可融合,在cfg_table中表现为['C','BR',...]。从而可以更方便的从该表进行量化和flops/param权重提取
+ 更改fold_ratio方法,以支持cfg_table的同项融合,多考虑了relu层,摆脱原先临近依赖的限制。方案:读取到conv时,获取相同前后缀的层,并相加
+ 更改module,允许量化层传入层的bias为none。原fold_bn已经考虑了,只需要改动freeze即可。
+ 对于Conv,freeze前后都无bias。
+ 对于ConvBN,freeze前后有bias。forward使用临时值,inference使用固定值(固定到相应conv_module)。
+ 由于允许conv.bias=None,相应改变全精度模型fold_bn方法,从而保证量化前后可比参数相同。改写方式同量化层
+ 更改js_div计算方法,一个层如果同时有多个参数,例如weight和bias,应该总共加起来权重为1。当前直接简单取平均(即js除以该层参数量),后续考虑加权。PS: Inception_BN中,外层conv层有bias,Inception模块内由于后接bn层,bias为false
+ 由于named_parameters迭代器长度不固定,需要先将排成固定列表再处理,从而获得同一层参数数,改动见ptq.py。对全精度模型做此操作即可
+ 新框架中的model_utils方法可以通过调整反量化位置来进行bug的确定。经过当前实验,可以初步判断精度问题出现在inception结构中,具体信息见Inception_BN相关部分。经过排查,量化框架本身并未出现问题,问题可能在于该模型参数分布与POT集中分布的不适配。
## update: 2023/04/17
+ 指定了新的梯度学习率方案,对全精度模型重新训练以达到更高的acc,并重新进行ptq和fit
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment