Commit c5562a31 by Klin

fix: retrain reptq and refit for better fp32-acc

parent 1f5b02d5
ykl/AlexNet/image/AlexNet_table.png

21.3 KB | W: | H:

ykl/AlexNet/image/AlexNet_table.png

20.9 KB | W: | H:

ykl/AlexNet/image/AlexNet_table.png
ykl/AlexNet/image/AlexNet_table.png
ykl/AlexNet/image/AlexNet_table.png
ykl/AlexNet/image/AlexNet_table.png
  • 2-up
  • Swipe
  • Onion skin
ykl/AlexNet/image/flops.png

33 KB | W: | H:

ykl/AlexNet/image/flops.png

33.7 KB | W: | H:

ykl/AlexNet/image/flops.png
ykl/AlexNet/image/flops.png
ykl/AlexNet/image/flops.png
ykl/AlexNet/image/flops.png
  • 2-up
  • Swipe
  • Onion skin
ykl/AlexNet/image/param.png

32.7 KB | W: | H:

ykl/AlexNet/image/param.png

32.2 KB | W: | H:

ykl/AlexNet/image/param.png
ykl/AlexNet/image/param.png
ykl/AlexNet/image/param.png
ykl/AlexNet/image/param.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -8,6 +8,19 @@ from torch.autograd import Variable
from function import FakeQuantize
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
# 获取最近的量化值
def get_nearest_val(quant_type,x,is_bias=False):
if quant_type=='INT':
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
7507.750225073234 2739.7076966853715 602.5612265737566 140.92196299779113 34.51721972907417 8.518507426398175 2.1353734024227116 0.5319409334311419 0.13162665658268388 0.03249625709236131 0.00803756234599581 0.0020459635613112545 0.00041898539095542464 0.00013219370584495937 5.889692605810293e-06 7507.667349327719 1654.3776698748943 136.74005401701154 134.57824579197353 134.57836801317453 134.5784177398131 134.5783104900291 1054.3432294422712 244.48311706582172 247.89704517757673 87.6567210381318 89.6383158999559 37.95288939287757 48.43949061178506 50.12249448664991 9.763717307655865 37.67666344437851 37.082531959820756 37.162726030896145 2.504495273130303 9.660222213843284 37.67755581522706 33.31639025927671 32.12034332380884 0.654188829204954 2.4420339872993106 9.660223950420662 37.71766886303616
7596.158354929403 3342.630224687211 803.7421515317819 193.09625855378914 47.294430438986815 11.617764000749704 2.8409912757716502 0.7002681399367439 0.17784808389949117 0.0436832697576101 0.011038433846755926 0.0027589181015611444 0.0006869155287621905 0.0001593813201971203 0.0001203343519776774 7596.100728242634 2133.4439203566753 134.29969268901075 130.9402314639402 130.94019523793065 130.9402499980709 130.94030592234444 1277.7023613784083 289.82743245674186 302.10808960554925 92.40277930093731 107.5666593496828 36.700788172354116 44.407474473657686 60.09530043518254 9.507477271994002 36.255231843587175 31.324843649420213 44.77256767415278 2.4592434345083602 9.346483017228227 36.2552369826946 27.3682985947986 38.75922946625283 0.6654846714873894 2.3614977284616923 9.346479578793366 36.25830464235682
js_param_list:
2099.4012777078437 756.8892728120519 165.48621719007886 38.661065835467824 9.465502337873694 2.33808893548631 0.5869253659714873 0.14619333548441132 0.03612248684512918 0.008927375183637454 0.002218583272882395 0.0005696575186261605 0.00011790447981599561 5.3025835391648186e-05 9.207995013302385e-06 2099.3808052778672 455.65423486787154 38.27924274988919 37.69144524899607 37.69150679435307 37.69151359769784 37.691495756398155 292.0014018519822 67.9502067901384 68.50198967225035 24.584935828703593 24.79733981410337 10.623398482526541 13.713032956745856 13.858531189296505 2.7327807059031084 10.548006045861452 10.549392886768755 10.275096064814912 0.7005038658637505 2.704395835634312 10.548258699586787 9.49326525014626 8.881738559127504 0.18278915543612456 0.6834360936949567 2.704395905153418 10.565651241352946
2124.6901484547043 924.4573506960227 220.4801169201279 52.8504167013976 12.953296497807385 3.1853253087773608 0.7775012981543739 0.19165897232008872 0.04881389226260701 0.011927057061170981 0.003025213095666009 0.0007628906685143884 0.00019075741735802466 4.5081513290070486e-05 3.399542619773311e-05 2124.672207113351 587.9132319943658 37.55091744333372 36.635514313284844 36.63549899741198 36.63550255682983 36.63554387096417 352.76533450990684 80.14627267392913 83.49214909064641 25.77374926912755 29.78985840193075 10.26657509432266 12.541439619277627 16.648877037050198 2.658753685911974 10.145569681413992 8.911419624239798 12.401484350536222 0.6875383356517406 2.6148026740520978 10.145568951377017 7.808360107452724 10.735180287226289 0.18568651930732344 0.6608485065237775 2.614801761970877 10.151356149109917
ptq_acc_list:
10.0 10.05 56.03 78.36 83.22 84.34 85.08 84.93 85.08 85.1 85.06 85.07 85.07 85.08 85.08 10.0 15.14 72.69 72.45 72.68 72.07 72.75 24.65 63.25 57.28 77.86 74.77 82.16 81.39 81.22 84.02 81.89 81.97 82.79 84.73 84.16 81.93 82.83 83.41 84.91 84.77 83.97 81.99
10.0 10.1 54.63 76.21 84.14 85.66 86.1 86.1 86.12 86.09 86.09 86.08 86.08 86.05 86.08 10.0 13.77 74.96 75.06 74.89 74.91 74.79 20.04 67.57 50.1 77.82 79.72 82.7 81.61 83.88 85.78 82.82 82.84 84.97 86.16 85.76 82.66 82.83 85.16 86.17 86.22 85.95 82.78
acc_loss_list:
0.8824635637047484 0.8818758815232722 0.34144334743770566 0.07898448519040902 0.021861777150916778 0.008697696285848553 0.0 0.0017630465444286728 0.0 -0.0002350728725904563 0.0002350728725904563 0.00011753643629531167 0.00011753643629531167 0.0 0.0 0.8824635637047484 0.8220498354489891 0.14562764456981667 0.14844851904090262 0.1457451810061118 0.1529149036201223 0.1449224259520451 0.710272684532205 0.2565820404325341 0.32675129290079924 0.0848613070051716 0.12118006582040436 0.03432063939821347 0.04337094499294779 0.04536906440996708 0.01245886224729669 0.037494123178185214 0.03655383168782322 0.026915843911612506 0.004113775270333736 0.01081335213916316 0.03702397743300413 0.026445698166431594 0.019628584861307027 0.001998119417019296 0.0036436295251528242 0.013046544428772913 0.03631875881523276
0.8838289962825279 0.8826672862453532 0.3653578066914498 0.11466078066914503 0.022537174721189566 0.004879182156133849 -0.000232342007434898 -0.000232342007434898 -0.00046468401486996115 -0.00011617100371753155 -0.00011617100371753155 0.0 0.0 0.00034851301115242957 0.0 0.8838289962825279 0.8400325278810409 0.12918215613382905 0.12802044609665422 0.12999535315985128 0.12976301115241637 0.13115706319702594 0.7671933085501857 0.21503252788104096 0.41798327137546465 0.09595724907063204 0.07388475836431226 0.03926579925650552 0.05192843866171003 0.0255576208178439 0.0034851301115241306 0.03787174721189597 0.03763940520446091 0.012894981412639398 -0.0009293680297397572 0.0037174721189590287 0.03973048327137549 0.03775557620817844 0.010687732342007455 -0.0010455390334572886 -0.0016263940520446162 0.0015102230483270847 0.03833643122676577
from model import *
from extract_ratio import *
from utils import *
import openpyxl
import gol
import sys
import argparse
......@@ -11,6 +14,13 @@ import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm).cpu().item()
def quantize_aware_training(model, device, train_loader, optimizer, epoch):
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
......@@ -115,7 +125,7 @@ if __name__ == "__main__":
batch_size = 32
seed = 1
epochs = 20
epochs = 15
lr = 0.001
momentum = 0.5
......@@ -123,8 +133,13 @@ if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
writer = SummaryWriter(log_dir='./log/qat')
wb = openpyxl.Workbook()
ws = wb.active
layer, par_ratio, flop_ratio = extract_ratio()
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
......@@ -133,7 +148,7 @@ if __name__ == "__main__":
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
......@@ -151,17 +166,16 @@ if __name__ == "__main__":
ckpt_prefix = "ckpt/qat/"
loss_sum = 0.
grad_dict_sum = {}
grad_dict_avg = {}
full_grad_sum = {}
full_grad_avg = {}
for name,param in model.named_parameters():
grad_dict_sum[name] = torch.zeros_like(param)
grad_dict_avg[name] = torch.zeros_like(param)
full_grad_sum[name] = torch.zeros_like(param)
full_grad_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
# 训练原模型,获取梯度分布
loss,grad_dict = train(model, device, train_loader, optimizer, epoch)
loss,full_grad = train(model, device, train_loader, optimizer, epoch)
if epoch == 1:
loss_start = loss
loss_delta = loss - loss_start
# print('loss:%f' % loss_avg)
writer.add_scalar('Full.loss',loss,epoch)
# for name,grad in grad_dict.items():
......@@ -169,30 +183,38 @@ if __name__ == "__main__":
loss_sum += loss
loss_avg = loss_sum / epoch
for name,grad in grad_dict.items():
grad_dict_sum[name] += grad_dict[name]
grad_dict_avg[name] = grad_dict_sum[name] / epoch
if store_qat:
ckpt = {
'epoch' : epoch,
'loss' : loss,
'loss_sum' : loss_sum,
'loss_avg' : loss_avg,
'grad_dict_avg' : grad_dict_avg
}
if epoch % 5 == 0:
subdir = 'epoch_%d/' % epoch
torch.save(ckpt,ckpt_prefix+ subdir +'full.pt')
loss_delta = loss - loss_start
for name,grad in full_grad.items():
full_grad_sum[name] += full_grad[name]
full_grad_avg[name] = full_grad_sum[name] / epoch
# loss_avg,grad_dict = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
# print('qat_loss:%f' % loss_avg)
# for name,grad in grad_dict.items():
# writer.add_histogram('qat_'+name+'_grad',grad,global_step=epoch)
if epoch % 5 == 0:
ws = wb.create_sheet('epoch_%d'%epoch)
ws.cell(row=1,column=2,value='loss')
ws.cell(row=1,column=3,value='loss_sum')
ws.cell(row=1,column=4,value='loss_avg')
ws.cell(row=1,column=5,value='loss_delta')
ws.cell(row=2,column=1,value='FP32')
ws.cell(row=2,column=2,value=loss.item())
ws.cell(row=2,column=3,value=loss_sum.item())
ws.cell(row=2,column=4,value=loss_avg.item())
ws.cell(row=2,column=5,value=loss_delta.item())
ws.cell(row=4,column=1,value='title')
ws.cell(row=4,column=2,value='loss')
ws.cell(row=4,column=3,value='loss_sum')
ws.cell(row=4,column=4,value='loss_avg')
ws.cell(row=4,column=5,value='loss_delta')
ws.cell(row=4,column=6,value='js_grad')
ws.cell(row=4,column=7,value='js_grad_sum')
ws.cell(row=4,column=8,value='js_grad_avg')
quant_type_list = ['INT','POT','FLOAT']
gol._init()
currow=4 #数据从哪行开始写
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
......@@ -214,6 +236,7 @@ if __name__ == "__main__":
if load_qat is True and osp.exists(ckpt_prefix+'epoch_20/'+title+'.pt'):
continue
currow += 1
print('\nQAT: '+title)
model_ptq = AlexNet()
......@@ -221,82 +244,54 @@ if __name__ == "__main__":
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
model_ptq.load_state_dict(torch.load(full_file))
# model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.train()
loss_sum = 0.
grad_dict_sum = {}
grad_dict_avg = {}
qat_grad_sum = {}
qat_grad_avg = {}
for name,param in model.named_parameters():
grad_dict_sum[name] = torch.zeros_like(param)
grad_dict_avg[name] = torch.zeros_like(param)
qat_grad_sum[name] = torch.zeros_like(param)
qat_grad_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
loss,grad_dict = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
loss,qat_grad = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
# print('loss:%f' % loss_avg)
if epoch == 1:
loss_start = loss
writer.add_scalar(title+'.loss',loss,epoch)
for name,grad in grad_dict.items():
writer.add_histogram(title+'.'+name+'_grad',grad,global_step=epoch)
# for name,grad in qat_grad.items():
# writer.add_histogram(title+'.'+name+'_grad',grad,global_step=epoch)
loss_sum += loss
loss_avg = loss_sum / epoch
loss_avg = loss_sum / epoch
loss_delta = loss-loss_start
for name,param in model.named_parameters():
grad_dict_sum[name] += grad_dict[name]
grad_dict_avg[name] = grad_dict_sum[name] / epoch
ckpt = {
'epoch' : epoch,
'loss' : loss,
'loss_sum' : loss_sum,
'loss_avg' : loss_avg,
# 'grad_dict' : grad_dict,
# 'grad_dict_sum' : grad_dict_sum,
'grad_dict_avg' : grad_dict_avg
}
qat_grad_sum[name] += qat_grad[name]
qat_grad_avg[name] = qat_grad_sum[name] / epoch
if epoch % 5 == 0:
subdir = 'epoch_%d/' % epoch
torch.save(ckpt,ckpt_prefix+subdir + title+'.pt')
ws = wb['epoch_%d'%epoch]
js_grad = 0.
js_grad_sum = 0.
js_grad_avg = 0.
for name,_ in model_ptq.named_parameters():
prefix = name.split('.')[0]
layer_idx = layer.index(prefix)
js_grad += flop_ratio[layer_idx] * js_div_norm(qat_grad[name],full_grad[name])
js_grad_sum += flop_ratio[layer_idx] * js_div_norm(qat_grad_sum[name],full_grad_sum[name])
js_grad_avg += flop_ratio[layer_idx] * js_div_norm(qat_grad_avg[name],full_grad_avg[name])
ws.cell(row=currow,column=1,value=title)
ws.cell(row=currow,column=2,value=loss.item())
ws.cell(row=currow,column=3,value=loss_sum.item())
ws.cell(row=currow,column=4,value=loss_avg.item())
ws.cell(row=currow,column=5,value=loss_delta.item())
ws.cell(row=currow,column=6,value=js_grad)
ws.cell(row=currow,column=7,value=js_grad_sum)
ws.cell(row=currow,column=8,value=js_grad_avg)
wb.remove(wb['Sheet']) # 根据名称删除工作表
wb.save('qat_result.xlsx')
writer.close()
# # model.eval()
# # full_inference(model, test_loader)
# num_bits = 8
# e_bits = 0
# gol._init()
# print("qat: INT8")
# model.quantize('INT',num_bits,e_bits)
# print('Quantization bit: %d' % num_bits)
# if load_quant_model_file is not None:
# model.load_state_dict(torch.load(load_quant_model_file))
# print("Successfully load quantized model %s" % load_quant_model_file)
# else:
# model.train()
# for epoch in range(1, epochs+1):
# quantize_aware_training(model, device, train_loader, optimizer, epoch)
# # for epoch in range(epochs1 + 1, epochs2 + 1):
# # quantize_aware_training(model, device, train_loader, optimizer2, epoch)
# model.eval()
# # torch.save(model.state_dict(), save_file)
# model.freeze()
# # for name, param in model.named_parameters():
# # print(name)
# # print(param.data)
# # print('----------')
# # for param_tensor, param_value in model.state_dict().items():
# # print(param_tensor, "\t", param_value)
# quantize_inference(model, test_loader)
......@@ -46,14 +46,9 @@ def test(model, device, test_loader):
if __name__ == "__main__":
batch_size = 32
test_batch_size = 32
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
epochs_cfg = [20, 30, 20, 20, 10]
lr_cfg = [0.01, 0.005, 0.001, 0.0005, 0.0001]
momentum = 0.5
save_model = True
......@@ -62,7 +57,7 @@ if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
......@@ -73,31 +68,24 @@ if __name__ == "__main__":
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = AlexNet().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
epoch_start = 1
for epochs,lr in zip(epochs_cfg,lr_cfg):
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
epoch_end = epoch_start+epochs
for epoch in range(epoch_start,epoch_end):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
epoch_start += epochs
if save_model:
if not osp.exists('ckpt'):
......
import torch
import torch.nn as nn
import torch.nn.functional as F
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
......
ykl/AlexNet_BN/image/flops.png

36.8 KB | W: | H:

ykl/AlexNet_BN/image/flops.png

36.8 KB | W: | H:

ykl/AlexNet_BN/image/flops.png
ykl/AlexNet_BN/image/flops.png
ykl/AlexNet_BN/image/flops.png
ykl/AlexNet_BN/image/flops.png
  • 2-up
  • Swipe
  • Onion skin
ykl/AlexNet_BN/image/param.png

36.2 KB | W: | H:

ykl/AlexNet_BN/image/param.png

36.4 KB | W: | H:

ykl/AlexNet_BN/image/param.png
ykl/AlexNet_BN/image/param.png
ykl/AlexNet_BN/image/param.png
ykl/AlexNet_BN/image/param.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -8,6 +8,19 @@ from torch.autograd import Variable
from function import FakeQuantize
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
# 获取最近的量化值
def get_nearest_val(quant_type,x,is_bias=False):
if quant_type=='INT':
......
......@@ -51,18 +51,6 @@ def quantize_inference(model, test_loader, device):
return 100. * correct / len(test_loader.dataset)
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
if __name__ == "__main__":
batch_size = 32
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
7398.254588242588 2629.375369920873 590.682242850979 140.07309702607543 33.8605061564944 8.284907955771828 2.0380761928718356 0.5092245742053088 0.1268421502575885 0.03186153637549241 0.00786669870124002 0.0019702839340457524 0.0005105761502405256 0.00014720966403738878 5.009152681172546e-05 7398.228439126393 1620.255743626143 133.73060307719862 131.62858417356293 131.6287169902594 131.62886974968603 131.62916373182327 1069.390418247601 255.8934244865269 239.72195987976787 94.18685986225425 86.02822309304119 36.77674350011185 54.051731941424315 47.84958061374913 9.560250444993518 36.520065154425616 42.22217884431455 35.43566024130792 2.43887086039784 9.461428083033894 36.52006172442561 38.25590959917856 30.625094006227837 0.6417067659753078 2.378241078221592 9.461428733351953 36.54256264242457
7315.84863077462 2447.6990359739557 554.9539055398842 130.3057613397198 31.72854680369856 7.739889012396159 1.9503014204883173 0.47476923180404534 0.11812657336071226 0.03002884209117736 0.007400953973311122 0.001932886049776867 0.00044452340426047595 9.522531481526753e-05 2.9848570843407086e-06 7315.828621431085 1509.6239277040813 132.669444214901 130.7354806274539 130.73505398794512 130.73538243275962 130.7352909353478 1040.798298841082 253.67138749112172 224.5416693145032 97.71660734169551 80.35741771506743 36.57394847264219 58.08625652423495 44.466842189195326 9.423972626618646 36.3279041181201 46.176068440021766 32.78338584270386 2.408595598953836 9.337974512410408 36.32789770432847 41.972310900423636 28.345412246072254 0.6323033918133659 2.3565326800100244 9.337984500288826 36.35255195223587
js_param_list:
2072.6236878099226 729.7443476264982 163.58861069796387 38.80282172149893 9.382636343553832 2.2970053703034847 0.5644822291532099 0.14115822853411863 0.035139363239340055 0.00883850565434335 0.002179718866202066 0.0005459844381131149 0.00014248590416669505 4.154122620906531e-05 1.4063815487843632e-05 2072.605571769845 448.9786991943886 37.51683456097356 36.93614282789083 36.93617886328605 36.93621446606478 36.93632841420433 297.8950402934797 71.5356184108912 66.5931076214962 26.51685077742183 23.910047607709778 10.313899279091611 15.31199014495078 13.292340995682595 2.680379799957696 10.242829742707812 11.995046375299877 9.842202416339607 0.6838684384320987 2.6531284784835787 10.24283079854063 10.878175502415237 8.50330068632705 0.17977809643751386 0.6671260475358367 2.6531289753057634 10.256522418400566
2049.6929693507795 679.6962663951226 153.76103950048937 36.08036236374624 8.775254036917314 2.1428681018671267 0.540685220427713 0.13159647793378718 0.03272457732749732 0.008309703792123481 0.002061388543162421 0.0005373214024740875 0.00012932799821557616 3.358989301840365e-05 1.9910609472068016e-06 2049.682281691373 418.6989301615079 37.19252279367339 36.66040315748872 36.66028729152239 36.660375447269956 36.66034282835435 290.1050259737521 70.90992813610583 62.413211449939176 27.472679564357314 22.33696197321313 10.256081766983586 16.41390223671071 12.354416475776839 2.642681146872554 10.188908715006807 13.078305876994543 9.106313700653434 0.6753022926659958 2.6192195012070716 10.1889069543645 11.896313258304504 7.873233459924485 0.17713436485239206 0.6610392339533913 2.61922152391777 10.201050207834063
ptq_acc_list:
10.0 14.28 50.26 81.87 85.97 86.89 87.13 87.09 87.1 87.08 87.07 87.06 87.07 87.09 87.08 10.0 22.65 39.86 40.7 38.01 41.06 43.97 15.24 57.61 69.92 76.09 81.73 82.71 79.61 84.75 86.32 82.69 80.96 85.47 86.87 86.36 82.57 81.34 86.24 86.95 86.85 86.27 78.21
10.0 17.5 55.9 80.28 86.05 87.16 87.25 87.27 87.38 87.29 87.32 87.32 87.31 87.32 87.3 10.0 23.75 64.18 66.4 65.77 66.53 58.32 19.99 54.54 68.29 72.88 82.37 83.32 77.04 85.29 86.28 83.07 79.34 86.0 87.05 86.34 83.1 79.9 86.01 87.06 87.01 86.26 81.18
acc_loss_list:
0.8851894374282434 0.8360505166475315 0.4229621125143513 0.06004592422502859 0.012973593570608444 0.002411021814006817 -0.00034443168771528286 0.00011481056257165219 0.0 0.00022962112514346753 0.00034443168771528286 0.00045924225028693506 0.00034443168771528286 0.00011481056257165219 0.00022962112514346753 0.8851894374282434 0.7399540757749712 0.5423650975889782 0.5327210103329506 0.5636050516647532 0.5285878300803674 0.4951779563719862 0.825028702640643 0.33857634902411016 0.19724454649827777 0.12640642939150393 0.06165327210103319 0.050401836969001156 0.08599311136624564 0.02698048220436274 0.00895522388059703 0.05063145809414463 0.07049368541905857 0.018714121699196274 0.0026406429391502844 0.00849598163030993 0.05200918484500576 0.0661308840413317 0.009873708381171062 0.0017221584385762512 0.0028702640642939152 0.009529276693455779 0.10206659012629163
0.8854393401305991 0.7995188452285485 0.3596059113300493 0.08030702256845004 0.014205521823805809 0.0014892885783023217 0.000458242639477675 0.0002291213197389189 -0.0010310459388244838 0.0 -0.00034368197960805275 -0.00034368197960805275 -0.0002291213197387561 -0.00034368197960805275 -0.00011456065986929664 0.8854393401305991 0.727918432810173 0.26474968495818535 0.23931721846717835 0.24653454003895073 0.2378279298888762 0.3318822316416543 0.7709932409210678 0.3751861610722878 0.2176652537518616 0.16508191087180674 0.05636384465574523 0.04548058196815228 0.11742467636613586 0.022912131973880166 0.011570626646809544 0.048344598464887305 0.0910757245961737 0.01477832512315278 0.0027494558368657243 0.010883262687593112 0.048000916485279085 0.08466032764348723 0.01466376446328332 0.002634895176996265 0.0032076984763432367 0.011799747966548299 0.0699965631802039
from model import *
from extract_ratio import *
from utils import *
import openpyxl
import gol
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm).cpu().item()
def quantize_aware_training(model, device, train_loader, optimizer, epoch):
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader, 1):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model.quantize_forward(data)
# 对一批数据求得的loss是平均值
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad
# print(grad_dict[name])
# print(grad_dict.items())
# input()
optimizer.step()
if batch_idx % 50 == 0:
print('Quantize Aware Training Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def full_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad
# print(grad_dict[name])
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def quantize_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_inference(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Quant Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
if __name__ == "__main__":
batch_size = 32
seed = 1
epochs = 15
lr = 0.001
momentum = 0.5
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
writer = SummaryWriter(log_dir='./log/qat')
wb = openpyxl.Workbook()
ws = wb.active
layer, par_ratio, flop_ratio = extract_ratio()
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
full_file = 'ckpt/cifar10_AlexNet_BN.pt'
model = AlexNet_BN()
# model.load_state_dict(torch.load(full_file))
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
load_qat = False
ckpt_prefix = "ckpt/qat/"
loss_sum = 0.
full_grad_sum = {}
full_grad_avg = {}
for name,param in model.named_parameters():
full_grad_sum[name] = torch.zeros_like(param)
full_grad_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
# 训练原模型,获取梯度分布
loss,full_grad = train(model, device, train_loader, optimizer, epoch)
if epoch == 1:
loss_start = loss
# print('loss:%f' % loss_avg)
writer.add_scalar('Full.loss',loss,epoch)
if epoch == 15:
for name,grad in full_grad.items():
writer.add_histogram('Full.'+name+'_grad',grad)
loss_sum += loss
loss_avg = loss_sum / epoch
loss_delta = loss - loss_start
for name,grad in full_grad.items():
full_grad_sum[name] += full_grad[name]
full_grad_avg[name] = full_grad_sum[name] / epoch
if epoch % 5 == 0:
ws = wb.create_sheet('epoch_%d'%epoch)
ws.cell(row=1,column=2,value='loss')
ws.cell(row=1,column=3,value='loss_sum')
ws.cell(row=1,column=4,value='loss_avg')
ws.cell(row=1,column=5,value='loss_delta')
ws.cell(row=2,column=1,value='FP32')
ws.cell(row=2,column=2,value=loss.item())
ws.cell(row=2,column=3,value=loss_sum.item())
ws.cell(row=2,column=4,value=loss_avg.item())
ws.cell(row=2,column=5,value=loss_delta.item())
ws.cell(row=4,column=1,value='title')
ws.cell(row=4,column=2,value='loss')
ws.cell(row=4,column=3,value='loss_sum')
ws.cell(row=4,column=4,value='loss_avg')
ws.cell(row=4,column=5,value='loss_delta')
ws.cell(row=4,column=6,value='js_grad')
ws.cell(row=4,column=7,value='js_grad_sum')
ws.cell(row=4,column=8,value='js_grad_avg')
quant_type_list = ['INT','POT','FLOAT']
gol._init()
currow=4 #数据从哪行开始写
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
if load_qat is True and osp.exists(ckpt_prefix+'epoch_20/'+title+'.pt'):
continue
currow += 1
print('\nQAT: '+title)
model_ptq = AlexNet_BN()
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
# model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.train()
loss_sum = 0.
qat_grad_sum = {}
qat_grad_avg = {}
for name,param in model.named_parameters():
qat_grad_sum[name] = torch.zeros_like(param)
qat_grad_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
loss,qat_grad = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
# print('loss:%f' % loss_avg)
if epoch == 1:
loss_start = loss
writer.add_scalar(title+'.loss',loss,epoch)
if epoch == 15:
for name,grad in qat_grad.items():
writer.add_histogram(title+'.'+name+'_grad',grad)
loss_sum += loss
loss_avg = loss_sum / epoch
loss_delta = loss-loss_start
for name,param in model.named_parameters():
qat_grad_sum[name] += qat_grad[name]
qat_grad_avg[name] = qat_grad_sum[name] / epoch
if epoch % 5 == 0:
ws = wb['epoch_%d'%epoch]
js_grad = 0.
js_grad_sum = 0.
js_grad_avg = 0.
for name,_ in model_ptq.named_parameters():
prefix = name.split('.')[0]
layer_idx = layer.index(prefix)
js_grad += flop_ratio[layer_idx] * js_div(qat_grad[name],full_grad[name])
js_grad_sum += flop_ratio[layer_idx] * js_div(qat_grad_sum[name],full_grad_sum[name])
js_grad_avg += flop_ratio[layer_idx] * js_div(qat_grad_avg[name],full_grad_avg[name])
ws.cell(row=currow,column=1,value=title)
ws.cell(row=currow,column=2,value=loss.item())
ws.cell(row=currow,column=3,value=loss_sum.item())
ws.cell(row=currow,column=4,value=loss_avg.item())
ws.cell(row=currow,column=5,value=loss_delta.item())
ws.cell(row=currow,column=6,value=js_grad.item())
ws.cell(row=currow,column=7,value=js_grad_sum.item())
ws.cell(row=currow,column=8,value=js_grad_avg.item())
wb.remove(wb['Sheet']) # 根据名称删除工作表
wb.save('qat_result.xlsx')
writer.close()
......@@ -46,14 +46,9 @@ def test(model, device, test_loader):
if __name__ == "__main__":
batch_size = 32
test_batch_size = 32
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
epochs_cfg = [15, 20, 20, 20, 10, 10]
lr_cfg = [0.01, 0.005, 0.002, 0.001, 0.0005, 0.0001]
momentum = 0.5
save_model = True
......@@ -62,7 +57,7 @@ if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
......@@ -73,31 +68,24 @@ if __name__ == "__main__":
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = AlexNet_BN().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
epoch_start = 1
for epochs,lr in zip(epochs_cfg,lr_cfg):
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
epoch_end = epoch_start+epochs
for epoch in range(epoch_start,epoch_end):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
epoch_start += epochs
if save_model:
if not osp.exists('ckpt'):
......
# 改动说明
## update:2023/04/17
+ 指定了新的梯度学习率方案,对全精度模型重新训练以达到更高的acc,并重新进行ptq和fit
## update: 2023/04/16
+ 添加了matlab的拟合及绘图脚本,支持模型分类标记,且曲线拟合相比cftool更加平滑
+ ptq.py中计算js_param笔误,应由flop_ratio改为par_ratio。否则flops和param拟合没有区别
+ module.py中bias_qmax方法,应当为float类型传参num_bits为16,e_bits为7.
+ 这里主要关注e_bits,拟合离群点主要为FLOAT_7_E5 / FLOAT_8_E5 / FLOAT_8_E6,其表现为bias两极分布,与之前int量化bias溢出的问题现象相似。
+ 原先指定e_bits为5,由于bias的scale为input和weight的scale乘积,bias量化范围应当大致为x和weight量化范围的平方倍。目前代码支持的最高x和weight量化范围大致为 2的2的6次方 ,因此bias范围应当近似取到2的2的7次方,即将e_bits指定为7
+ 改动之后,离群点消失,拟合效果显著提高
\ No newline at end of file
+ 原先指定e_bits为5,由于bias的scale为input和weight的scale乘积,bias量化范围应当大致为x和weight量化范围的平方倍。目前代码支持的最高x和weight量化范围大致为 $2^{2^{6}}$ ,因此bias范围应当近似取到$2^{2^7}$,即将e_bits指定为7
+ 改动之后,离群点消失,拟合效果显著提高
ykl/VGG_16/image/VGG16_table.png

21.6 KB | W: | H:

ykl/VGG_16/image/VGG16_table.png

21 KB | W: | H:

ykl/VGG_16/image/VGG16_table.png
ykl/VGG_16/image/VGG16_table.png
ykl/VGG_16/image/VGG16_table.png
ykl/VGG_16/image/VGG16_table.png
  • 2-up
  • Swipe
  • Onion skin
ykl/VGG_16/image/flops.png

33.1 KB | W: | H:

ykl/VGG_16/image/flops.png

33.2 KB | W: | H:

ykl/VGG_16/image/flops.png
ykl/VGG_16/image/flops.png
ykl/VGG_16/image/flops.png
ykl/VGG_16/image/flops.png
  • 2-up
  • Swipe
  • Onion skin
ykl/VGG_16/image/param.png

33.5 KB | W: | H:

ykl/VGG_16/image/param.png

33.4 KB | W: | H:

ykl/VGG_16/image/param.png
ykl/VGG_16/image/param.png
ykl/VGG_16/image/param.png
ykl/VGG_16/image/param.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -14,7 +14,7 @@ import module
feature_cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
classifier_cfg = [4096, 4096, 'LF']
def make_feature_layers(cfg, batch_norm=False):
def make_feature_layers(cfg, batch_norm=True):
layers = []
names = []
input_channel = 3
......@@ -103,13 +103,14 @@ def quantize_classifier_layers(model,name_list,quant_type,num_bits,e_bits):
return names, layers
def quantize_utils(model,qfeature_name,qclassifier_name,func, x=None):
# 原层和量化层的forward都可以使用该函数
def model_utils(model,feature_name,classifier_name,func, x=None):
if func == 'inference':
layer=getattr(model,qfeature_name[0])
layer=getattr(model,feature_name[0])
x = layer.qi.quantize_tensor(x)
last_qo = None
for name in qfeature_name:
for name in feature_name:
layer = getattr(model,name)
if func == 'forward':
x = layer(x)
......@@ -123,7 +124,7 @@ def quantize_utils(model,qfeature_name,qclassifier_name,func, x=None):
if func != 'freeze':
x = torch.flatten(x, start_dim=1)
for name in qclassifier_name:
for name in classifier_name:
layer = getattr(model,name)
if func == 'forward':
x = layer(x)
......@@ -164,26 +165,9 @@ class VGG_16(nn.Module):
# self.fc3 = nn.Linear(4096, num_class)
def forward(self, x):
#feature
for name in self.feature_name:
layer = getattr(self,name)
x = layer(x)
x = torch.flatten(x, start_dim=1)
#classifier
for name in self.classifier_name:
layer = getattr(self,name)
x = layer(x)
# x = self.fc1(x)
# x = self.crelu1(x)
# x = self.drop1(x)
# x = self.fc2(x)
# x = self.crelu2(x)
# x = self.drop2(x)
# x = self.fc3(x)
x = model_utils(self, self.feature_name, self.classifier_name,
func='forward', x=x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
......@@ -207,17 +191,17 @@ class VGG_16(nn.Module):
# self.qfc3 = QLinear(quant_type, self.fc3, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
def quantize_forward(self,x):
x = quantize_utils(self, self.qfeature_name, self.qclassifier_name,
x = model_utils(self, self.qfeature_name, self.qclassifier_name,
func='forward', x=x)
return x
def freeze(self):
quantize_utils(self, self.qfeature_name, self.qclassifier_name,
model_utils(self, self.qfeature_name, self.qclassifier_name,
func='freeze', x=None)
def quantize_inference(self,x):
x = quantize_utils(self, self.qfeature_name, self.qclassifier_name,
x = model_utils(self, self.qfeature_name, self.qclassifier_name,
func='inference', x=x)
return x
\ No newline at end of file
......@@ -8,6 +8,19 @@ from torch.autograd import Variable
from function import FakeQuantize
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
# 获取最近的量化值
def get_nearest_val(quant_type, x, is_bias=False, block_size=1000000):
if quant_type == 'INT':
......
......@@ -51,18 +51,6 @@ def quantize_inference(model, test_loader, device):
return 100. * correct / len(test_loader.dataset)
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
if __name__ == "__main__":
batch_size = 32
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
9536.470545853725 2226.0625961453534 479.08942343671333 110.29737895892738 26.51254658172414 6.543173505823439 1.6082545552680656 0.4011046819184469 0.0995744091981884 0.02512104453474506 0.006191111080978074 0.0015259451834671177 0.0003948419706838088 9.156847371397302e-05 5.103821562683077e-05 9536.458765578662 1346.195594505469 186.04319320819684 184.6681482382002 184.6683487483022 184.66821940371673 184.66795503280395 1162.9049836246736 334.88736537318147 213.59353722936538 162.90832461289847 74.46974838241348 51.093368592351894 114.00984771261629 39.88037443687132 13.180687261808096 50.91668399478627 97.04450826091268 28.84898524647785 3.333529301905296 13.115715295208558 50.91668073149591 90.27653879084389 24.69379955366047 0.8539877927713175 3.2949263309664056 13.115710967364368 50.954444671302134
9678.738526193363 2754.445091033171 599.5177886615332 139.45004407102132 33.61972833441705 8.327290718295096 2.0621958796228017 0.507271590567208 0.12638558978042777 0.03172725299824697 0.00787383157101955 0.0019659234261853076 0.0004929713753570182 0.00013082873931775967 6.58462511070724e-05 9678.721786560653 1669.6727215221636 181.2126746364518 179.2569959125805 179.25701832538817 179.2571686197933 179.2571438439814 1252.5642980710986 326.4085549727165 257.15456948396184 140.83288095342002 91.39911193737521 49.66634546455573 91.6028944861708 50.03401260763902 12.92006159573897 49.428214414171954 75.71805318590293 36.69642341667939 3.284018584054309 12.831006283541836 49.4282155693062 69.7655678600058 31.59066644962877 0.8486327021161753 3.2312358866685056 12.83099812618559 49.46498235942715
js_param_list:
6887.256091122821 1340.784804470707 280.73813097751594 63.94496568453547 15.289532692524366 3.7564939386797525 0.9259647062823152 0.2288835741602492 0.05710934861695347 0.014186045165619268 0.003557508084599861 0.0008813425517560222 0.00022710583614690054 4.224842305566257e-05 1.6705389170873787e-05 6887.256027866801 806.0594459403769 139.7164244353938 139.04019666205338 139.04029041111284 139.04030319644943 139.04024126647542 795.737012530977 249.6261323076044 132.75011877730668 132.00246136063635 45.67073400948295 38.23747085219327 96.59443228186598 23.88552681545614 9.84598569650612 38.15017558575456 83.61518038181232 17.04342608978349 2.4726694252419787 9.813474353111108 38.1501666315099 78.20948388628338 14.518609766822784 0.6269759811734874 2.4532405132850865 9.813473104774046 38.16105197534707
6979.995257775382 1676.8544762024308 357.07416358579195 82.21848562208979 19.71659591761924 4.858750496313133 1.204106769484056 0.2970651538514128 0.07389159465364287 0.01847686680473944 0.004583820583860123 0.0011690277325588727 0.0003001505161697722 8.054369623471168e-05 3.8545500662813255e-05 6979.99738972489 1008.874309342299 135.7925378028479 134.79040735180877 134.7905911406337 134.79058895538347 134.79061404359229 841.4093324566583 236.70954280105926 160.0719446844072 112.93571614580846 56.470927826653764 36.93548609003767 78.37733457244148 30.396125541872564 9.673987287874596 36.81191428618332 66.55321174809056 22.060163046203872 2.446769401900793 9.628841768242337 36.811916984902815 61.86831406132784 18.91259887812061 0.6248787079335093 2.4197557895059423 9.62881847680995 36.822419165985735
ptq_acc_list:
10.0 11.52 61.58 87.57 88.71 89.24 89.5 89.52 89.45 89.46 89.44 89.42 89.45 89.44 89.45 10.0 13.86 65.99 67.14 68.35 65.85 66.85 14.04 62.71 78.91 78.63 87.01 86.71 83.03 88.66 88.78 86.78 84.11 88.73 89.44 88.73 86.65 84.61 88.92 89.4 89.54 88.75 75.87
10.0 11.13 56.25 87.31 89.42 89.68 89.8 89.82 89.8 89.75 89.77 89.81 89.79 89.8 89.8 10.0 19.8 68.55 69.75 70.61 69.72 67.62 12.29 54.35 78.59 80.14 87.26 87.17 83.65 88.89 89.41 87.03 84.94 89.19 89.78 89.18 87.11 85.27 89.42 89.76 89.81 89.41 74.75
acc_loss_list:
0.8881932021466905 0.8711985688729875 0.31149373881932024 0.020907871198568923 0.008161896243291637 0.0022361359570662216 -0.0006708407871198823 -0.0008944543828264568 -0.00011180679785336669 -0.00022361359570657448 0.0 0.00022361359570657448 -0.00011180679785336669 0.0 -0.00011180679785336669 0.8881932021466905 0.8450357781753131 0.26218694096601075 0.24932915921288012 0.23580053667262973 0.2637522361359571 0.25257155635062617 0.8430232558139535 0.2988595706618962 0.1177325581395349 0.12086314847942758 0.02716905187835412 0.030523255813953532 0.07166815742397134 0.008720930232558153 0.007379248658318388 0.029740608228980284 0.059593023255813934 0.007938282647584904 0.0 0.007938282647584904 0.031194096601073258 0.05400268336314846 0.005813953488372049 0.00044722719141314897 -0.0011180679785331902 0.007714669051878329 0.1517218246869409
0.888641425389755 0.8760579064587973 0.3736080178173719 0.027728285077950946 0.004231625835189259 0.0013363028953228323 0.0 -0.00022271714922044567 0.0 0.0005567928730511933 0.00033407572383074765 -0.00011135857461030197 0.00011135857461014371 0.0 0.0 0.888641425389755 0.779510022271715 0.23663697104677062 0.22327394209354118 0.2136971046770601 0.2236080178173719 0.2469933184855233 0.8631403118040089 0.39476614699331847 0.12483296213808456 0.10757238307349662 0.02828507795100214 0.029287305122494382 0.06848552338530058 0.010133630289532257 0.004342984409799561 0.03084632516703782 0.05412026726057906 0.006792873051224938 0.00022271714922044567 0.006904231625835082 0.02995545657015588 0.05044543429844099 0.004231625835189259 0.00044543429844089133 -0.00011135857461030197 0.004342984409799561 0.16759465478841867
from model import *
from extract_ratio import *
from utils import *
import openpyxl
import gol
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm).cpu().item()
def quantize_aware_training(model, device, train_loader, optimizer, epoch):
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader, 1):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model.quantize_forward(data)
# 对一批数据求得的loss是平均值
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad
# print(grad_dict[name])
# print(grad_dict.items())
# input()
optimizer.step()
if batch_idx % 50 == 0:
print('Quantize Aware Training Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def full_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad
# print(grad_dict[name])
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def quantize_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_inference(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Quant Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
if __name__ == "__main__":
batch_size = 32
seed = 1
epochs = 15
lr = 0.001
momentum = 0.5
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
writer = SummaryWriter(log_dir='./log/qat')
wb = openpyxl.Workbook()
ws = wb.active
layer, par_ratio, flop_ratio = extract_ratio()
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
full_file = 'ckpt/cifar10_VGG_16.pt'
model = VGG_16()
# model.load_state_dict(torch.load(full_file))
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
load_qat = False
ckpt_prefix = "ckpt/qat/"
loss_sum = 0.
full_grad_sum = {}
full_grad_avg = {}
for name,param in model.named_parameters():
full_grad_sum[name] = torch.zeros_like(param)
full_grad_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
# 训练原模型,获取梯度分布
loss,full_grad = train(model, device, train_loader, optimizer, epoch)
if epoch == 1:
loss_start = loss
# print('loss:%f' % loss_avg)
writer.add_scalar('Full.loss',loss,epoch)
# for name,grad in grad_dict.items():
# writer.add_histogram('Full.'+name+'_grad',grad,global_step=epoch)
loss_sum += loss
loss_avg = loss_sum / epoch
loss_delta = loss - loss_start
for name,grad in full_grad.items():
full_grad_sum[name] += full_grad[name]
full_grad_avg[name] = full_grad_sum[name] / epoch
if epoch % 5 == 0:
ws = wb.create_sheet('epoch_%d'%epoch)
ws.cell(row=1,column=2,value='loss')
ws.cell(row=1,column=3,value='loss_sum')
ws.cell(row=1,column=4,value='loss_avg')
ws.cell(row=1,column=5,value='loss_delta')
ws.cell(row=2,column=1,value='FP32')
ws.cell(row=2,column=2,value=loss.item())
ws.cell(row=2,column=3,value=loss_sum.item())
ws.cell(row=2,column=4,value=loss_avg.item())
ws.cell(row=2,column=5,value=loss_delta.item())
ws.cell(row=4,column=1,value='title')
ws.cell(row=4,column=2,value='loss')
ws.cell(row=4,column=3,value='loss_sum')
ws.cell(row=4,column=4,value='loss_avg')
ws.cell(row=4,column=5,value='loss_delta')
ws.cell(row=4,column=6,value='js_grad')
ws.cell(row=4,column=7,value='js_grad_sum')
ws.cell(row=4,column=8,value='js_grad_avg')
quant_type_list = ['INT','POT','FLOAT']
gol._init()
currow=4 #数据从哪行开始写
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
if load_qat is True and osp.exists(ckpt_prefix+'epoch_20/'+title+'.pt'):
continue
currow += 1
print('\nQAT: '+title)
model_ptq = VGG_16()
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
# model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.train()
loss_sum = 0.
qat_grad_sum = {}
qat_grad_avg = {}
for name,param in model.named_parameters():
qat_grad_sum[name] = torch.zeros_like(param)
qat_grad_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
loss,qat_grad = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
# print('loss:%f' % loss_avg)
if epoch == 1:
loss_start = loss
writer.add_scalar(title+'.loss',loss,epoch)
# for name,grad in qat_grad.items():
# writer.add_histogram(title+'.'+name+'_grad',grad,global_step=epoch)
loss_sum += loss
loss_avg = loss_sum / epoch
loss_delta = loss-loss_start
for name,param in model.named_parameters():
qat_grad_sum[name] += qat_grad[name]
qat_grad_avg[name] = qat_grad_sum[name] / epoch
if epoch % 5 == 0:
ws = wb['epoch_%d'%epoch]
js_grad = 0.
js_grad_sum = 0.
js_grad_avg = 0.
for name,_ in model_ptq.named_parameters():
prefix = name.split('.')[0]
layer_idx = layer.index(prefix)
js_grad += flop_ratio[layer_idx] * js_div_norm(qat_grad[name],full_grad[name])
js_grad_sum += flop_ratio[layer_idx] * js_div_norm(qat_grad_sum[name],full_grad_sum[name])
js_grad_avg += flop_ratio[layer_idx] * js_div_norm(qat_grad_avg[name],full_grad_avg[name])
ws.cell(row=currow,column=1,value=title)
ws.cell(row=currow,column=2,value=loss.item())
ws.cell(row=currow,column=3,value=loss_sum.item())
ws.cell(row=currow,column=4,value=loss_avg.item())
ws.cell(row=currow,column=5,value=loss_delta.item())
ws.cell(row=currow,column=6,value=js_grad)
ws.cell(row=currow,column=7,value=js_grad_sum)
ws.cell(row=currow,column=8,value=js_grad_avg)
wb.remove(wb['Sheet']) # 根据名称删除工作表
wb.save('qat_result.xlsx')
writer.close()
......@@ -46,14 +46,9 @@ def test(model, device, test_loader):
if __name__ == "__main__":
batch_size = 32
test_batch_size = 32
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
epochs_cfg = [25, 30, 30, 20, 20, 10, 10]
lr_cfg = [0.01, 0.008, 0.005, 0.002, 0.001, 0.0005, 0.0001]
momentum = 0.5
save_model = True
......@@ -78,26 +73,19 @@ if __name__ == "__main__":
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = VGG_16().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
epoch_start = 1
for epochs,lr in zip(epochs_cfg,lr_cfg):
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
epoch_end = epoch_start+epochs
for epoch in range(epoch_start,epoch_end):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
epoch_start += epochs
if save_model:
if not osp.exists('ckpt'):
......
ykl/VGG_19/image/VGG19_table.png

21.6 KB | W: | H:

ykl/VGG_19/image/VGG19_table.png

21.4 KB | W: | H:

ykl/VGG_19/image/VGG19_table.png
ykl/VGG_19/image/VGG19_table.png
ykl/VGG_19/image/VGG19_table.png
ykl/VGG_19/image/VGG19_table.png
  • 2-up
  • Swipe
  • Onion skin
ykl/VGG_19/image/flops.png

32.8 KB | W: | H:

ykl/VGG_19/image/flops.png

33.4 KB | W: | H:

ykl/VGG_19/image/flops.png
ykl/VGG_19/image/flops.png
ykl/VGG_19/image/flops.png
ykl/VGG_19/image/flops.png
  • 2-up
  • Swipe
  • Onion skin
ykl/VGG_19/image/param.png

33.6 KB | W: | H:

ykl/VGG_19/image/param.png

34.7 KB | W: | H:

ykl/VGG_19/image/param.png
ykl/VGG_19/image/param.png
ykl/VGG_19/image/param.png
ykl/VGG_19/image/param.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -8,6 +8,19 @@ from torch.autograd import Variable
from function import FakeQuantize
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
# 获取最近的量化值
def get_nearest_val(quant_type, x, is_bias=False, block_size=1000000):
if quant_type == 'INT':
......
......@@ -51,18 +51,6 @@ def quantize_inference(model, test_loader, device):
return 100. * correct / len(test_loader.dataset)
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
if __name__ == "__main__":
batch_size = 32
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
10125.068781318083 2125.8746396005995 448.1684275076656 102.72227062531958 24.664131831035423 6.028952317368312 1.4808853001039965 0.3654121470448262 0.09188669130958198 0.022862163476844406 0.005681816830498811 0.0014281973847489284 0.00033902203745254765 0.00010609152120135955 3.4934861119480844e-05 10125.058594081152 1275.5846069054285 202.10990836927274 200.95538533356967 200.95543492335307 200.95532374524822 200.95529898315357 1204.7518156488943 367.2780949517851 207.51154204984104 188.40107410370646 71.77026871896568 55.97867526181043 135.4187670910398 37.94656908671154 14.343459410681755 55.83674829649791 116.40261372885318 27.2535410832288 3.606044025817579 14.290313737791045 55.83674707301358 108.58905190788829 23.270040541541757 0.9179190066917342 3.574594323551281 14.290315204670643 55.87150641050152
10335.410852393428 2702.374876526504 586.8085986640358 136.19783508263313 32.76686433160855 8.108710198358402 2.017359787423094 0.5043751741104074 0.12519450821641712 0.03078375848492338 0.007759517552654125 0.001915080842767542 0.0004650357624047513 0.00010331804808074696 5.051867857804825e-05 10335.404291659883 1639.8833350853 197.21747754065524 195.40539442483635 195.40543311681287 195.40540281991127 195.40546678069262 1301.4075583496397 353.91370485500016 255.05728255399467 161.1495511523927 90.18686133261879 54.04437942140043 108.573018962868 48.961241334358604 14.062757558415433 53.804179496106855 91.04968596750315 35.698608089109584 3.568426481512405 13.974762150284537 53.80416850535668 84.32169873971353 30.648531893335964 0.9205644289947652 3.5162305719167306 13.974779013299235 53.8384368168758
js_param_list:
7919.99983980378 1368.1564713436267 283.1693554968751 64.29426936491545 15.321270528727021 3.745277426600951 0.9178679001253169 0.22807122885447256 0.05690431499169481 0.01419322342297318 0.003568448783068516 0.0008906301386853677 0.00021638563506049427 7.150496949260222e-05 3.2155535421527054e-05 7920.00259798844 818.8837200099067 164.46191241414346 163.8453142768725 163.84534269839182 163.84535520112814 163.84532577449994 912.7525364721757 304.20939172507764 138.97594743194938 167.73729462694857 46.84141610323806 45.957500752525 124.95158122945192 23.983665956476454 11.613224211085821 45.88101341301542 108.73234414542763 16.892438161824266 2.904718992579768 11.584588581117753 45.88101100206524 101.79740801889845 14.310069953741102 0.7348166494910408 2.8876544755152516 11.584596345742247 45.89450884590524
8185.907782380622 1800.9911591065213 381.9598699018292 87.55210816175092 20.97233721922892 5.157556218764395 1.273967796169326 0.3167936166061732 0.07878698713657759 0.01962915205843054 0.0049230819182810885 0.001218414473634084 0.00029934385592657504 6.540185677536515e-05 1.9836379728102926e-05 8185.911496782605 1085.762959717068 161.20373387464056 160.2127414792563 160.2128133174403 160.21278421049763 160.21274821008095 958.9015917875203 279.7889215111697 174.41461406276403 139.022264040724 61.1200446333161 43.83433305556267 98.6990126657353 32.570105813234655 11.541062350360075 43.70912334388383 84.54385769669656 23.46464257859416 2.916976983569494 11.494558828605104 43.70911423954581 78.86080293161805 20.057860207412013 0.7429960203642413 2.8894343747747815 11.494570441006683 43.722130173425036
ptq_acc_list:
10.0 12.95 60.49 87.39 88.96 89.11 89.25 89.3 89.24 89.25 89.25 89.26 89.25 89.25 89.25 10.0 17.62 66.43 66.71 66.14 62.97 65.32 12.22 59.32 80.16 79.22 87.23 86.1 82.78 88.34 88.76 86.04 84.01 88.86 89.23 88.72 86.27 84.73 88.9 89.32 89.33 88.74 73.92
10.0 10.47 69.02 88.3 89.73 89.92 89.99 90.08 90.09 90.08 90.11 90.1 90.11 90.1 90.09 10.0 14.58 73.48 72.92 65.82 72.06 71.49 10.35 50.78 80.42 79.11 87.53 87.77 84.05 89.15 89.39 87.83 84.85 89.4 89.9 89.32 87.61 85.69 89.65 90.18 89.87 89.44 78.02
acc_loss_list:
0.8879551820728291 0.8549019607843137 0.3222408963585434 0.020840336134453775 0.0032492997198880253 0.0015686274509803986 0.0 -0.0005602240896358225 0.00011204481792722819 0.0 0.0 -0.00011204481792722819 0.0 0.0 0.0 0.8879551820728291 0.8025770308123249 0.25568627450980386 0.2525490196078432 0.2589355742296919 0.29445378151260504 0.26812324929972 0.8630812324929972 0.3353501400560224 0.10184873949579835 0.11238095238095239 0.02263305322128847 0.035294117647058885 0.07249299719887954 0.010196078431372511 0.005490196078431315 0.03596638655462178 0.05871148459383748 0.0043697478991596705 0.00022408963585429716 0.005938375350140069 0.03338935574229696 0.05064425770308119 0.003921568627450917 -0.0007843137254901196 -0.0008963585434173479 0.0057142857142857715 0.17176470588235293
0.8890122086570477 0.883795782463929 0.2339622641509434 0.01997780244173138 0.0041065482796891276 0.001997780244173059 0.0012208657047724687 0.0002219755826858604 0.00011098779134285134 0.0002219755826858604 -0.00011098779134300907 0.0 -0.00011098779134300907 0.0 0.00011098779134285134 0.8890122086570477 0.8381798002219756 0.1844617092119866 0.19067702552719193 0.26947835738068815 0.20022197558268584 0.2065482796892342 0.8851276359600444 0.4364039955604883 0.10743618201997773 0.1219755826859045 0.02852386237513866 0.025860155382907864 0.0671476137624861 0.01054384017758034 0.007880133185349542 0.025194228634850123 0.05826859045504995 0.007769145394006534 0.0022197558268589193 0.00865704772475029 0.02763596004439506 0.04894561598224192 0.004994450610432726 -0.000887902330743757 0.0025527192008877888 0.007325194228634813 0.13407325194228634
from model import *
from extract_ratio import *
from utils import *
import openpyxl
import gol
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm).cpu().item()
def quantize_aware_training(model, device, train_loader, optimizer, epoch):
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader, 1):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model.quantize_forward(data)
# 对一批数据求得的loss是平均值
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad
# print(grad_dict[name])
# print(grad_dict.items())
# input()
optimizer.step()
if batch_idx % 50 == 0:
print('Quantize Aware Training Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def full_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad
# print(grad_dict[name])
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def quantize_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_inference(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Quant Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
if __name__ == "__main__":
batch_size = 32
seed = 1
epochs = 15
lr = 0.001
momentum = 0.5
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
writer = SummaryWriter(log_dir='./log/qat')
wb = openpyxl.Workbook()
ws = wb.active
layer, par_ratio, flop_ratio = extract_ratio()
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
full_file = 'ckpt/cifar10_VGG_19.pt'
model = VGG_19()
# model.load_state_dict(torch.load(full_file))
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
load_qat = False
ckpt_prefix = "ckpt/qat/"
loss_sum = 0.
full_grad_sum = {}
full_grad_avg = {}
for name,param in model.named_parameters():
full_grad_sum[name] = torch.zeros_like(param)
full_grad_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
# 训练原模型,获取梯度分布
loss,full_grad = train(model, device, train_loader, optimizer, epoch)
if epoch == 1:
loss_start = loss
# print('loss:%f' % loss_avg)
writer.add_scalar('Full.loss',loss,epoch)
# for name,grad in grad_dict.items():
# writer.add_histogram('Full.'+name+'_grad',grad,global_step=epoch)
loss_sum += loss
loss_avg = loss_sum / epoch
loss_delta = loss - loss_start
for name,grad in full_grad.items():
full_grad_sum[name] += full_grad[name]
full_grad_avg[name] = full_grad_sum[name] / epoch
if epoch % 5 == 0:
ws = wb.create_sheet('epoch_%d'%epoch)
ws.cell(row=1,column=2,value='loss')
ws.cell(row=1,column=3,value='loss_sum')
ws.cell(row=1,column=4,value='loss_avg')
ws.cell(row=1,column=5,value='loss_delta')
ws.cell(row=2,column=1,value='FP32')
ws.cell(row=2,column=2,value=loss.item())
ws.cell(row=2,column=3,value=loss_sum.item())
ws.cell(row=2,column=4,value=loss_avg.item())
ws.cell(row=2,column=5,value=loss_delta.item())
ws.cell(row=4,column=1,value='title')
ws.cell(row=4,column=2,value='loss')
ws.cell(row=4,column=3,value='loss_sum')
ws.cell(row=4,column=4,value='loss_avg')
ws.cell(row=4,column=5,value='loss_delta')
ws.cell(row=4,column=6,value='js_grad')
ws.cell(row=4,column=7,value='js_grad_sum')
ws.cell(row=4,column=8,value='js_grad_avg')
quant_type_list = ['INT','POT','FLOAT']
gol._init()
currow=4 #数据从哪行开始写
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
if load_qat is True and osp.exists(ckpt_prefix+'epoch_20/'+title+'.pt'):
continue
currow += 1
print('\nQAT: '+title)
model_ptq = VGG_19()
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
# model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.train()
loss_sum = 0.
qat_grad_sum = {}
qat_grad_avg = {}
for name,param in model.named_parameters():
qat_grad_sum[name] = torch.zeros_like(param)
qat_grad_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
loss,qat_grad = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
# print('loss:%f' % loss_avg)
if epoch == 1:
loss_start = loss
writer.add_scalar(title+'.loss',loss,epoch)
# for name,grad in qat_grad.items():
# writer.add_histogram(title+'.'+name+'_grad',grad,global_step=epoch)
loss_sum += loss
loss_avg = loss_sum / epoch
loss_delta = loss-loss_start
for name,param in model.named_parameters():
qat_grad_sum[name] += qat_grad[name]
qat_grad_avg[name] = qat_grad_sum[name] / epoch
if epoch % 5 == 0:
ws = wb['epoch_%d'%epoch]
js_grad = 0.
js_grad_sum = 0.
js_grad_avg = 0.
for name,_ in model_ptq.named_parameters():
prefix = name.split('.')[0]
layer_idx = layer.index(prefix)
js_grad += flop_ratio[layer_idx] * js_div_norm(qat_grad[name],full_grad[name])
js_grad_sum += flop_ratio[layer_idx] * js_div_norm(qat_grad_sum[name],full_grad_sum[name])
js_grad_avg += flop_ratio[layer_idx] * js_div_norm(qat_grad_avg[name],full_grad_avg[name])
ws.cell(row=currow,column=1,value=title)
ws.cell(row=currow,column=2,value=loss.item())
ws.cell(row=currow,column=3,value=loss_sum.item())
ws.cell(row=currow,column=4,value=loss_avg.item())
ws.cell(row=currow,column=5,value=loss_delta.item())
ws.cell(row=currow,column=6,value=js_grad)
ws.cell(row=currow,column=7,value=js_grad_sum)
ws.cell(row=currow,column=8,value=js_grad_avg)
wb.remove(wb['Sheet']) # 根据名称删除工作表
wb.save('qat_result.xlsx')
writer.close()
......@@ -3,7 +3,6 @@ from model import *
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
......@@ -47,16 +46,9 @@ def test(model, device, test_loader):
if __name__ == "__main__":
batch_size = 32
test_batch_size = 32
seed = 1
# epoch = 35
# lr = 0.01
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
epochs_cfg = [30, 40, 30, 20, 20, 10, 10]
lr_cfg = [0.01, 0.008, 0.005, 0.002, 0.001, 0.0005, 0.0001]
momentum = 0.5
save_model = True
......@@ -81,33 +73,19 @@ if __name__ == "__main__":
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = VGG_19().to(device)
# optimizer = optim.Adam(model.parameters(), lr=lr)
# lr_scheduler = CosineAnnealingLR(optimizer, T_max=epoch)
# for epoch in range(1, epoch + 1):
# train(model, device, train_loader, optimizer, epoch)
# # lr_scheduler.step()
# test(model, device, test_loader)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
epoch_start = 1
for epochs,lr in zip(epochs_cfg,lr_cfg):
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
epoch_end = epoch_start+epochs
for epoch in range(epoch_start,epoch_end):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
epoch_start += epochs
if save_model:
if not osp.exists('ckpt'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment