Commit 06dfc82c by Zhihong Ma
parents 4a099b52 c6cb641c
# AlexNet 量化说明
## ptq部分
+ INT/POT/FLOAT量化都采用相同的框架,可以通过`quant_type`进行确定
+ 量化范围:均采用有符号对称量化,且将zeropoint定为0
+ 量化策略:在第一次forward执行伪量化,统计每层的x和weight范围;后续interface在卷积/池化层中使用量化后的值进行运算。量化均通过放缩至对应范围后从量化点列表取最近点进行实现。
+ bias说明:每种量化模式下,bias采用相同的量化策略(INT:32bit量化,POT:8bit量化,FP8:FP16-E7量化)。bias量化损失对结果影响小,该量化策略也对运算硬件实现影响不大,但是在代码实现上可以更加高效,故采用。(英伟达的量化策略甚至直接舍弃了bias)
+ 关于量化策略设置,可以更改`module.py`中的`bias_qmax`函数和`utils.py`中的`build_bias_list`函数
+ 由于INT量化位宽较高,使用量化表开销过大,直接使用round_操作即可
+ 量化点选择:
+ INT:取INT2-INT16(INT16后相比全精度无损失)
+ POT:取POT2-POT8 (POT8之后容易出现Overflow)
+ FP8:取E1-E6 (E0相当于INT量化,E7相当于POT量化,直接取相应策略效果更好)
+ 支持调整FP的位宽
+ 关于量化点选择,可以更改`utils.py`中的`bit_list`函数
+ 量化结果:
FP32-acc:85.08
<img src="image/table.png" alt="table" style="zoom: 33%;" />
+ 数据拟合:
matlab导入数据,选择列向量
+ js_flops - acc_loss
Rational: Numerator degree 2 / Denominator degree 2
- [ ] center and scale
![fig1](image/fig1.png)
- [x] center and scale
![fig2](image/fig2.png)
+ js_param - acc_loss
Rational: Numerator degree 2 / Denominator degree 2
- [ ] center and scale
![fig3](image/fig3.png)
- [x] center and scale
![fig4](image/fig4.png)
\ No newline at end of file
import sys
import os
# 从get_param.py输出重定向文件val.txt中提取参数量和计算量
def extract_ratio():
fr = open('param_flops.txt','r')
lines = fr.readlines()
layer = []
par_ratio = []
flop_ratio = []
for line in lines:
if '(' in line and ')' in line:
layer.append(line.split(')')[0].split('(')[1])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
par_ratio.append(r1)
r2 = line.split('%')[-2].split(',')[-1]
r2 = float(r2)
flop_ratio.append(r2)
return layer, par_ratio, flop_ratio
if __name__ == "__main__":
layer, par_ratio, flop_ratio = extract_ratio()
print(layer)
print(par_ratio)
print(flop_ratio)
\ No newline at end of file
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
\ No newline at end of file
from model import *
import torch
from ptflops import get_model_complexity_info
if __name__ == "__main__":
model = AlexNet()
full_file = 'ckpt/cifar10_AlexNet.pt'
model.load_state_dict(torch.load(full_file))
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
import torch
import torch.nn as nn
import torch.nn.functional as F
from module import *
import module
class AlexNet(nn.Module):
def __init__(self, num_channels=3, num_classes=10):
super(AlexNet, self).__init__()
# original size 32x32
self.conv1 = nn.Conv2d(num_channels, 32, kernel_size=3, padding=1, bias=True)
self.relu1 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) # output[48, 27, 27]
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=True) # output[128, 27, 27]
self.relu2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) # output[128, 13, 13]
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=True) # output[192, 13, 13]
self.relu3 = nn.ReLU(inplace=True)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1, bias=True) # output[192, 13, 13]
self.relu4 = nn.ReLU(inplace=True)
self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=True) # output[128, 13, 13]
self.relu5 = nn.ReLU(inplace=True)
self.pool5 = nn.MaxPool2d(kernel_size=3, stride=2)
self.drop1 = nn.Dropout(p=0.5)
self.fc1 = nn.Linear(256 * 3 * 3, 1024, bias=True)
self.relu6 = nn.ReLU(inplace=True)
self.drop2 = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(1024, 512, bias=True)
self.relu7 = nn.ReLU(inplace=True)
self.fc3 = nn.Linear(512, num_classes, bias=True)
def forward(self, x):
x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.relu3(x)
x = self.conv4(x)
x = self.relu4(x)
x = self.conv5(x)
x = self.relu5(x)
x = self.pool5(x)
x = torch.flatten(x, start_dim=1)
x = self.drop1(x)
x = self.fc1(x)
x = self.relu6(x)
x = self.drop2(x)
x = self.fc2(x)
x = self.relu7(x)
x = self.fc3(x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
# e_bits仅当使用FLOAT量化时用到
self.qconv1 = QConv2d(quant_type, self.conv1, qi=True, qo=True, num_bits=num_bits, e_bits=e_bits)
self.qrelu1 = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
self.qpool1 = QMaxPooling2d(quant_type, kernel_size=2, stride=2, padding=0, num_bits=num_bits, e_bits=e_bits)
self.qconv2 = QConv2d(quant_type, self.conv2, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
self.qrelu2 = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
self.qpool2 = QMaxPooling2d(quant_type, kernel_size=2, stride=2, padding=0, num_bits=num_bits, e_bits=e_bits)
self.qconv3 = QConv2d(quant_type, self.conv3, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
self.qrelu3 = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
self.qconv4 = QConv2d(quant_type, self.conv4, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
self.qrelu4 = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
self.qconv5 = QConv2d(quant_type, self.conv5, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
self.qrelu5 = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
self.qpool5 = QMaxPooling2d(quant_type, kernel_size=3, stride=2, padding=0, num_bits=num_bits, e_bits=e_bits)
self.qfc1 = QLinear(quant_type, self.fc1, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
self.qrelu6 = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
self.qfc2 = QLinear(quant_type, self.fc2, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
self.qrelu7 = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
self.qfc3 = QLinear(quant_type, self.fc3, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
def quantize_forward(self, x):
x = self.qconv1(x)
x = self.qrelu1(x)
x = self.qpool1(x)
x = self.qconv2(x)
x = self.qrelu2(x)
x = self.qpool2(x)
x = self.qconv3(x)
x = self.qrelu3(x)
x = self.qconv4(x)
x = self.qrelu4(x)
x = self.qconv5(x)
x = self.qrelu5(x)
x = self.qpool5(x)
x = torch.flatten(x, start_dim=1)
x = self.drop1(x)
x = self.qfc1(x)
x = self.qrelu6(x)
x = self.drop2(x)
x = self.qfc2(x)
x = self.qrelu7(x)
x = self.qfc3(x)
return x
def freeze(self):
self.qconv1.freeze()
self.qrelu1.freeze(self.qconv1.qo)
self.qpool1.freeze(self.qconv1.qo)
self.qconv2.freeze(self.qconv1.qo)
self.qrelu2.freeze(self.qconv2.qo)
self.qpool2.freeze(self.qconv2.qo)
self.qconv3.freeze(self.qconv2.qo)
self.qrelu3.freeze(self.qconv3.qo)
self.qconv4.freeze(self.qconv3.qo)
self.qrelu4.freeze(self.qconv4.qo)
self.qconv5.freeze(self.qconv4.qo)
self.qrelu5.freeze(self.qconv5.qo)
self.qpool5.freeze(self.qconv5.qo)
self.qfc1.freeze(self.qconv5.qo)
self.qrelu6.freeze(self.qfc1.qo)
self.qfc2.freeze(self.qfc1.qo)
self.qrelu7.freeze(self.qfc2.qo)
self.qfc3.freeze(self.qfc2.qo)
def quantize_inference(self, x):
qx = self.qconv1.qi.quantize_tensor(x)
qx = self.qconv1.quantize_inference(qx)
qx = self.qrelu1.quantize_inference(qx)
qx = self.qpool1.quantize_inference(qx)
qx = self.qconv2.quantize_inference(qx)
qx = self.qrelu2.quantize_inference(qx)
qx = self.qpool2.quantize_inference(qx)
qx = self.qconv3.quantize_inference(qx)
qx = self.qrelu3.quantize_inference(qx)
qx = self.qconv4.quantize_inference(qx)
qx = self.qrelu4.quantize_inference(qx)
qx = self.qconv5.quantize_inference(qx)
qx = self.qrelu5.quantize_inference(qx)
qx = self.qpool5.quantize_inference(qx)
qx = torch.flatten(qx, start_dim=1)
qx = self.qfc1.quantize_inference(qx)
qx = self.qrelu6.quantize_inference(qx)
qx = self.qfc2.quantize_inference(qx)
qx = self.qrelu7.quantize_inference(qx)
qx = self.qfc3.quantize_inference(qx)
out = self.qfc3.qo.dequantize_tensor(qx)
return out
AlexNet(
3.87 M, 100.000% Params, 70.08 MMac, 100.000% MACs,
(conv1): Conv2d(896, 0.023% Params, 917.5 KMac, 1.309% MACs, 3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1): ReLU(0, 0.000% Params, 32.77 KMac, 0.047% MACs, inplace=True)
(pool1): MaxPool2d(0, 0.000% Params, 32.77 KMac, 0.047% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(18.5 k, 0.478% Params, 4.73 MMac, 6.756% MACs, 32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu2): ReLU(0, 0.000% Params, 16.38 KMac, 0.023% MACs, inplace=True)
(pool2): MaxPool2d(0, 0.000% Params, 16.38 KMac, 0.023% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv3): Conv2d(73.86 k, 1.909% Params, 4.73 MMac, 6.745% MACs, 64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3): ReLU(0, 0.000% Params, 8.19 KMac, 0.012% MACs, inplace=True)
(conv4): Conv2d(295.17 k, 7.630% Params, 18.89 MMac, 26.955% MACs, 128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu4): ReLU(0, 0.000% Params, 16.38 KMac, 0.023% MACs, inplace=True)
(conv5): Conv2d(590.08 k, 15.252% Params, 37.77 MMac, 53.887% MACs, 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu5): ReLU(0, 0.000% Params, 16.38 KMac, 0.023% MACs, inplace=True)
(pool5): MaxPool2d(0, 0.000% Params, 16.38 KMac, 0.023% MACs, kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
(drop1): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.5, inplace=False)
(fc1): Linear(2.36 M, 61.010% Params, 2.36 MMac, 3.368% MACs, in_features=2304, out_features=1024, bias=True)
(relu6): ReLU(0, 0.000% Params, 1.02 KMac, 0.001% MACs, inplace=True)
(drop2): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.5, inplace=False)
(fc2): Linear(524.8 k, 13.565% Params, 524.8 KMac, 0.749% MACs, in_features=1024, out_features=512, bias=True)
(relu7): ReLU(0, 0.000% Params, 512.0 Mac, 0.001% MACs, inplace=True)
(fc3): Linear(5.13 k, 0.133% Params, 5.13 KMac, 0.007% MACs, in_features=512, out_features=10, bias=True)
)
from torch.serialization import load
from model import *
from extract_ratio import *
from utils import *
import gol
import openpyxl
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
def direct_quantize(model, test_loader,device):
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model.quantize_forward(data).cpu()
if i % 500 == 0:
break
print('direct quantization finish')
def full_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model(data).cpu()
pred = output.argmax(dim=1, keepdim=True)
# print(pred)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
return 100. * correct / len(test_loader.dataset)
def quantize_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model.quantize_inference(data).cpu()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('Test set: Quant Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
return 100. * correct / len(test_loader.dataset)
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
if __name__ == "__main__":
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = AlexNet()
writer = SummaryWriter(log_dir='./log')
full_file = 'ckpt/cifar10_AlexNet.pt'
model.load_state_dict(torch.load(full_file))
model.to(device)
load_ptq = True
ptq_file_prefix = 'ckpt/cifar10_AlexNet_ptq_'
model.eval()
full_acc = full_inference(model, test_loader, device)
full_params = []
layer, par_ratio, flop_ratio = extract_ratio()
for name, param in model.named_parameters():
param_norm = F.normalize(param.data.cpu(),p=2,dim=-1)
full_params.append(param_norm)
writer.add_histogram(tag='Full_' + name + '_data', values=param.data)
gol._init()
quant_type_list = ['INT','POT','FLOAT']
title_list = []
js_flops_list = []
js_param_list = []
ptq_acc_list = []
acc_loss_list = []
for quant_type in quant_type_list:
num_bit_list, e_bit_list = bit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
for e_bits in e_bit_list:
model_ptq = AlexNet()
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nPTQ: '+title)
title_list.append(title)
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
# 判断是否需要载入
if load_ptq is True and osp.exists(ptq_file_prefix + title + '.pt'):
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.load_state_dict(torch.load(ptq_file_prefix + title + '.pt'))
model_ptq.to(device)
print('Successfully load ptq model: ' + title)
else:
model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.eval()
direct_quantize(model_ptq, train_loader, device)
torch.save(model_ptq.state_dict(), ptq_file_prefix + title + '.pt')
model_ptq.freeze()
ptq_acc = quantize_inference(model_ptq, test_loader, device)
ptq_acc_list.append(ptq_acc)
acc_loss = (full_acc - ptq_acc) / full_acc
acc_loss_list.append(acc_loss)
idx = -1
# 获取计算量/参数量下的js-div
js_flops = 0.
js_param = 0.
for name, param in model_ptq.named_parameters():
if '.' not in name:
continue
idx = idx + 1
prefix = name.split('.')[0]
if prefix in layer:
layer_idx = layer.index(prefix)
ptq_param = param.data.cpu()
# 取L2范数
ptq_norm = F.normalize(ptq_param,p=2,dim=-1)
writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param)
js = js_div(ptq_norm,full_params[idx])
js = js.item()
if js < 0.:
js = 0.
js_flops = js_flops + js * flop_ratio[layer_idx]
js_param = js_param + js * flop_ratio[layer_idx]
js_flops_list.append(js_flops)
js_param_list.append(js_param)
print(title + ': js_flops: %f js_param: %f acc_loss: %f' % (js_flops, js_param, acc_loss))
# 写入xlsx
workbook = openpyxl.Workbook()
worksheet = workbook.active
worksheet.cell(row=1,column=1,value='FP32-acc')
worksheet.cell(row=1,column=2,value=full_acc)
worksheet.cell(row=3,column=1,value='title')
worksheet.cell(row=3,column=2,value='js_flops')
worksheet.cell(row=3,column=3,value='js_param')
worksheet.cell(row=3,column=4,value='ptq_acc')
worksheet.cell(row=3,column=5,value='acc_loss')
for i in range(len(title_list)):
worksheet.cell(row=i+4, column=1, value=title_list[i])
worksheet.cell(row=i+4, column=2, value=js_flops_list[i])
worksheet.cell(row=i+4, column=3, value=js_param_list[i])
worksheet.cell(row=i+4, column=4, value=ptq_acc_list[i])
worksheet.cell(row=i+4, column=5, value=acc_loss_list[i])
workbook.save('ptq_result.xlsx')
writer.close()
ft = open('ptq_result.txt','w')
print('title_list:',file=ft)
print(" ".join(title_list),file=ft)
print('js_flops_list:',file=ft)
print(" ".join(str(i) for i in js_flops_list), file=ft)
print('js_param_list:',file=ft)
print(" ".join(str(i) for i in js_param_list), file=ft)
print('ptq_acc_list:',file=ft)
print(" ".join(str(i) for i in ptq_acc_list), file=ft)
print('acc_loss_list:',file=ft)
print(" ".join(str(i) for i in acc_loss_list), file=ft)
ft.close()
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
7507.750226317713 2739.698390971301 602.5613310246055 140.92197221503724 34.51721888016634 8.518508718865842 2.1353732883428638 0.5319411628570782 0.1316271020831477 0.03249564657892055 0.008037284252895557 0.0020460099353784723 0.00041867764927864105 0.0001321614950419231 5.841430176387608e-06 7507.667348902921 1654.3775934528933 136.7401730898288 134.5782970456457 134.57841422062364 134.5783939274636 134.5782945727605 33.31638902152266 32.12034308540418 0.6541880874259414 2.442034364817909 9.688117360231624 37.70544899186622
js_param_list:
7507.750226317713 2739.698390971301 602.5613310246055 140.92197221503724 34.51721888016634 8.518508718865842 2.1353732883428638 0.5319411628570782 0.1316271020831477 0.03249564657892055 0.008037284252895557 0.0020460099353784723 0.00041867764927864105 0.0001321614950419231 5.841430176387608e-06 7507.667348902921 1654.3775934528933 136.7401730898288 134.5782970456457 134.57841422062364 134.5783939274636 134.5782945727605 33.31638902152266 32.12034308540418 0.6541880874259414 2.442034364817909 9.688117360231624 37.70544899186622
ptq_acc_list:
10.0 10.16 51.21 77.39 83.03 84.73 84.84 85.01 85.08 85.07 85.06 85.08 85.08 85.08 85.08 10.0 14.32 72.49 72.65 72.95 72.08 72.23 82.73 83.3 85.01 84.77 59.86 51.87
acc_loss_list:
0.8824635637047484 0.8805829807240245 0.3980959097320169 0.0903855195110484 0.02409496944052653 0.004113775270333736 0.0028208744710859768 0.0008227550540666805 0.0 0.00011753643629531167 0.0002350728725904563 0.0 0.0 0.0 0.0 0.8824635637047484 0.8316878232251997 0.14797837329572172 0.14609779031499756 0.14257169722614005 0.152797367183827 0.15103432063939815 0.027621062529384042 0.020921485660554785 0.0008227550540666805 0.0036436295251528242 0.29642689233662434 0.39033850493653033
from model import *
from utils import *
import gol
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
def quantize_aware_training(model, device, train_loader, optimizer, epoch):
lossLayer = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(train_loader, 1):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model.quantize_forward(data)
loss = lossLayer(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Quantize Aware Training Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
def full_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
def quantize_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_inference(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Quant Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
if __name__ == "__main__":
# arg1 = int(sys.argv[1]) # epoch
# arg2 = int(sys.argv[2]) # bits of quantization
batch_size = 32
seed = 1
epochs1 = 3
epochs2 = 30 # 16~30
lr1 = 0.01
lr2 = 0.001
momentum = 0.5
using_bn = False
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
# if using_bn:
# model = NetBN()
# model.load_state_dict(torch.load('ckpt/mnist_cnnbn.pt', map_location='cpu'))
# save_file = "ckpt/mnist_cnnbn_qat.pt"
# else:
# model = Net()
# model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu'))
# save_file = "ckpt/mnist_cnn_qat.pt"
model = AlexNet()
# model.load_state_dict(torch.load('ckpt/cifar10_AlexNet_t5.pt', map_location='cpu'))
model.load_state_dict(torch.load('ckpt/cifar10_AlexNet_t4.pt'))
save_file = "ckpt/cifar10_AlexNet_qat_e4.pt"
load_quant_model_file = None
# load_quant_model_file = "ckpt/cifar10_AlexNet_qat_ratio_4.pt"
model.to(device)
# 原来是所有参数,包括原有conv的参数
# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
# 改动后只训练每层的scale,需要在量化后才能指定
model.eval()
full_inference(model, test_loader)
# for param_tensor, param_value in model.state_dict().items():
# print(param_tensor, "\t", param_value)
num_bits = 8
e_bits = 4
gol._init()
plist = build_list(num_bits=16,e_bits=5)
gol.set_value(plist,is_bias=True)
plist = build_list(num_bits,e_bits)
gol.set_value(plist)
plist = build_list(num_bits,e_bits)
# print(plist)
gol.set_value(plist)
model.quantize(num_bits,e_bits)
print('Quantization bit: %d' % num_bits)
params,params_name = model.get_quant_scales()
optimizer1 = optim.SGD(params, lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(params, lr=lr2, momentum=momentum)
# print('--debug--')
# for name in params_name:
# print(name)
# input()
if load_quant_model_file is not None:
model.load_state_dict(torch.load(load_quant_model_file))
print("Successfully load quantized model %s" % load_quant_model_file)
else:
model.train()
for epoch in range(1, epochs1 + 1):
quantize_aware_training(model, device, train_loader, optimizer1, epoch)
# for epoch in range(epochs1 + 1, epochs2 + 1):
# quantize_aware_training(model, device, train_loader, optimizer2, epoch)
model.eval()
torch.save(model.state_dict(), save_file)
model.freeze()
# for name, param in model.named_parameters():
# print(name)
# print(param.data)
# print('----------')
# for param_tensor, param_value in model.state_dict().items():
# print(param_tensor, "\t", param_value)
quantize_inference(model, test_loader)
from model import *
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += lossLayer(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
test_loss, 100. * correct / len(test_loader.dataset)
))
if __name__ == "__main__":
batch_size = 32
test_batch_size = 32
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
momentum = 0.5
save_model = True
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = AlexNet().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
if save_model:
if not osp.exists('ckpt'):
os.makedirs('ckpt')
torch.save(model.state_dict(), 'ckpt/cifar10_AlexNet.pt')
\ No newline at end of file
import torch
def bit_list(quant_type):
if quant_type == 'INT':
num_bit_list = list(range(2,17))
e_bit_list = [0]
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
e_bit_list = [0]
else:
num_bit_list = [8]
e_bit_list = list(range(1,7))
return num_bit_list, e_bit_list
def build_bias_list(quant_type):
if quant_type == 'POT':
return build_pot_list(8)
else:
return build_float_list(16,7)
def build_list(quant_type, num_bits, e_bits):
if quant_type == 'POT':
return build_pot_list(num_bits)
else:
return build_float_list(num_bits,e_bits)
def build_pot_list(num_bits):
plist = [0.]
for i in range(-2 ** (num_bits-1) + 2, 1):
# i最高到0,即pot量化最大值为1
plist.append(2. ** i)
plist.append(-2. ** i)
plist = torch.Tensor(list(set(plist)))
# plist = plist.mul(1.0 / torch.max(plist))
return plist
def build_float_list(num_bits,e_bits):
m_bits = num_bits - 1 - e_bits
plist = [0.]
# 相邻尾数的差值
dist_m = 2 ** (-m_bits)
e = -2 ** (e_bits - 1) + 1
for m in range(1, 2 ** m_bits):
frac = m * dist_m # 尾数部分
expo = 2 ** e # 指数部分
flt = frac * expo
plist.append(flt)
plist.append(-flt)
for e in range(-2 ** (e_bits - 1) + 2, 2 ** (e_bits - 1) + 1):
expo = 2 ** e
for m in range(0, 2 ** m_bits):
frac = 1. + m * dist_m
flt = frac * expo
plist.append(flt)
plist.append(-flt)
plist = torch.Tensor(list(set(plist)))
return plist
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment