Commit 4d897197 by Klin

feat: forbid store_ptq when ptq

parent 50852902
from model import *
from extract_ratio import *
from utils import *
import openpyxl
import os
import os.path as osp
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm)
if __name__ == "__main__":
wb = openpyxl.Workbook()
ws = wb.active
writer = SummaryWriter(log_dir='./log')
layer, par_ratio, flop_ratio = extract_ratio()
dir_prefix = 'ckpt/qat/epoch_'
quant_type_list = ['INT','POT','FLOAT']
for epoch in [5,10,15,20]:
ws_epoch = wb.create_sheet('epoch_%d'%epoch)
full_state = torch.load(dir_prefix+'%d/'%epoch + 'full.pt')
ws_epoch.cell(row=1,column=2,value='loss')
ws_epoch.cell(row=1,column=3,value='loss_sum')
ws_epoch.cell(row=1,column=4,value='loss_avg')
ws_epoch.cell(row=2,column=1,value='FP32')
ws_epoch.cell(row=2,column=2,value=full_state['loss'].cpu().item())
ws_epoch.cell(row=2,column=3,value=full_state['loss_sum'].cpu().item())
ws_epoch.cell(row=2,column=4,value=full_state['loss_avg'].cpu().item())
# full_grad = full_state['grad_dict']
# full_grad_sum = full_state['grad_dict_sum']
full_grad_avg = full_state['grad_dict_avg']
for name,tmpgrad in full_grad_avg.items():
writer.add_histogram('FULL: '+name,tmpgrad,global_step=epoch)
ws_epoch.cell(row=4,column=1,value='title')
ws_epoch.cell(row=4,column=2,value='loss')
ws_epoch.cell(row=4,column=3,value='loss_sum')
ws_epoch.cell(row=4,column=4,value='loss_avg')
ws_epoch.cell(row=4,column=5,value='js_grad_avg_norm')
ws_epoch.cell(row=4,column=6,value='conv1.weight')
ws_epoch.cell(row=4,column=7,value='conv1.bias')
ws_epoch.cell(row=4,column=8,value='conv2.weight')
ws_epoch.cell(row=4,column=9,value='conv2.bias')
ws_epoch.cell(row=4,column=10,value='conv3.weight')
ws_epoch.cell(row=4,column=11,value='conv3.bias')
ws_epoch.cell(row=4,column=12,value='conv4.weight')
ws_epoch.cell(row=4,column=13,value='conv4.bias')
ws_epoch.cell(row=4,column=14,value='conv5.weight')
ws_epoch.cell(row=4,column=15,value='conv5.bias')
ws_epoch.cell(row=4,column=16,value='fc1.weight')
ws_epoch.cell(row=4,column=17,value='fc1.bias')
ws_epoch.cell(row=4,column=18,value='fc2.weight')
ws_epoch.cell(row=4,column=19,value='fc2.bias')
ws_epoch.cell(row=4,column=20,value='fc3.weight')
ws_epoch.cell(row=4,column=21,value='fc3.bias')
currow=4
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nAnalyse: '+title)
currow += 1
qat_state=torch.load(dir_prefix+'%d/'%epoch+title+'.pt')
js_grad_avg_norm=0.
grad_avg = qat_state['grad_dict_avg']
for name,tmpgrad in grad_avg.items():
writer.add_histogram(title+': '+name,tmpgrad,global_step=epoch)
colidx=5
for name,_ in full_grad_avg.items():
prefix = name.split('.')[0]
colidx += 1
layer_idx = layer.index(prefix)
js_norm = js_div_norm(full_grad_avg[name],grad_avg[name])
ws_epoch.cell(row=currow,column=colidx,value=js_norm.cpu().item())
js_grad_avg_norm += flop_ratio[layer_idx] * js_norm
ws_epoch.cell(row=currow,column=1,value=title)
ws_epoch.cell(row=currow,column=2,value=qat_state['loss'].cpu().item())
ws_epoch.cell(row=currow,column=3,value=qat_state['loss_sum'].cpu().item())
ws_epoch.cell(row=currow,column=4,value=qat_state['loss_avg'].cpu().item())
ws_epoch.cell(row=currow,column=5,value=js_grad_avg_norm.cpu().item())
wb.save('loss_grad.xlsx')
writer.close()
\ No newline at end of file
......@@ -50,20 +50,8 @@ def quantize_inference(model, test_loader, device):
return 100. * correct / len(test_loader.dataset)
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
if __name__ == "__main__":
batch_size = 64
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
......@@ -94,6 +82,7 @@ if __name__ == "__main__":
model.to(device)
load_ptq = True
store_ptq = False
ptq_file_prefix = 'ckpt/cifar10_AlexNet_ptq_'
model.eval()
......@@ -152,7 +141,8 @@ if __name__ == "__main__":
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.eval()
direct_quantize(model_ptq, train_loader, device)
torch.save(model_ptq.state_dict(), ptq_file_prefix + title + '.pt')
if store_ptq:
torch.save(model_ptq.state_dict(), ptq_file_prefix + title + '.pt')
model_ptq.freeze()
ptq_acc = quantize_inference(model_ptq, test_loader, device)
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
7507.750226317713 2739.698390971301 602.5613310246055 140.92197221503724 34.51721888016634 8.518508718865842 2.1353732883428638 0.5319411628570782 0.1316271020831477 0.03249564657892055 0.008037284252895557 0.0020460099353784723 0.00041867764927864105 0.0001321614950419231 5.841430176387608e-06 7507.667348902921 1654.3775934528933 136.7401730898288 134.5782970456457 134.57841422062364 134.5783939274636 134.5782945727605 1054.3432278105702 244.48311696489273 247.89704518368768 87.65672091651302 89.63831617681878 37.95288917539117 48.439491059469624 50.122494451137555 9.763717383777191 37.67666314899965 37.082531966568794 37.162725668876305 2.504495253790035 9.660221946273799 37.70544899186622 33.31638902152266 32.12034308540418 0.6541880874259414 2.442034364817909 9.688117360231624 37.70544899186622
7507.75063214733 2739.6984955933212 602.5612756622503 140.9219606122552 34.51723734934774 8.51850258417184 2.135387955471241 0.5319397685779753 0.13161209621640868 0.03248819537407536 0.008041006527013596 0.002041353551137856 0.00044022134597488164 0.0001237480336660566 2.404456269420319e-07 7507.667754908634 1654.3776843456685 136.74013174028667 134.57831309525932 134.57842088419864 134.57840070743188 134.5783010453692 33.31639395389063 32.12035410835974 0.6541864185268057 2.442042655846908 9.68812852009895 37.70545171879492
js_param_list:
7507.750226317713 2739.698390971301 602.5613310246055 140.92197221503724 34.51721888016634 8.518508718865842 2.1353732883428638 0.5319411628570782 0.1316271020831477 0.03249564657892055 0.008037284252895557 0.0020460099353784723 0.00041867764927864105 0.0001321614950419231 5.841430176387608e-06 7507.667348902921 1654.3775934528933 136.7401730898288 134.5782970456457 134.57841422062364 134.5783939274636 134.5782945727605 1054.3432278105702 244.48311696489273 247.89704518368768 87.65672091651302 89.63831617681878 37.95288917539117 48.439491059469624 50.122494451137555 9.763717383777191 37.67666314899965 37.082531966568794 37.162725668876305 2.504495253790035 9.660221946273799 37.70544899186622 33.31638902152266 32.12034308540418 0.6541880874259414 2.442034364817909 9.688117360231624 37.70544899186622
7507.75063214733 2739.6984955933212 602.5612756622503 140.9219606122552 34.51723734934774 8.51850258417184 2.135387955471241 0.5319397685779753 0.13161209621640868 0.03248819537407536 0.008041006527013596 0.002041353551137856 0.00044022134597488164 0.0001237480336660566 2.404456269420319e-07 7507.667754908634 1654.3776843456685 136.74013174028667 134.57831309525932 134.57842088419864 134.57840070743188 134.5783010453692 33.31639395389063 32.12035410835974 0.6541864185268057 2.442042655846908 9.68812852009895 37.70545171879492
ptq_acc_list:
10.0 10.16 51.21 77.39 83.03 84.73 84.84 85.01 85.08 85.07 85.06 85.08 85.08 85.08 85.08 10.0 14.32 72.49 72.65 72.95 72.08 72.23 24.42 66.66 47.53 77.89 76.18 81.78 81.76 81.37 84.11 81.87 82.02 82.5 84.72 84.18 52.15 82.73 83.3 85.01 84.77 59.86 51.87
10.0 10.16 51.21 77.39 83.03 84.73 84.84 85.01 85.08 85.07 85.06 85.07 85.08 85.08 85.08 10.0 14.32 72.49 72.65 72.95 72.08 72.24 82.73 83.3 85.01 84.77 59.86 51.87
acc_loss_list:
0.8824635637047484 0.8805829807240245 0.3980959097320169 0.0903855195110484 0.02409496944052653 0.004113775270333736 0.0028208744710859768 0.0008227550540666805 0.0 0.00011753643629531167 0.0002350728725904563 0.0 0.0 0.0 0.0 0.8824635637047484 0.8316878232251997 0.14797837329572172 0.14609779031499756 0.14257169722614005 0.152797367183827 0.15103432063939815 0.7129760225669958 0.21650211565585334 0.44134931828866947 0.08450869769628583 0.10460742830277377 0.03878702397743297 0.039022096850023426 0.04360601786553824 0.011401034320639386 0.03772919605077567 0.035966149506346995 0.030324400564174875 0.004231311706629048 0.010578279266572538 0.3870474847202633 0.027621062529384042 0.020921485660554785 0.0008227550540666805 0.0036436295251528242 0.29642689233662434 0.39033850493653033
0.8824635637047484 0.8805829807240245 0.3980959097320169 0.0903855195110484 0.02409496944052653 0.004113775270333736 0.0028208744710859768 0.0008227550540666805 0.0 0.00011753643629531167 0.0002350728725904563 0.00011753643629531167 0.0 0.0 0.0 0.8824635637047484 0.8316878232251997 0.14797837329572172 0.14609779031499756 0.14257169722614005 0.152797367183827 0.150916784203103 0.027621062529384042 0.020921485660554785 0.0008227550540666805 0.0036436295251528242 0.29642689233662434 0.39033850493653033
......@@ -45,10 +45,11 @@ def quantize_aware_training(model, device, train_loader, optimizer, epoch):
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / len(train_loader)
loss_avg = loss_sum / len(train_loader)
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def full_inference(model, test_loader):
......@@ -92,10 +93,11 @@ def train(model, device, train_loader, optimizer, epoch):
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / len(train_loader)
loss_avg = loss_sum / len(train_loader)
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
......@@ -111,7 +113,7 @@ def quantize_inference(model, test_loader):
if __name__ == "__main__":
batch_size = 64
batch_size = 32
seed = 1
epochs = 20
lr = 0.001
......@@ -145,8 +147,7 @@ if __name__ == "__main__":
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
load_quant_model_file = None
# load_quant_model_file = "ckpt/cifar10_AlexNet_qat_ratio_4.pt"
load_qat = False
ckpt_prefix = "ckpt/qat/"
loss_sum = 0.
......@@ -156,6 +157,7 @@ if __name__ == "__main__":
grad_dict_sum[name] = torch.zeros_like(param)
grad_dict_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
# 训练原模型,获取梯度分布
loss,grad_dict = train(model, device, train_loader, optimizer, epoch)
# print('loss:%f' % loss_avg)
writer.add_scalar('Full.loss',loss,epoch)
......@@ -173,8 +175,6 @@ if __name__ == "__main__":
'loss' : loss,
'loss_sum' : loss_sum,
'loss_avg' : loss_avg,
'grad_dict' : grad_dict,
'grad_dict_sum' : grad_dict_sum,
'grad_dict_avg' : grad_dict_avg
}
if epoch % 5 == 0:
......@@ -188,6 +188,7 @@ if __name__ == "__main__":
# writer.add_histogram('qat_'+name+'_grad',grad,global_step=epoch)
quant_type_list = ['INT','POT','FLOAT']
gol._init()
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
......@@ -200,13 +201,18 @@ if __name__ == "__main__":
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
model_ptq = AlexNet()
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nQAT: '+title)
if load_qat is True and osp.exists(ckpt_prefix+'epoch_20/'+title+'.pt'):
continue
print('\nQAT: '+title)
model_ptq = AlexNet()
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
......@@ -214,6 +220,7 @@ if __name__ == "__main__":
model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.train()
loss_sum = 0.
grad_dict_sum = {}
......@@ -222,7 +229,7 @@ if __name__ == "__main__":
grad_dict_sum[name] = torch.zeros_like(param)
grad_dict_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
loss,grad_dict = train(model, device, train_loader, optimizer, epoch)
loss,grad_dict = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
# print('loss:%f' % loss_avg)
writer.add_scalar(title+'.loss',loss,epoch)
for name,grad in grad_dict.items():
......@@ -230,7 +237,7 @@ if __name__ == "__main__":
loss_sum += loss
loss_avg = loss_sum / epoch
for name,grad in grad_dict.items():
for name,param in model.named_parameters():
grad_dict_sum[name] += grad_dict[name]
grad_dict_avg[name] = grad_dict_sum[name] / epoch
......@@ -239,8 +246,8 @@ if __name__ == "__main__":
'loss' : loss,
'loss_sum' : loss_sum,
'loss_avg' : loss_avg,
'grad_dict' : grad_dict,
'grad_dict_sum' : grad_dict_sum,
# 'grad_dict' : grad_dict,
# 'grad_dict_sum' : grad_dict_sum,
'grad_dict_avg' : grad_dict_avg
}
if epoch % 5 == 0:
......
import torch
import torch.nn as nn
import torch.nn.functional as F
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
def ebit_list(quant_type, num_bits):
......@@ -15,7 +28,8 @@ def numbit_list(quant_type):
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
else:
num_bit_list = list(range(2,9))
# num_bit_list = list(range(2,9))
num_bit_list = [8]
return num_bit_list
......
......@@ -69,7 +69,7 @@ if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
......@@ -80,7 +80,7 @@ if __name__ == "__main__":
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
......@@ -95,6 +95,7 @@ if __name__ == "__main__":
model.to(device)
load_ptq = True
store_ptq = False
ptq_file_prefix = 'ckpt/cifar10_AlexNet_BN_ptq_'
model.eval()
......@@ -158,7 +159,8 @@ if __name__ == "__main__":
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.eval()
direct_quantize(model_ptq, train_loader, device)
torch.save(model_ptq.state_dict(), ptq_file_prefix + title + '.pt')
if store_ptq:
torch.save(model_ptq.state_dict(), ptq_file_prefix + title + '.pt')
model_ptq.freeze()
ptq_acc = quantize_inference(model_ptq, test_loader, device)
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
POT_8
js_flops_list:
7398.25896005357 2629.3755068988844 590.6821882672879 140.0731019449853 33.860481696932794 8.284907785159687 2.03806636256361 0.5092280763429845 0.1268424513574334 0.03186369059971782 0.00784113217166064 0.001986723934379166 0.000524832988970629 0.00015508064451377022 4.1298343551726127e-05 7398.228145381178 1620.2558599788433 133.73012416555693 131.6289258948164 131.62859895093086 131.62894506043 131.6290471316483 1069.3904718014871 255.89338502739713 239.72194310589404 94.18685823406595 86.0282131607227 36.776733058300835 54.05171266778348 47.84959128688678 9.560265134591262 36.52222119751017 42.22214674061798 35.43566072515916 2.4388549985136687 9.465166994736311 36.53673909239142 38.2559090619616 30.62508633611525 0.6417076736777798 2.382631215467162 9.478115589519865 36.536710369840804
131.62910457064856
js_param_list:
7398.25896005357 2629.3755068988844 590.6821882672879 140.0731019449853 33.860481696932794 8.284907785159687 2.03806636256361 0.5092280763429845 0.1268424513574334 0.03186369059971782 0.00784113217166064 0.001986723934379166 0.000524832988970629 0.00015508064451377022 4.1298343551726127e-05 7398.228145381178 1620.2558599788433 133.73012416555693 131.6289258948164 131.62859895093086 131.62894506043 131.6290471316483 1069.3904718014871 255.89338502739713 239.72194310589404 94.18685823406595 86.0282131607227 36.776733058300835 54.05171266778348 47.84959128688678 9.560265134591262 36.52222119751017 42.22214674061798 35.43566072515916 2.4388549985136687 9.465166994736311 36.53673909239142 38.2559090619616 30.62508633611525 0.6417076736777798 2.382631215467162 9.478115589519865 36.536710369840804
131.62910457064856
ptq_acc_list:
10.33 15.15 48.54 81.99 85.67 86.8 86.95 87.1 87.14 87.08 87.1 87.05 87.08 87.09 87.08 9.98 19.68 41.69 46.99 39.9 39.59 38.99 16.82 58.19 70.2 75.61 81.71 82.49 79.27 84.74 86.46 79.06 81.3 85.63 86.73 77.44 36.37 81.41 85.78 86.93 75.88 41.43 36.61
42.38
acc_loss_list:
0.881387070846251 0.8260420254908715 0.4426455390974854 0.058560110230795825 0.01630497186818236 0.0033298886209668878 0.0016075324377081246 -0.00011482374555047542 -0.0005741187277528666 0.0001148237455506386 -0.00011482374555047542 0.0004592949822023912 0.0001148237455506386 0.0 0.0001148237455506386 0.8854059019405213 0.7740268687564588 0.5212998047996326 0.46044321965782525 0.5418532552531864 0.5454127913652543 0.5523022160982891 0.8068664599839248 0.33184062464117586 0.19393730623492939 0.13181765989206573 0.061775175106212075 0.05281892295326683 0.08979216902055354 0.026983580204386366 0.0072338959696866415 0.09220346767711564 0.0664829486737858 0.01676426685038475 0.004133654839820868 0.11080491445630962 0.5823860374325411 0.06521988747272944 0.015041910667125987 0.0018371799288092385 0.1287174187622001 0.5242852221839477 0.5796302675393271
0.5133769663566425
......@@ -14,6 +14,7 @@ def numbit_list(quant_type):
num_bit_list = list(range(2,17))
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
# num_bit_list = [8]
else:
num_bit_list = list(range(2,9))
# num_bit_list = [8]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment