Commit ecee0a11 by Klin

fix: AlexNet_BN: fold bn to conv in full model

parent 3e850e16
# AlexNet_BN 量化说明
+ 结构和AlexNet基本一致,在训练中的每个conv后加了bn层,量化时将conv、bn、relu整合为QConvBNReLu层。
+ 结构和AlexNet基本一致,在训练中的每个conv后加了bn层,量化时将conv、bn、relu合并为QConvBNReLu层。
+ 合并之后如果直接和原全精度模型的conv层进行相似度比对,不符合眼里,且拟合效果较差(可见之前commit的README)
+ 措施:将全精度模型的BN层参数fold至Conv层中,参数量和计算量也相应加至Conv层
+ 该方法不会降低全精度模型推理精度(可为全精度模型加入排除BN层的inference方法,使用fold后模型推理验证)
+ 另外测试了分别单独量化Conv和BN层的方案,精度下降较为明显。同时当前大部分量化策略都采用了Conv+BN的方案,融合能够减少运算量,使得模型更为高效。
## ptq部分
......@@ -8,7 +12,7 @@
FP32-acc:87.09
![image-20230409143615651](image/image-20230409143615651.png)
![image-20230410030841210](image/image-20230410030841210.png)
+ 数据拟合:
......@@ -22,24 +26,24 @@
- [ ] center and scale
![image-20230409133159255](image/image-20230409133159255.png)
![image-20230410030613387](image/image-20230410030613387.png)
- [x] center and scale
![image-20230409133216797](image/image-20230409133216797.png)
![image-20230410030625395](image/image-20230410030625395.png)
+ js_param - acc_loss
+ js_param - acc_loss
Rational: Numerator degree 2 / Denominator degree 2
Rational: Numerator degree 2 / Denominator degree 2
- [ ] center and scale
- [ ] center and scale
![image-20230409133305267](image/image-20230409133305267.png)
![image-20230410030654186](image/image-20230410030654186.png)
- [x] center and scale
- [x] center and scale
![image-20230409133331030](image/image-20230409133331030.png)
![image-20230410030707277](image/image-20230410030707277.png)
+ 加入FP3-FP7后
......@@ -50,19 +54,19 @@
- [ ] center and scale
![image-20230409133616671](image/image-20230409133616671.png)
![image-20230410030018190](image/image-20230410030018190.png)
- [x] center and scale
![image-20230409143802907](image/image-20230409143802907.png)
![image-20230410030035550](image/image-20230410030035550.png)
+ js_param - acc_loss
Rational: Numerator degree 2 / Denominator degree 2
+ [ ] center and scale
![image-20230409143907871](image/image-20230409143907871.png)
![image-20230410030148120](image/image-20230410030148120.png)
+ [x] center and scale
![image-20230409143845569](image/image-20230409143845569.png)
\ No newline at end of file
![image-20230410030206554](image/image-20230410030206554.png)
\ No newline at end of file
......@@ -69,7 +69,7 @@ class AlexNet_BN(nn.Module):
x = self.relu7(x)
x = self.fc3(x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
# e_bits仅当使用FLOAT量化时用到
......
......@@ -13,6 +13,7 @@ import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import torch.utils.bottleneck as bn
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
......@@ -98,15 +99,21 @@ if __name__ == "__main__":
model.eval()
full_acc = full_inference(model, test_loader, device)
model_fold = fold_model(model)
full_params = []
layer, par_ratio, flop_ratio = extract_ratio()
par_ratio, flop_ratio = fold_ratio(layer, par_ratio, flop_ratio)
for name, param in model.named_parameters():
for name, param in model_fold.named_parameters():
if 'bn' in name:
continue
param_norm = F.normalize(param.data.cpu(),p=2,dim=-1)
full_params.append(param_norm)
writer.add_histogram(tag='Full_' + name + '_data', values=param.data)
gol._init()
quant_type_list = ['INT','POT','FLOAT']
title_list = []
......@@ -164,7 +171,7 @@ if __name__ == "__main__":
js_flops = 0.
js_param = 0.
for name, param in model_ptq.named_parameters():
if '.' not in name:
if '.' not in name or 'bn' in name:
continue
idx = idx + 1
prefix = name.split('.')[0]
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
7387.838496104369 2625.646833355205 589.9163523840543 139.95492844215636 33.898346427925894 8.359296461665465 2.1216004520452856 0.5949235039400762 0.21309794043284996 0.11825646935172719 0.09429331800099429 0.08841095083362506 0.08699678964334162 0.08663305654197168 0.08647598591627721 7387.827067868839 1617.97485666175 133.61810656803314 131.5187461175448 131.5184480272645 131.52048525420298 131.520379017506 1067.9230850082097 255.60428668878217 239.45983640568826 94.13499025246169 85.9906047976941 36.80986958674102 54.05786859221271 47.867086055234125 9.63279058044633 36.55164430227947 42.24566761329685 35.471179319270334 2.5217779013551413 9.531897021062111 36.55105201198746 38.2852725131297 30.667455239210526 0.7272006067102655 2.4588845270355084 9.531607285329187 36.55102319987964
7398.25896005357 2629.3755068988844 590.6821882672879 140.0731019449853 33.860481696932794 8.284907785159687 2.03806636256361 0.5092280763429845 0.1268424513574334 0.03186369059971782 0.00784113217166064 0.001986723934379166 0.000524832988970629 0.00015508064451377022 4.1298343551726127e-05 7398.228145381178 1620.2558599788433 133.73012416555693 131.6289258948164 131.62859895093086 131.62894506043 131.6290471316483 1069.3904718014871 255.89338502739713 239.72194310589404 94.18685823406595 86.0282131607227 36.776733058300835 54.05171266778348 47.84959128688678 9.560265134591262 36.52222119751017 42.22214674061798 35.43566072515916 2.4388549985136687 9.465166994736311 36.53673909239142 38.2559090619616 30.62508633611525 0.6417076736777798 2.382631215467162 9.478115589519865 36.536710369840804
js_param_list:
7387.838496104369 2625.646833355205 589.9163523840543 139.95492844215636 33.898346427925894 8.359296461665465 2.1216004520452856 0.5949235039400762 0.21309794043284996 0.11825646935172719 0.09429331800099429 0.08841095083362506 0.08699678964334162 0.08663305654197168 0.08647598591627721 7387.827067868839 1617.97485666175 133.61810656803314 131.5187461175448 131.5184480272645 131.52048525420298 131.520379017506 1067.9230850082097 255.60428668878217 239.45983640568826 94.13499025246169 85.9906047976941 36.80986958674102 54.05786859221271 47.867086055234125 9.63279058044633 36.55164430227947 42.24566761329685 35.471179319270334 2.5217779013551413 9.531897021062111 36.55105201198746 38.2852725131297 30.667455239210526 0.7272006067102655 2.4588845270355084 9.531607285329187 36.55102319987964
7398.25896005357 2629.3755068988844 590.6821882672879 140.0731019449853 33.860481696932794 8.284907785159687 2.03806636256361 0.5092280763429845 0.1268424513574334 0.03186369059971782 0.00784113217166064 0.001986723934379166 0.000524832988970629 0.00015508064451377022 4.1298343551726127e-05 7398.228145381178 1620.2558599788433 133.73012416555693 131.6289258948164 131.62859895093086 131.62894506043 131.6290471316483 1069.3904718014871 255.89338502739713 239.72194310589404 94.18685823406595 86.0282131607227 36.776733058300835 54.05171266778348 47.84959128688678 9.560265134591262 36.52222119751017 42.22214674061798 35.43566072515916 2.4388549985136687 9.465166994736311 36.53673909239142 38.2559090619616 30.62508633611525 0.6417076736777798 2.382631215467162 9.478115589519865 36.536710369840804
ptq_acc_list:
10.0 19.92 48.84 81.62 85.89 86.9 87.03 87.1 87.13 87.09 87.08 87.1 87.08 87.09 87.08 10.0 18.75 39.19 40.72 43.55 41.77 41.42 16.37 58.14 70.31 75.27 81.85 82.79 79.64 84.63 86.46 79.42 81.14 85.76 86.84 78.98 37.06 81.02 85.85 86.88 74.87 41.3 36.46
10.33 15.15 48.54 81.99 85.67 86.8 86.95 87.1 87.14 87.08 87.1 87.05 87.08 87.09 87.08 9.98 19.68 41.69 46.99 39.9 39.59 38.99 16.82 58.19 70.2 75.61 81.71 82.49 79.27 84.74 86.46 79.06 81.3 85.63 86.73 77.44 36.37 81.41 85.78 86.93 75.88 41.43 36.61
acc_loss_list:
0.8851762544494202 0.771271098863245 0.43920082673096794 0.06280858881616717 0.013778849466069614 0.002181651165460991 0.0006889424733035052 -0.00011482374555047542 -0.00045929498220222804 0.0 0.0001148237455506386 -0.00011482374555047542 0.0001148237455506386 0.0 0.0001148237455506386 0.8851762544494202 0.7847054770926628 0.5500057411872775 0.5324377081180388 0.49994258812722475 0.5203812148352279 0.5244000459294982 0.8120335285337007 0.33241474336892873 0.192674245033873 0.13572166724078547 0.06016764266850395 0.049374210586749304 0.08554369043518202 0.028246641405442734 0.0072338959696866415 0.08806981283729477 0.06832012860259505 0.015271558158227101 0.0028705936387644964 0.09312205764152026 0.574463198989551 0.06969801354920206 0.014238144448272006 0.002411298656562268 0.14031461706280857 0.5257779308761052 0.5813526237225859
0.881387070846251 0.8260420254908715 0.4426455390974854 0.058560110230795825 0.01630497186818236 0.0033298886209668878 0.0016075324377081246 -0.00011482374555047542 -0.0005741187277528666 0.0001148237455506386 -0.00011482374555047542 0.0004592949822023912 0.0001148237455506386 0.0 0.0001148237455506386 0.8854059019405213 0.7740268687564588 0.5212998047996326 0.46044321965782525 0.5418532552531864 0.5454127913652543 0.5523022160982891 0.8068664599839248 0.33184062464117586 0.19393730623492939 0.13181765989206573 0.061775175106212075 0.05281892295326683 0.08979216902055354 0.026983580204386366 0.0072338959696866415 0.09220346767711564 0.0664829486737858 0.01676426685038475 0.004133654839820868 0.11080491445630962 0.5823860374325411 0.06521988747272944 0.015041910667125987 0.0018371799288092385 0.1287174187622001 0.5242852221839477 0.5796302675393271
import torch
import torch.nn as nn
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
......@@ -63,4 +63,82 @@ def build_float_list(num_bits,e_bits):
plist.append(flt)
plist.append(-flt)
plist = torch.Tensor(list(set(plist)))
return plist
\ No newline at end of file
return plist
def fold_ratio(layer, par_ratio, flop_ratio):
idx = -1
for name in layer:
idx = idx + 1
if 'bn' in name:
par_ratio[idx-1] += par_ratio[idx]
flop_ratio[idx-1] += flop_ratio[idx]
return par_ratio,flop_ratio
def fold_model(model):
idx = -1
module_list = []
for name, module in model.named_modules():
idx += 1
module_list.append(module)
if 'bn' in name:
module_list[idx-1] = fold_bn(module_list[idx-1],module)
return model
# def fold_model(model):
# last_conv = None
# last_bn = None
# for name, module in model.named_modules():
# if isinstance(module, nn.Conv2d):
# # 如果当前模块是卷积层,则将其 "fold" 到上一个 BN 层中
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# last_conv = module
# elif isinstance(module, nn.BatchNorm2d):
# # 如果当前模块是 BN 层,则将其 "fold" 到上一个卷积层中
# last_bn = module
# if last_conv is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# # 处理最后一个 BN 层
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# return model
def fold_bn(conv, bn):
# 获取 BN 层的参数
gamma = bn.weight.data
beta = bn.bias.data
mean = bn.running_mean
var = bn.running_var
eps = bn.eps
std = torch.sqrt(var + eps)
feat = bn.num_features
# 获取卷积层的参数
weight = conv.weight.data
bias = conv.bias.data
if bn.affine:
gamma_ = gamma / std
weight = weight * gamma_.view(feat, 1, 1, 1)
if bias is not None:
bias = gamma_ * bias - gamma_ * mean + beta
else:
bias = beta - gamma_ * mean
else:
gamma_ = 1 / std
weight = weight * gamma_
if bias is not None:
bias = gamma_ * bias - gamma_ * mean
else:
bias = -gamma_ * mean
# 设置新的 weight 和 bias
conv.weight.data = weight
conv.bias.data = bias
return conv
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment