Commit 55f47e27 by Klin
parents bff8b078 8352e4cf
from model import *
from extract_ratio import *
from utils import *
import argparse
import openpyxl
import os
import os.path as osp
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Loss-Grad Analysis')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
args = parser.parse_args()
wb = openpyxl.Workbook()
ws = wb.active
writer = SummaryWriter(log_dir='log/' + args.model + '/qat_loss_grad')
# layer, par_ratio, flop_ratio = extract_ratio()
layer, par_ratio, flop_ratio = extract_ratio(args.model)
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
layer = []
for name, param in model.named_parameters():
if 'weight' in name:
n = name.split('.')
pre = '.'.join(n[:len(n)-1])
layer.append(pre)
# dir_prefix = 'ckpt/qat/epoch_'
dir_prefix = 'ckpt/qat/'+ args.model + '/'
quant_type_list = ['INT','POT','FLOAT']
for epoch in [5,10,15,20]:
ws_epoch = wb.create_sheet('epoch_%d'%epoch)
full_state = torch.load(dir_prefix+'%d/'%epoch + 'full.pt')
ws_epoch.cell(row=1,column=2,value='loss')
ws_epoch.cell(row=1,column=3,value='loss_sum')
ws_epoch.cell(row=1,column=4,value='loss_avg')
ws_epoch.cell(row=2,column=1,value='FP32')
ws_epoch.cell(row=2,column=2,value=full_state['loss'].cpu().item())
ws_epoch.cell(row=2,column=3,value=full_state['loss_sum'].cpu().item())
ws_epoch.cell(row=2,column=4,value=full_state['loss_avg'].cpu().item())
# full_grad = full_state['grad_dict']
# full_grad_sum = full_state['grad_dict_sum']
full_grad_avg = full_state['grad_dict_avg']
for name,tmpgrad in full_grad_avg.items():
writer.add_histogram('FULL: '+name,tmpgrad,global_step=epoch)
ws_epoch.cell(row=4,column=1,value='title')
ws_epoch.cell(row=4,column=2,value='loss')
ws_epoch.cell(row=4,column=3,value='loss_sum')
ws_epoch.cell(row=4,column=4,value='loss_avg')
ws_epoch.cell(row=4,column=5,value='js_grad_avg_norm')
# ws_epoch.cell(row=4,column=6,value='conv1.weight')
# ws_epoch.cell(row=4,column=7,value='conv1.bias')
# ws_epoch.cell(row=4,column=8,value='conv2.weight')
# ws_epoch.cell(row=4,column=9,value='conv2.bias')
# ws_epoch.cell(row=4,column=10,value='conv3.weight')
# ws_epoch.cell(row=4,column=11,value='conv3.bias')
# ws_epoch.cell(row=4,column=12,value='conv4.weight')
# ws_epoch.cell(row=4,column=13,value='conv4.bias')
# ws_epoch.cell(row=4,column=14,value='conv5.weight')
# ws_epoch.cell(row=4,column=15,value='conv5.bias')
# ws_epoch.cell(row=4,column=16,value='fc1.weight')
# ws_epoch.cell(row=4,column=17,value='fc1.bias')
# ws_epoch.cell(row=4,column=18,value='fc2.weight')
# ws_epoch.cell(row=4,column=19,value='fc2.bias')
# ws_epoch.cell(row=4,column=20,value='fc3.weight')
# ws_epoch.cell(row=4,column=21,value='fc3.bias')
cnt = 5
for n in layer:
cnt = cnt + 1
ws_epoch.cell(row=4,column=cnt,value=n)
currow=4
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nAnalyse: '+title)
currow += 1
qat_state=torch.load(dir_prefix+'%d/'%epoch+title+'.pt')
js_grad_avg_norm=0.
grad_avg = qat_state['grad_dict_avg']
for name,tmpgrad in grad_avg.items():
writer.add_histogram(title+': '+name,tmpgrad,global_step=epoch)
colidx=5
for name,_ in full_grad_avg.items():
prefix = name.split('.')[0]
colidx += 1
layer_idx = layer.index(prefix)
js_norm = js_div_norm(full_grad_avg[name],grad_avg[name])
ws_epoch.cell(row=currow,column=colidx,value=js_norm.cpu().item())
js_grad_avg_norm += flop_ratio[layer_idx] * js_norm
ws_epoch.cell(row=currow,column=1,value=title)
ws_epoch.cell(row=currow,column=2,value=qat_state['loss'].cpu().item())
ws_epoch.cell(row=currow,column=3,value=qat_state['loss_sum'].cpu().item())
ws_epoch.cell(row=currow,column=4,value=qat_state['loss_avg'].cpu().item())
ws_epoch.cell(row=currow,column=5,value=js_grad_avg_norm.cpu().item())
wb.save('loss_grad.xlsx')
writer.close()
\ No newline at end of file
......@@ -72,9 +72,9 @@ class ResNet(nn.Module):
self.layer2.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer3.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer4.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.qavgpool1 = QAdaptiveAvgPool2d(quant_type,qi=False,num_bits=num_bits,e_bits=e_bits)
# self.qfc1 = QLinear(quant_type, self.fc,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qfc1 = QLinear(quant_type, self.fc,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qavgpool1 = QAdaptiveAvgPool2d(quant_type,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qfc1 = QLinear(quant_type, self.fc,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
# self.qfc1 = QLinear(quant_type, self.fc,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)
def quantize_forward(self, x):
# for _, layer in self.quantize_layers.items():
......@@ -102,8 +102,8 @@ class ResNet(nn.Module):
qo = self.layer3.freeze(qinput = qo)
qo = self.layer4.freeze(qinput = qo)
self.qavgpool1.freeze(qi=qo)
# self.qfc1.freeze(qi=self.qavgpool1.qo)
self.qfc1.freeze()
self.qfc1.freeze(qi=self.qavgpool1.qo)
# self.qfc1.freeze()
def fakefreeze(self):
pass
......
......@@ -16,6 +16,7 @@ def get_nearest_val(quant_type,x,is_bias=False):
plist = gol.get_value(is_bias)
# print('get')
# print(plist)
# x = x / 64
shape = x.shape
xhard = x.view(-1)
plist = plist.type_as(x)
......@@ -23,6 +24,7 @@ def get_nearest_val(quant_type,x,is_bias=False):
idx = (xhard.unsqueeze(0) - plist.unsqueeze(1)).abs().min(dim=0)[1]
xhard = plist[idx].view(shape)
xout = (xhard - x).detach() + x
# xout = xout * 64
return xout
# 采用对称有符号量化时,获取量化范围最大值
......@@ -64,7 +66,7 @@ def bias_qmax(quant_type):
elif quant_type == 'POT':
return get_qmax(quant_type)
else:
return get_qmax(quant_type, 16, 7) # e7 m9 (e5时不够大,导致数据溢出到两侧)
return get_qmax(quant_type, 16, 7)
# 转化为FP32,不需再做限制
......@@ -222,7 +224,9 @@ class QConv2d(QModule):
x = self.conv_module(x)
x = self.M * x
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
......@@ -279,7 +283,11 @@ class QLinear(QModule):
x = x - self.qi.zero_point
x = self.fc_module(x)
x = self.M * x
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
......@@ -453,7 +461,11 @@ class QConvBNReLU(QModule):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
x.clamp_(min=0)
return x
......@@ -567,38 +579,36 @@ class QConvBN(QModule):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
# x.clamp_(min=0)
return x
# 待修改 需要有qo吧
class QAdaptiveAvgPool2d(QModule):
def __init__(self, quant_type, qi=False, qo=True,num_bits=8, e_bits=3):
def __init__(self, quant_type, qi=False, qo=True, num_bits=8, e_bits=3):
super(QAdaptiveAvgPool2d, self).__init__(quant_type,qi,qo,num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
# if hasattr(self, 'qo') and qo is not None:
# raise ValueError('qo has been provided in init function.')
# if not hasattr(self, 'qo') and qo is None:
# raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
# if qo is not None:
# self.qo = qo
# def fakefreeze(self, qi=None):
# if hasattr(self, 'qi') and qi is not None:
# raise ValueError('qi has been provided in init function.')
# if not hasattr(self, 'qi') and qi is None:
# raise ValueError('qi is not existed, should be provided.')
# if qi is not None:
# self.qi = qi
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qo is not None:
self.qo = qo
self.M.data = (self.qi.scale / self.qo.scale).data
def forward(self, x):
if hasattr(self, 'qi'):
......@@ -609,16 +619,22 @@ class QAdaptiveAvgPool2d(QModule):
x = F.adaptive_avg_pool2d(x,(1, 1)) # 对输入输出都量化一下就算是量化了
# if hasattr(self, 'qo'):
# self.qo.update(x)
# x = FakeQuantize.apply(x, self.qo)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = F.adaptive_avg_pool2d(x,(1,1))
# x = FakeQuantize.apply(x, self.qo
# x = get_nearest_val(self.quant_type,x) # 这里可能并不适配于PoT的情况 缺少一个放缩?
x = F.adaptive_avg_pool2d(x,(1, 1)) # 对输入输出都量化一下就算是量化了
x = self.M * x
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
return x
......@@ -632,12 +648,7 @@ class QModule_2(nn.Module):
if qi1:
self.qi1 = QParam(quant_type,num_bits, e_bits) # qi在此处就已经被num_bits和mode赋值了
if qo:
# if num_bits <=9 :
# self.qo = QParam(quant_type,num_bits, e_bits)
# if num_bits > 9 and num_bits<13:
# self.qo = QParam(quant_type,8, e_bits) # qo在此处就已经被num_bits和mode赋值了
# else:
self.qo = QParam(quant_type,num_bits, e_bits)
self.qo = QParam(quant_type,num_bits, e_bits) # qo在此处就已经被num_bits和mode赋值了
self.quant_type = quant_type
self.num_bits = num_bits
......@@ -688,6 +699,8 @@ class QElementwiseAdd(QModule_2):
# 根据https://zhuanlan.zhihu.com/p/156835141, 这是式3 的系数
self.M0.data = self.qi0.scale / self.qo.scale
self.M1.data = self.qi1.scale / self.qi0.scale
# self.M0.data = self.qi0.scale / self.qo.scale
# self.M1.data = self.qi1.scale / self.qo.scale
def forward(self, x0, x1): # 前向传播,输入张量,x为浮点型数据
if hasattr(self, 'qi0'):
......@@ -706,20 +719,15 @@ class QElementwiseAdd(QModule_2):
return x
def quantize_inference(self, x0, x1): # 此处input为已经量化的qx
# x0_d = self.qi0.dequantize_tensor(x0)
# x1_d = self.qi1.dequantize_tensor(x1)
# print(f"x0={x0_d.reshape(-1)[:10]}")
# print(f"x1={x1_d.reshape(-1)[:10]}")
x0 = x0 - self.qi0.zero_point
x1 = x1 - self.qi1.zero_point
x = self.M0 * (x0 + x1*self.M1)
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
# x_d = self.qo.dequantize_tensor(x)
# print(f"x={x_d.reshape(-1)[:10]}")
# print(f"loss={x_d.reshape(-1)[:10]-(x0_d.reshape(-1)[:10]+x1_d.reshape(-1)[:10])}")
# print('=============')
return x
from model import *
from utils import *
import gol
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
def quantize_aware_training(model, device, train_loader, optimizer, epoch):
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
# 可以遍历各种weight和bias
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader, 1):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model.quantize_forward(data)
# 对一批数据求得的loss是平均值
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad
# print(grad_dict[name])
# print(grad_dict.items())
# input()
optimizer.step()
if batch_idx % 50 == 0:
print('Quantize Aware Training Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def full_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad # 对batch的累加
# print(grad_dict[name])
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler) # batch数量
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / batch_size # 对batch的平均
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def quantize_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_inference(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Quant Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='QAT Training')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
parser.add_argument('-e','--epochs', default=20, type=int, metavar='EPOCHS', help='number of total epochs to run')
parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
parser.add_argument('-j','--workers', default=1, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')
parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('-wd','--weight_decay',default=0.0001,type=float,metavar='WD',help='lr schduler weight decay',dest='wd')
parser.add_argument('-t', '--test', dest='test', action='store_true', help='test model on test set')
args = parser.parse_args()
batch_size = args.batch_size
seed = 1
epochs = args.epochs
lr = args.lr
# momentum = 0.5
weight_decay = args.wd
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
writer = SummaryWriter(log_dir='log/' + args.model + '/qat')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=args.workers, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=args.workers, pin_memory=False
)
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
writer = SummaryWriter(log_dir='log/' + args.model + '/qat')
# full_file = 'ckpt/cifar10_' + args.model + '.pt'
# model.load_state_dict(torch.load(full_file))
model.to(device)
# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
optimizer = optim.Adam(model.parameters(), lr=lr)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
load_qat = False
ckpt_prefix = 'ckpt/qat/'+ args.model + '/'
loss_sum = 0.
grad_dict_sum = {}
grad_dict_avg = {}
for name,param in model.named_parameters():
grad_dict_sum[name] = torch.zeros_like(param)
grad_dict_avg[name] = torch.zeros_like(param)
# full precision from scratch
for epoch in range(1, epochs+1):
# 训练原模型,获取梯度分布
loss,grad_dict = train(model, device, train_loader, optimizer, epoch)
# print('loss:%f' % loss_avg)
writer.add_scalar('Full.loss',loss,epoch)
for name,grad in grad_dict.items():
writer.add_histogram('Full.'+name+'_grad',grad,global_step=epoch)
loss_sum += loss
loss_avg = loss_sum / epoch
for name,grad in grad_dict.items():
grad_dict_sum[name] += grad_dict[name] # 对epoch的累加
grad_dict_avg[name] = grad_dict_sum[name] / epoch # 对epoch求平均
ckpt = {
'epoch' : epoch,
'loss' : loss,
'loss_sum' : loss_sum,
'loss_avg' : loss_avg,
'grad_dict_avg' : grad_dict_avg
}
if epoch % 5 == 0:
subdir = 'epoch_%d/' % epoch
if not os.path.isdir(ckpt_prefix+ subdir):
os.makedirs(ckpt_prefix+ subdir, mode=0o777)
os.chmod(ckpt_prefix+ subdir, mode=0o777)
torch.save(ckpt,ckpt_prefix+ subdir +'full.pt')
lr_scheduler.step()
# loss_avg,grad_dict = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
# print('qat_loss:%f' % loss_avg)
# for name,grad in grad_dict.items():
# writer.add_histogram('qat_'+name+'_grad',grad,global_step=epoch)
# QAT from scratch
quant_type_list = ['INT','POT','FLOAT']
gol._init()
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
if load_qat is True and osp.exists(ckpt_prefix+'epoch_20/'+title+'.pt'):
continue
print('\nQAT: '+title)
# model_ptq = AlexNet()
if args.model == 'ResNet18':
model_ptq = resnet18()
elif args.model == 'ResNet50':
model_ptq = resnet50()
elif args.model == 'ResNet152':
model_ptq = resnet152()
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
# model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.train()
loss_sum = 0.
grad_dict_sum = {}
grad_dict_avg = {}
for name,param in model.named_parameters():
grad_dict_sum[name] = torch.zeros_like(param)
grad_dict_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
loss,grad_dict = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
# print('loss:%f' % loss_avg)
writer.add_scalar(title+'.loss',loss,epoch)
for name,grad in grad_dict.items():
writer.add_histogram(title+'.'+name+'_grad',grad,global_step=epoch)
loss_sum += loss
loss_avg = loss_sum / epoch
for name,param in model.named_parameters():
grad_dict_sum[name] += grad_dict[name]
grad_dict_avg[name] = grad_dict_sum[name] / epoch
ckpt = {
'epoch' : epoch,
'loss' : loss,
'loss_sum' : loss_sum,
'loss_avg' : loss_avg,
# 'grad_dict' : grad_dict,
# 'grad_dict_sum' : grad_dict_sum,
'grad_dict_avg' : grad_dict_avg
}
if epoch % 5 == 0:
subdir = 'epoch_%d/' % epoch
if not os.path.isdir(ckpt_prefix+ subdir):
os.makedirs(ckpt_prefix+ subdir, mode=0o777)
os.chmod(ckpt_prefix+ subdir, mode=0o777)
torch.save(ckpt,ckpt_prefix+subdir + title+'.pt')
lr_scheduler.step()
writer.close()
# # model.eval()
# # full_inference(model, test_loader)
# num_bits = 8
# e_bits = 0
# gol._init()
# print("qat: INT8")
# model.quantize('INT',num_bits,e_bits)
# print('Quantization bit: %d' % num_bits)
# if load_quant_model_file is not None:
# model.load_state_dict(torch.load(load_quant_model_file))
# print("Successfully load quantized model %s" % load_quant_model_file)
# else:
# model.train()
# for epoch in range(1, epochs+1):
# quantize_aware_training(model, device, train_loader, optimizer, epoch)
# # for epoch in range(epochs1 + 1, epochs2 + 1):
# # quantize_aware_training(model, device, train_loader, optimizer2, epoch)
# model.eval()
# # torch.save(model.state_dict(), save_file)
# model.freeze()
# # for name, param in model.named_parameters():
# # print(name)
# # print(param.data)
# # print('----------')
# # for param_tensor, param_value in model.state_dict().items():
# # print(param_tensor, "\t", param_value)
# quantize_inference(model, test_loader)
## update: <br>2023.4.17<br>
- 已针对4.12中的问题进行了修正和补充:
- INT量化位宽较大时,acc骤降,是因为QAdaptiveAvgPool2d的quantize_inference中的写法有问题,修正后结果符合预期。
- 已检查网络中的细节错误并修正,在QElementwiseAdd的quantize_inference中补充了x = get_nearest_val(self.quant_type,x)
- 已确认加权计算过程正确
- 发现并尝试解决的新问题:
- 在完成上述的修改后,INT量化和FP量化的结果与预期相符,但PoT量化出现了量化后推理精度极低的问题(10%~20%)。经过大量实验研究,我发现主要的问题是在PoT量化后的模型中,每层在quantize_inference中对输出进行的激活量化: x = get_nearest_val(self.quant_type,x)导致了较大的误差。<br>
针对该问题,我进行了具体的研究。影响量化后模型的因素不仅仅是模型的权重参数,每层输出数据在激活量化后也会引入误差。INT量化和FP量化中,输出数据在激活量化后引入的误差比较小(因为INT量化对输出数据就是直接rounding取整,而FP量化因为量化精度比较细致,因此误差也相对较小),但PoT量化不同,一般情况下在ResNet系列都有比较显著的误差。尤其当模型的层数较深时(比如ResNet50,ResNet152),量化后的推理精度因为这个原因有很大的损失(从90% => 10%~20%)。<br>
我最初有上述想法是在反复确认程序框架没有明显问题后,PoT量化的效果总是很差,而模型权重参数的相似度又很高,因此想到了可能是输入输出数据方面的问题。为验证该想法,我进行了实验,在各个量化模块的quantize_inference方法中,我设置在PoT量化时,不进行x = get_nearest_val(self.quant_type,x),即只对每层的权重参数量化,而不量化每层的输出tensor. 经过实验,PoT量化后的推理精度提升了不少(到达了50%~60%),总体拟合得到的曲线也比较符合预期了。<br>
考虑到目前的自变量只是模型量化前后权重参数的分布相似度,我便没有再引入其他指标来刻画每层输出的tensor在量化前后的分布相似度。
<br>
- 正在尝试的实验:
对PoT量化进行单独的考虑,先只是量化权值,看一下权值分布相似度跟acc的关系。然后选一个acc比较高的PoT权值保持不变,再用不同的数据表示去量化激活,这时候看激活的位宽和acc的关系。
- 尚未完成的内容:
- 收敛速度与梯度分布相似度的实验
<br> <br> <br> <br>
- 拟合曲线:
1. PoT量化中只对网络权值量化,INT和FP对网络权值和激活都进行量化:<br>
resnet18:
<img src = "fig/18_flops.png" class="h-90 auto">
<img src = "fig/18_params.png" class="h-90 auto">
resnet50:
<img src = "fig/50_flops.png" class="h-90 auto">
<img src = "fig/50_params.png" class="h-90 auto">
resnet152:
<img src = "fig/152_flops.png" class="h-90 auto">
<img src = "fig/152_params.png" class="h-90 auto">
2. 去掉PoT量化:<br>
resnet18:
<img src = "fig/18_flops_nopot.png" class="h-90 auto">
<img src = "fig/18_params_nopot.png" class="h-90 auto">
resnet50:
<img src = "fig/50_flops_nopot.png" class="h-90 auto">
<img src = "fig/50_params_nopot.png" class="h-90 auto">
resnet152:
<img src = "fig/152_flops_nopot.png" class="h-90 auto">
<img src = "fig/152_params_nopot.png" class="h-90 auto">
## update: <br>2023.4.12<br>
- 已修改get_param_flops.py, extract_ratio.py, ptq.py中与计算数据分布相似度、计算计算量、参数量加权权重相关的程序,使其能够适应于ResNet系列网络。<br>
(注: 目前需要在得到了param_flops_ResNet18/50/152.txt后,手动删除在layer中重复出现了一次的downsample module的统计数据,考虑到resnet18的只需要删3处,resnet50,152中的需要删4处,故直接手动进行了)
......
import torch
import torch.nn as nn
import torch.nn.functional as F
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
......@@ -22,7 +34,7 @@ def numbit_list(quant_type):
def build_bias_list(quant_type):
if quant_type == 'POT':
return build_pot_list(8)
return build_pot_list(8) #
else:
return build_float_list(16,7)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment