Commit f641275d by Klin

feat: add ptq:FP3-FP7

parent 06dfc82c
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
+ POT:取POT2-POT8 (POT8之后容易出现Overflow) + POT:取POT2-POT8 (POT8之后容易出现Overflow)
+ FP8:取E1-E6 (E0相当于INT量化,E7相当于POT量化,直接取相应策略效果更好) + FP8:取E1-E6 (E0相当于INT量化,E7相当于POT量化,直接取相应策略效果更好)
+ 支持调整FP的位宽 + 支持调整FP的位宽
+ 关于量化点选择,可以更改`utils.py`中的`bit_list`函数 + 关于量化点选择,可以更改`utils.py`中的`numbit_list`函数
+ 量化结果: + 量化结果:
...@@ -31,6 +31,8 @@ ...@@ -31,6 +31,8 @@
matlab导入数据,选择列向量 matlab导入数据,选择列向量
+ 加入FP3-FP7前:
+ js_flops - acc_loss + js_flops - acc_loss
Rational: Numerator degree 2 / Denominator degree 2 Rational: Numerator degree 2 / Denominator degree 2
...@@ -43,6 +45,7 @@ ...@@ -43,6 +45,7 @@
![fig2](image/fig2.png) ![fig2](image/fig2.png)
+ js_param - acc_loss + js_param - acc_loss
Rational: Numerator degree 2 / Denominator degree 2 Rational: Numerator degree 2 / Denominator degree 2
...@@ -54,3 +57,29 @@ ...@@ -54,3 +57,29 @@
- [x] center and scale - [x] center and scale
![fig4](image/fig4.png) ![fig4](image/fig4.png)
+ 加入FP3-FP7后
+ js_flops - acc_loss
Rational: Numerator degree 2 / Denominator degree 2
- [ ] center and scale
![image-20230407010858191](image/fig5.png)
- [x] center and scale
![image-20230407011501987](image/fig6.png)
+ js_param - acc_loss
Rational: Numerator degree 2 / Denominator degree 2
- [ ] center and scale
![image-20230407010945342](image/fig7.png)
- [x] center and scale
![image-20230407010958875](image/fig8.png)
\ No newline at end of file
ykl/AlexNet/image/table.png

113 KB | W: | H:

ykl/AlexNet/image/table.png

18.4 KB | W: | H:

ykl/AlexNet/image/table.png
ykl/AlexNet/image/table.png
ykl/AlexNet/image/table.png
ykl/AlexNet/image/table.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -116,7 +116,7 @@ if __name__ == "__main__": ...@@ -116,7 +116,7 @@ if __name__ == "__main__":
acc_loss_list = [] acc_loss_list = []
for quant_type in quant_type_list: for quant_type in quant_type_list:
num_bit_list, e_bit_list = bit_list(quant_type) num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表 # 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可 # int由于位宽大,使用量化表开销过大,直接_round即可
...@@ -125,6 +125,7 @@ if __name__ == "__main__": ...@@ -125,6 +125,7 @@ if __name__ == "__main__":
gol.set_value(bias_list, is_bias=True) gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list: for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list: for e_bits in e_bit_list:
model_ptq = AlexNet() model_ptq = AlexNet()
if quant_type == 'FLOAT': if quant_type == 'FLOAT':
......
title_list: title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6 INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list: js_flops_list:
7507.750226317713 2739.698390971301 602.5613310246055 140.92197221503724 34.51721888016634 8.518508718865842 2.1353732883428638 0.5319411628570782 0.1316271020831477 0.03249564657892055 0.008037284252895557 0.0020460099353784723 0.00041867764927864105 0.0001321614950419231 5.841430176387608e-06 7507.667348902921 1654.3775934528933 136.7401730898288 134.5782970456457 134.57841422062364 134.5783939274636 134.5782945727605 33.31638902152266 32.12034308540418 0.6541880874259414 2.442034364817909 9.688117360231624 37.70544899186622 7507.750226317713 2739.698390971301 602.5613310246055 140.92197221503724 34.51721888016634 8.518508718865842 2.1353732883428638 0.5319411628570782 0.1316271020831477 0.03249564657892055 0.008037284252895557 0.0020460099353784723 0.00041867764927864105 0.0001321614950419231 5.841430176387608e-06 7507.667348902921 1654.3775934528933 136.7401730898288 134.5782970456457 134.57841422062364 134.5783939274636 134.5782945727605 1054.3432278105702 244.48311696489273 247.89704518368768 87.65672091651302 89.63831617681878 37.95288917539117 48.439491059469624 50.122494451137555 9.763717383777191 37.67666314899965 37.082531966568794 37.162725668876305 2.504495253790035 9.660221946273799 37.70544899186622 33.31638902152266 32.12034308540418 0.6541880874259414 2.442034364817909 9.688117360231624 37.70544899186622
js_param_list: js_param_list:
7507.750226317713 2739.698390971301 602.5613310246055 140.92197221503724 34.51721888016634 8.518508718865842 2.1353732883428638 0.5319411628570782 0.1316271020831477 0.03249564657892055 0.008037284252895557 0.0020460099353784723 0.00041867764927864105 0.0001321614950419231 5.841430176387608e-06 7507.667348902921 1654.3775934528933 136.7401730898288 134.5782970456457 134.57841422062364 134.5783939274636 134.5782945727605 33.31638902152266 32.12034308540418 0.6541880874259414 2.442034364817909 9.688117360231624 37.70544899186622 7507.750226317713 2739.698390971301 602.5613310246055 140.92197221503724 34.51721888016634 8.518508718865842 2.1353732883428638 0.5319411628570782 0.1316271020831477 0.03249564657892055 0.008037284252895557 0.0020460099353784723 0.00041867764927864105 0.0001321614950419231 5.841430176387608e-06 7507.667348902921 1654.3775934528933 136.7401730898288 134.5782970456457 134.57841422062364 134.5783939274636 134.5782945727605 1054.3432278105702 244.48311696489273 247.89704518368768 87.65672091651302 89.63831617681878 37.95288917539117 48.439491059469624 50.122494451137555 9.763717383777191 37.67666314899965 37.082531966568794 37.162725668876305 2.504495253790035 9.660221946273799 37.70544899186622 33.31638902152266 32.12034308540418 0.6541880874259414 2.442034364817909 9.688117360231624 37.70544899186622
ptq_acc_list: ptq_acc_list:
10.0 10.16 51.21 77.39 83.03 84.73 84.84 85.01 85.08 85.07 85.06 85.08 85.08 85.08 85.08 10.0 14.32 72.49 72.65 72.95 72.08 72.23 82.73 83.3 85.01 84.77 59.86 51.87 10.0 10.16 51.21 77.39 83.03 84.73 84.84 85.01 85.08 85.07 85.06 85.08 85.08 85.08 85.08 10.0 14.32 72.49 72.65 72.95 72.08 72.23 24.42 66.66 47.53 77.89 76.18 81.78 81.76 81.37 84.11 81.87 82.02 82.5 84.72 84.18 52.15 82.73 83.3 85.01 84.77 59.86 51.87
acc_loss_list: acc_loss_list:
0.8824635637047484 0.8805829807240245 0.3980959097320169 0.0903855195110484 0.02409496944052653 0.004113775270333736 0.0028208744710859768 0.0008227550540666805 0.0 0.00011753643629531167 0.0002350728725904563 0.0 0.0 0.0 0.0 0.8824635637047484 0.8316878232251997 0.14797837329572172 0.14609779031499756 0.14257169722614005 0.152797367183827 0.15103432063939815 0.027621062529384042 0.020921485660554785 0.0008227550540666805 0.0036436295251528242 0.29642689233662434 0.39033850493653033 0.8824635637047484 0.8805829807240245 0.3980959097320169 0.0903855195110484 0.02409496944052653 0.004113775270333736 0.0028208744710859768 0.0008227550540666805 0.0 0.00011753643629531167 0.0002350728725904563 0.0 0.0 0.0 0.0 0.8824635637047484 0.8316878232251997 0.14797837329572172 0.14609779031499756 0.14257169722614005 0.152797367183827 0.15103432063939815 0.7129760225669958 0.21650211565585334 0.44134931828866947 0.08450869769628583 0.10460742830277377 0.03878702397743297 0.039022096850023426 0.04360601786553824 0.011401034320639386 0.03772919605077567 0.035966149506346995 0.030324400564174875 0.004231311706629048 0.010578279266572538 0.3870474847202633 0.027621062529384042 0.020921485660554785 0.0008227550540666805 0.0036436295251528242 0.29642689233662434 0.39033850493653033
...@@ -9,16 +9,35 @@ import torch.optim as optim ...@@ -9,16 +9,35 @@ import torch.optim as optim
from torchvision import datasets, transforms from torchvision import datasets, transforms
import os import os
import os.path as osp import os.path as osp
from torch.utils.tensorboard import SummaryWriter
def quantize_aware_training(model, device, train_loader, optimizer, epoch): def quantize_aware_training(model, device, train_loader, optimizer, epoch):
lossLayer = torch.nn.CrossEntropyLoss() lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader, 1): for batch_idx, (data, target) in enumerate(train_loader, 1):
data, target = data.to(device), target.to(device) data, target = data.to(device), target.to(device)
optimizer.zero_grad() optimizer.zero_grad()
output = model.quantize_forward(data) output = model.quantize_forward(data)
# 对一批数据求得的loss是平均值
loss = lossLayer(output, target) loss = lossLayer(output, target)
loss.backward() loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad
# print(grad_dict[name])
# print(grad_dict.items())
# input()
optimizer.step() optimizer.step()
if batch_idx % 50 == 0: if batch_idx % 50 == 0:
...@@ -26,6 +45,11 @@ def quantize_aware_training(model, device, train_loader, optimizer, epoch): ...@@ -26,6 +45,11 @@ def quantize_aware_training(model, device, train_loader, optimizer, epoch):
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item() epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
)) ))
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / len(train_loader)
loss_avg = loss_sum / len(train_loader)
return loss_avg, grad_dict
def full_inference(model, test_loader): def full_inference(model, test_loader):
correct = 0 correct = 0
...@@ -36,6 +60,44 @@ def full_inference(model, test_loader): ...@@ -36,6 +60,44 @@ def full_inference(model, test_loader):
correct += pred.eq(target.view_as(pred)).sum().item() correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset))) print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad
# print(grad_dict[name])
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / len(train_loader)
loss_avg = loss_sum / len(train_loader)
return loss_avg, grad_dict
def quantize_inference(model, test_loader): def quantize_inference(model, test_loader):
correct = 0 correct = 0
...@@ -48,22 +110,17 @@ def quantize_inference(model, test_loader): ...@@ -48,22 +110,17 @@ def quantize_inference(model, test_loader):
if __name__ == "__main__": if __name__ == "__main__":
# arg1 = int(sys.argv[1]) # epoch
# arg2 = int(sys.argv[2]) # bits of quantization
batch_size = 32 batch_size = 64
seed = 1 seed = 1
epochs1 = 3 epochs = 20
epochs2 = 30 # 16~30 lr = 0.001
lr1 = 0.01
lr2 = 0.001
momentum = 0.5 momentum = 0.5
using_bn = False
torch.manual_seed(seed) torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
writer = SummaryWriter(log_dir='./log/qat')
train_loader = torch.utils.data.DataLoader( train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True, datasets.CIFAR10('data', train=True, download=True,
transform=transforms.Compose([ transform=transforms.Compose([
...@@ -81,78 +138,150 @@ if __name__ == "__main__": ...@@ -81,78 +138,150 @@ if __name__ == "__main__":
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
) )
# if using_bn:
# model = NetBN()
# model.load_state_dict(torch.load('ckpt/mnist_cnnbn.pt', map_location='cpu'))
# save_file = "ckpt/mnist_cnnbn_qat.pt"
# else:
# model = Net()
# model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu'))
# save_file = "ckpt/mnist_cnn_qat.pt"
full_file = 'ckpt/cifar10_AlexNet.pt'
model = AlexNet() model = AlexNet()
# model.load_state_dict(torch.load('ckpt/cifar10_AlexNet_t5.pt', map_location='cpu')) model.load_state_dict(torch.load(full_file))
model.load_state_dict(torch.load('ckpt/cifar10_AlexNet_t4.pt'))
save_file = "ckpt/cifar10_AlexNet_qat_e4.pt"
load_quant_model_file = None
# load_quant_model_file = "ckpt/cifar10_AlexNet_qat_ratio_4.pt"
model.to(device) model.to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
# 原来是所有参数,包括原有conv的参数 load_quant_model_file = None
# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) # load_quant_model_file = "ckpt/cifar10_AlexNet_qat_ratio_4.pt"
# 改动后只训练每层的scale,需要在量化后才能指定 ckpt_prefix = "ckpt/qat/"
model.eval() loss_sum = 0.
grad_dict_sum = {}
full_inference(model, test_loader) grad_dict_avg = {}
# for param_tensor, param_value in model.state_dict().items(): for name,param in model.named_parameters():
# print(param_tensor, "\t", param_value) grad_dict_sum[name] = torch.zeros_like(param)
num_bits = 8 grad_dict_avg[name] = torch.zeros_like(param)
e_bits = 4 for epoch in range(1, epochs+1):
gol._init() loss,grad_dict = train(model, device, train_loader, optimizer, epoch)
plist = build_list(num_bits=16,e_bits=5) # print('loss:%f' % loss_avg)
gol.set_value(plist,is_bias=True) writer.add_scalar('Full.loss',loss,epoch)
for name,grad in grad_dict.items():
writer.add_histogram('Full.'+name+'_grad',grad,global_step=epoch)
loss_sum += loss
loss_avg = loss_sum / epoch
for name,grad in grad_dict.items():
grad_dict_sum[name] += grad_dict[name]
grad_dict_avg[name] = grad_dict_sum[name] / epoch
ckpt = {
'epoch' : epoch,
'loss' : loss,
'loss_sum' : loss_sum,
'loss_avg' : loss_avg,
'grad_dict' : grad_dict,
'grad_dict_sum' : grad_dict_sum,
'grad_dict_avg' : grad_dict_avg
}
if epoch % 5 == 0:
subdir = 'epoch_%d/' % epoch
torch.save(ckpt,ckpt_prefix+ subdir +'full.pt')
# loss_avg,grad_dict = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
# print('qat_loss:%f' % loss_avg)
# for name,grad in grad_dict.items():
# writer.add_histogram('qat_'+name+'_grad',grad,global_step=epoch)
quant_type_list = ['INT','POT','FLOAT']
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
model_ptq = AlexNet()
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nQAT: '+title)
plist = build_list(num_bits,e_bits) # 设置量化表
gol.set_value(plist) if quant_type != 'INT':
plist = build_list(num_bits,e_bits) plist = build_list(quant_type, num_bits, e_bits)
# print(plist)
gol.set_value(plist) gol.set_value(plist)
model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
loss_sum = 0.
grad_dict_sum = {}
grad_dict_avg = {}
for name,param in model.named_parameters():
grad_dict_sum[name] = torch.zeros_like(param)
grad_dict_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
loss,grad_dict = train(model, device, train_loader, optimizer, epoch)
# print('loss:%f' % loss_avg)
writer.add_scalar(title+'.loss',loss,epoch)
for name,grad in grad_dict.items():
writer.add_histogram(title+'.'+name+'_grad',grad,global_step=epoch)
loss_sum += loss
loss_avg = loss_sum / epoch
for name,grad in grad_dict.items():
grad_dict_sum[name] += grad_dict[name]
grad_dict_avg[name] = grad_dict_sum[name] / epoch
ckpt = {
'epoch' : epoch,
'loss' : loss,
'loss_sum' : loss_sum,
'loss_avg' : loss_avg,
'grad_dict' : grad_dict,
'grad_dict_sum' : grad_dict_sum,
'grad_dict_avg' : grad_dict_avg
}
if epoch % 5 == 0:
subdir = 'epoch_%d/' % epoch
torch.save(ckpt,ckpt_prefix+subdir + title+'.pt')
writer.close()
# # model.eval()
# # full_inference(model, test_loader)
# num_bits = 8
# e_bits = 0
# gol._init()
# print("qat: INT8")
# model.quantize('INT',num_bits,e_bits)
# print('Quantization bit: %d' % num_bits)
# if load_quant_model_file is not None:
# model.load_state_dict(torch.load(load_quant_model_file))
# print("Successfully load quantized model %s" % load_quant_model_file)
# else:
# model.train()
model.quantize(num_bits,e_bits) # for epoch in range(1, epochs+1):
print('Quantization bit: %d' % num_bits) # quantize_aware_training(model, device, train_loader, optimizer, epoch)
# # for epoch in range(epochs1 + 1, epochs2 + 1):
params,params_name = model.get_quant_scales() # # quantize_aware_training(model, device, train_loader, optimizer2, epoch)
optimizer1 = optim.SGD(params, lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(params, lr=lr2, momentum=momentum)
# print('--debug--')
# for name in params_name:
# print(name)
# input()
if load_quant_model_file is not None:
model.load_state_dict(torch.load(load_quant_model_file))
print("Successfully load quantized model %s" % load_quant_model_file)
else:
model.train()
for epoch in range(1, epochs1 + 1):
quantize_aware_training(model, device, train_loader, optimizer1, epoch)
# for epoch in range(epochs1 + 1, epochs2 + 1):
# quantize_aware_training(model, device, train_loader, optimizer2, epoch)
model.eval() # model.eval()
torch.save(model.state_dict(), save_file) # # torch.save(model.state_dict(), save_file)
model.freeze() # model.freeze()
# for name, param in model.named_parameters(): # # for name, param in model.named_parameters():
# print(name) # # print(name)
# print(param.data) # # print(param.data)
# print('----------') # # print('----------')
# for param_tensor, param_value in model.state_dict().items(): # # for param_tensor, param_value in model.state_dict().items():
# print(param_tensor, "\t", param_value) # # print(param_tensor, "\t", param_value)
quantize_inference(model, test_loader) # quantize_inference(model, test_loader)
......
import torch import torch
def bit_list(quant_type):
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
e_bit_list = list(range(1,num_bits-1))
else:
e_bit_list = [0]
return e_bit_list
def numbit_list(quant_type):
if quant_type == 'INT': if quant_type == 'INT':
num_bit_list = list(range(2,17)) num_bit_list = list(range(2,17))
e_bit_list = [0]
elif quant_type == 'POT': elif quant_type == 'POT':
num_bit_list = list(range(2,9)) num_bit_list = list(range(2,9))
e_bit_list = [0]
else: else:
num_bit_list = [8] num_bit_list = list(range(2,9))
e_bit_list = list(range(1,7))
return num_bit_list, e_bit_list return num_bit_list
def build_bias_list(quant_type): def build_bias_list(quant_type):
if quant_type == 'POT': if quant_type == 'POT':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment