Commit f4b96743 by Klin
parents 38c41268 91d53d31
# -*- coding: utf-8 -*-
import numpy
import numpy as np
import torch
import sys
from mmd_loss import *
from collections import OrderedDict
d1 = sys.argv[1] # bit
d2 = sys.argv[2] # epoch
# d1=4
# d2=5
sum=0
flag=0
total_quan_list=list()
total_base_list=list()
# CNN FLOPs = Cout * Hout * Wout * (2 * Cin * K * K ) 是考虑bias 否则-1
# FCN FLOPs = Cout * Cin 是考虑bias 否则-1
# 把相关的relu,pool也考虑进去了
# MAdd
# weight0 =np.array( [ 705600.0+4704.0+ 3528.0 , 480000.0+ 1600.0 + 1200.0 , 95880.0 + 120.0,
# 20076.0 + 84.0 , 1670.0 ])
# weight1=np.array([705,600.0 , 480,000.0,+ 95,880.0 ,
# 20,076.0 , 1,670.0 ])
# flops
weight_f0= np.array([357504+4704+4704, 241600+1600+1600,48000+120,10080+84,840])
weight_f1=np.array([357504, 241600,48000,10080,840])
summary_quan_dict=OrderedDict()
summary_base_dict=OrderedDict()
losses=[]
# 最外层:不同epoch的字典 内层:各个网络层的grads
for i in range(int(d2)):
total_quan_list.append(torch.load('./project/p/checkpoint/cifar-10_lenet_bn_quant/' + str(d1) + '/ckpt_cifar-10_lenet_bn_quant_'+str(i+1)+'.pth'))
#total_quan_list.append(torch.load('checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(d2) + '.pth'))
total_base_list.append(torch.load('./project/p/checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(i+1) + '.pth'))
for k, _ in total_base_list[i]['grads'].items():
if flag == 0:
summary_quan_dict[k] = total_quan_list[i]['grads'][k].reshape(1,-1)
summary_base_dict[k] = total_base_list[i]['grads'][k].reshape(1,-1)
else :
# 字典里的数据不能直接改,需要重新赋值
a=summary_quan_dict[k]
b=total_quan_list[i]['grads'][k].reshape(1,-1)
c=np.vstack((a,b))
summary_quan_dict[k] = c
a = summary_base_dict[k]
b = total_base_list[i]['grads'][k].reshape(1,-1)
c = np.vstack((a, b))
summary_base_dict[k] = c
flag = 1
cnt = 0
flag = 0
for k, _ in summary_quan_dict.items():
if flag == 0:
sum += 0.99*weight_f1[cnt] * MK_MMD(source=summary_base_dict[k], target=summary_quan_dict[k]) # weight
else:
sum += 0.01*weight_f1[cnt] * MK_MMD(source=summary_base_dict[k], target=summary_quan_dict[k]) #bias
if flag == 1:
cnt = cnt + 1
flag = 0
else:
flag=1
sum=sum/(weight_f0.sum()*2)
print(sum)
f = open('./project/p/lenet_ptq_similarity.txt','a')
f.write('bit:' + str(d1) + ' epoch_num:' + str(d2) +': '+str(sum)+'\n')
f.close()
# for k,v in summary_base_dict.items():
# if k== 'conv_layers.conv1.weight':
# print(v)
# print('===========')
# print(summary_quan_dict[k])
\ No newline at end of file
import sys
import os
# 从get_param.py输出重定向文件val.txt中提取参数量和计算量
def extract_ratio(md='ResNet18'):
fr = open('param_flops_' + md + '.txt','r')
lines = fr.readlines()
layer = []
par_ratio = []
flop_ratio = []
for line in lines:
# if '(' in line and ')' in line:
if 'Conv' in line or 'BatchNorm2d' in line or 'Linear' in line:
layer.append(line.split(':')[1].split('(')[0])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
par_ratio.append(r1)
r2 = line.split('%')[-2].split(',')[-1]
r2 = float(r2)
flop_ratio.append(r2)
return layer, par_ratio, flop_ratio
if __name__ == "__main__":
layer, par_ratio, flop_ratio = extract_ratio()
print(len(layer))
print(len(par_ratio))
print(len(flop_ratio))
\ No newline at end of file
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
\ No newline at end of file
from model import *
import torch
from ptflops import get_model_complexity_info
import argparse
def get_children(model: torch.nn.Module):
# get children form model!
# 为了后续也能够更新参数,需要用nn.ModuleList来承载
# children = nn.ModuleList(model.children())
# print(children)
# 方便对其中的module进行后续的更新
# flatt_children = nn.ModuleList()
children = list(model.children())
# flatt_children = nn.ModuleList()
flatt_children = []
if len(children) == 0:
# if model has no children; model is last child! :O
return model
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
# print(flatt_children)
return flatt_children
# 定义获取不包含wrapper的所有子模块的函数
def get_all_child_modules(module):
for name, child in module.named_children():
if isinstance(child, nn.Sequential):
yield from get_all_child_modules(child)
elif len(list(child.children())) > 0:
yield from child.children()
else:
yield child
def filter_fn(module, n_inp, outp_shape):
# if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU,torch.nn.BatchNorm2d,torch.nn.Linear,torch.nn.AdaptiveAvgPool2d)):
if 'conv' in module or 'bn' in module or 'fc' in module or 'avg' in module or 'relu' in module:
return True
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Analysis --- params & flops')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='MobileNetV2')
args = parser.parse_args()
# if args.model == 'ResNet18':
# model = resnet18()
# elif args.model == 'ResNet50':
# model = resnet50()
# elif args.model == 'ResNet152':
# model = resnet152()
if args.model == 'MobileNetV2':
model = MobileNetV2()
full_file = 'ckpt/cifar10_' + args.model + '.pt'
model.load_state_dict(torch.load(full_file))
# flat = get_children(model)
# print(flat)
# flat = get_children(model)
# new_model = nn.Sequential(*flat)
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
def get_model_histogram(model):
"""
Description:
- get norm gradients from model, and store in a OrderDict
Args:
- model: (torch.nn.Module), torch model
Returns:
- grads in OrderDict
"""
gradshisto = OrderedDict()
grads = OrderedDict()
for name, params in model.named_parameters():
grad = params.grad
if grad is not None:
tmp = {}
params_np = grad.cpu().numpy()
histogram, bins = np.histogram(params_np.flatten(),bins=20)
tmp['histogram'] = list(histogram)
tmp['bins'] = list(bins)
gradshisto[name] = tmp
grads[name] = params_np
return gradshisto,grads
def get_model_norm_gradient(model):
"""
Description:
- get norm gradients from model, and store in a OrderDict
Args:
- model: (torch.nn.Module), torch model
Returns:
- grads in OrderDict
"""
grads = OrderedDict()
for name, params in model.named_parameters():
grad = params.grad
if grad is not None:
grads[name] = grad.norm().item()
return grads
def get_grad_histogram(grads_sum):
gradshisto = OrderedDict()
# grads = OrderedDict()
for name, params in grads_sum.items():
grad = params
if grad is not None:
tmp = {}
#params_np = grad.cpu().numpy()
params_np = grad
histogram, bins = np.histogram(params_np.flatten(),bins=20)
tmp['histogram'] = list(histogram)
tmp['bins'] = list(bins)
gradshisto[name] = tmp #每层一个histogram (tmp中的是描述直方图的信息)
# grads[name] = params_np
return gradshisto
\ No newline at end of file
class GlobalVariables:
SELF_INPLANES = 0
\ No newline at end of file
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
from model import *
from extract_ratio import *
from utils import *
import argparse
import openpyxl
import os
import os.path as osp
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Loss-Grad Analysis')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
args = parser.parse_args()
wb = openpyxl.Workbook()
ws = wb.active
writer = SummaryWriter(log_dir='log/' + args.model + '/qat_loss_grad')
# layer, par_ratio, flop_ratio = extract_ratio()
layer, par_ratio, flop_ratio = extract_ratio(args.model)
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
layer = []
for name, param in model.named_parameters():
if 'weight' in name:
n = name.split('.')
pre = '.'.join(n[:len(n)-1])
layer.append(pre)
# dir_prefix = 'ckpt/qat/epoch_'
dir_prefix = 'ckpt/qat/'+ args.model + '/'
quant_type_list = ['INT','POT','FLOAT']
for epoch in [5,10,15,20]:
ws_epoch = wb.create_sheet('epoch_%d'%epoch)
full_state = torch.load(dir_prefix+'%d/'%epoch + 'full.pt')
ws_epoch.cell(row=1,column=2,value='loss')
ws_epoch.cell(row=1,column=3,value='loss_sum')
ws_epoch.cell(row=1,column=4,value='loss_avg')
ws_epoch.cell(row=2,column=1,value='FP32')
ws_epoch.cell(row=2,column=2,value=full_state['loss'].cpu().item())
ws_epoch.cell(row=2,column=3,value=full_state['loss_sum'].cpu().item())
ws_epoch.cell(row=2,column=4,value=full_state['loss_avg'].cpu().item())
# full_grad = full_state['grad_dict']
# full_grad_sum = full_state['grad_dict_sum']
full_grad_avg = full_state['grad_dict_avg']
for name,tmpgrad in full_grad_avg.items():
writer.add_histogram('FULL: '+name,tmpgrad,global_step=epoch)
ws_epoch.cell(row=4,column=1,value='title')
ws_epoch.cell(row=4,column=2,value='loss')
ws_epoch.cell(row=4,column=3,value='loss_sum')
ws_epoch.cell(row=4,column=4,value='loss_avg')
ws_epoch.cell(row=4,column=5,value='js_grad_avg_norm')
# ws_epoch.cell(row=4,column=6,value='conv1.weight')
# ws_epoch.cell(row=4,column=7,value='conv1.bias')
# ws_epoch.cell(row=4,column=8,value='conv2.weight')
# ws_epoch.cell(row=4,column=9,value='conv2.bias')
# ws_epoch.cell(row=4,column=10,value='conv3.weight')
# ws_epoch.cell(row=4,column=11,value='conv3.bias')
# ws_epoch.cell(row=4,column=12,value='conv4.weight')
# ws_epoch.cell(row=4,column=13,value='conv4.bias')
# ws_epoch.cell(row=4,column=14,value='conv5.weight')
# ws_epoch.cell(row=4,column=15,value='conv5.bias')
# ws_epoch.cell(row=4,column=16,value='fc1.weight')
# ws_epoch.cell(row=4,column=17,value='fc1.bias')
# ws_epoch.cell(row=4,column=18,value='fc2.weight')
# ws_epoch.cell(row=4,column=19,value='fc2.bias')
# ws_epoch.cell(row=4,column=20,value='fc3.weight')
# ws_epoch.cell(row=4,column=21,value='fc3.bias')
cnt = 5
for n in layer:
cnt = cnt + 1
ws_epoch.cell(row=4,column=cnt,value=n)
currow=4
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nAnalyse: '+title)
currow += 1
qat_state=torch.load(dir_prefix+'%d/'%epoch+title+'.pt')
js_grad_avg_norm=0.
grad_avg = qat_state['grad_dict_avg']
for name,tmpgrad in grad_avg.items():
writer.add_histogram(title+': '+name,tmpgrad,global_step=epoch)
colidx=5
for name,_ in full_grad_avg.items():
prefix = name.split('.')[0]
colidx += 1
layer_idx = layer.index(prefix)
js_norm = js_div_norm(full_grad_avg[name],grad_avg[name])
ws_epoch.cell(row=currow,column=colidx,value=js_norm.cpu().item())
js_grad_avg_norm += flop_ratio[layer_idx] * js_norm
ws_epoch.cell(row=currow,column=1,value=title)
ws_epoch.cell(row=currow,column=2,value=qat_state['loss'].cpu().item())
ws_epoch.cell(row=currow,column=3,value=qat_state['loss_sum'].cpu().item())
ws_epoch.cell(row=currow,column=4,value=qat_state['loss_avg'].cpu().item())
ws_epoch.cell(row=currow,column=5,value=js_grad_avg_norm.cpu().item())
wb.save('loss_grad.xlsx')
writer.close()
\ No newline at end of file
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
3937.841102576776 3363.8190205631026 2484.282190565725 1666.2525440043232 1004.491228033656 619.2894405795665 418.1602230520327 312.24372234629664 249.45408998160573 193.16718609999484 138.595522944474 91.58500300086834 56.08617073690407 31.952073152774126 17.65229708562858 3937.814777378682 3068.5487659326095 613.4307511092446 32.126779884075525 0.1766064000339766 0.17675544226821238 0.17669672446190357 2644.0738554237805 1741.4374062424974 1207.0395601708549 1028.5434609756649 668.1556835852132 273.1682252338685 625.3801636402046 430.06294426631223 203.6193231190563 6.546919210669541 419.88446951149035 315.44088340877573 143.55644125961152 2.884048071197845 0.05101264315330961 312.7100494425904 250.59377938308924 93.48144767534725 1.3453977569799638 0.013982248527794 0.10109822916201198
js_param_list:
5523.5133138619085 4892.5106928859905 3857.644315893276 2707.0379699110163 1671.0773218634195 1050.057190366859 723.4234749265271 550.470391030846 445.310410225667 347.5339012995234 250.31388514236687 165.8204601056395 101.56282095527794 57.86236490894071 31.945254415161216 5523.499968355952 4553.489563739368 1040.3277683793247 57.8656798626289 0.006984618427447457 0.006992355997222603 0.007052672537742183 4053.9039848113193 2820.767118173691 1991.9978400778782 1709.3926174557196 1129.0122760496113 485.4325026388394 1059.7935852519893 742.5316942830176 365.9667180995057 11.782087473714286 726.0475165732927 555.5680608520271 259.1510147521308 5.1897931543416576 0.0014493051074338339 551.1504206498132 447.20405741295923 169.2517013866318 2.4155090345624393 0.000397355594218542 0.03827793516377293
ptq_acc_list:
10.0 10.0 10.44 37.82 81.77 88.99 91.31 91.84 91.96 91.9 91.93 91.88 91.87 91.88 91.89 10.0 10.0 14.09 19.53 11.56 11.34 17.05 10.01 10.36 20.89 25.1 62.06 79.8 45.72 81.59 89.76 79.92 71.56 85.92 91.67 90.2 81.46 73.66 87.45 91.65 91.51 90.24 81.11
acc_loss_list:
0.8914812805208898 0.8914812805208898 0.886706456863809 0.5895822029300054 0.11264243081931644 0.034291915355398925 0.009115572436245289 0.003364080303852439 0.002061855670103222 0.0027129679869777536 0.0023874118285404106 0.002930005425936085 0.003038524145415096 0.002930005425936085 0.0028214867064569192 0.8914812805208898 0.8914812805208898 0.8470971242539338 0.7880629408572979 0.8745523602821487 0.8769397721106891 0.8149755832881173 0.8913727618014107 0.8875746066196419 0.773304395008139 0.7276180141074337 0.32653282691264246 0.1340206185567011 0.5038524145415084 0.11459576776994033 0.02593597395550733 0.13271839392295173 0.22344004340748783 0.06760716223548566 0.00520889853499733 0.02116115029842651 0.11600651112316887 0.20065111231687474 0.051003798155181794 0.005425935973955507 0.006945198046663055 0.020727075420510156 0.11980466630493766
## update: <br>2023.4.24<br>
- 解决的问题
1. 解决了权值参数相似度不正常,无法拟合曲线的问题。
2. 修改了一些小bug
- 思路记录 <br>
在ResNet中遇到的权值参数相似度不太合理的问题在MobileNetV2中更加显著,使得MobileNet无法拟合曲线。考虑到PTQ后推理精度数据比较正常,则问题可能出在了权值参数的相似度计算上。<br>
我仔细检查了是否成功fold了BN,以及相应的ratio,计算js散度时是否使用的是匹配的两个权值参数,发现都没有问题。但在具体查看权值参数的数据时,发现了一些层的权值参数数据异常,也是这些异常层的js非常大干扰了整体的js计算。<br>
从tensorboard来观察,这些层的数据分布相似度应该与全精度模型很像,但js计算结果没能反应出这一点,对这些层的数据重点观察,他们有很多1或-1的值,因此我想到可能是对量化前后的模型权值参数先进行的normalize操作导致了数据分布变得不合理,进而导致了问题。<br>
考虑到normalize操作的本意应该是为了将量化前后的模型权值参数归一到同一个scale进而方便使用js散度计算距离,则可以考虑将量化后的模型的权值参数通过dequantize来恢复到与全精度模型相近的scale,而后再使用js散度计算距离。我将上述过程命名为fakefreeze. 经过实践,效果很好,重新计算的js散度反映的权值参数相似程度与tensorboard直接对数据分布的观察比较一致。
<img src = "fig/defreeze.png" class="h-90 auto">
## update: <br>2023.4.23<br>
1. 实现了MobileNetV2的PTQ量化
2. 目前存在一些问题:<br>
虽然PTQ量化后的acc比较合理,但权值参数相似度比较不正常。无法拟合出合理的曲线。<br>
对数据分析并通过tensorboard观察,一些计算量/参数量大的层在图中观察数据分布相似度较高,但计算出的js散度较大,又因为其加权权重较大,导致了整体的js距离很大。
3. 后续将检查出现问题的原因是代码设计问题还是加权系数不适用于MobileNetV2
from model import *
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += lossLayer(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
test_loss, 100. * correct / len(test_loader.dataset)
))
if __name__ == "__main__":
batch_size = 128
test_batch_size = 128
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
momentum = 0.5
save_model = True
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
)
# model = AlexNet_BN().to(device)
model = resnet18().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
if save_model:
if not osp.exists('ckpt'):
os.makedirs('ckpt')
torch.save(model.state_dict(), 'ckpt/cifar10_ResNet18.pt')
\ No newline at end of file
from model import *
model = MobileNetV2()
model.quantize('INT',8,0)
# for name, module in model.named_modules():
# print(name)
# print('==============================')
# for name, param in model.named_parameters():
# print(name)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
e_bit_list = list(range(1,num_bits-1))
else:
e_bit_list = [0]
return e_bit_list
def numbit_list(quant_type):
if quant_type == 'INT':
num_bit_list = list(range(2,17))
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
else:
num_bit_list = list(range(2,9))
# num_bit_list = [8]
return num_bit_list
def build_bias_list(quant_type):
if quant_type == 'POT':
return build_pot_list(8) #
else:
return build_float_list(16,7)
def build_list(quant_type, num_bits, e_bits):
if quant_type == 'POT':
return build_pot_list(num_bits)
else:
return build_float_list(num_bits,e_bits)
def build_pot_list(num_bits):
plist = [0.]
for i in range(-2 ** (num_bits-1) + 2, 1):
# i最高到0,即pot量化最大值为1
plist.append(2. ** i)
plist.append(-2. ** i)
plist = torch.Tensor(list(set(plist)))
# plist = plist.mul(1.0 / torch.max(plist))
return plist
def build_float_list(num_bits,e_bits):
m_bits = num_bits - 1 - e_bits
plist = [0.]
# 相邻尾数的差值
dist_m = 2 ** (-m_bits)
e = -2 ** (e_bits - 1) + 1
for m in range(1, 2 ** m_bits):
frac = m * dist_m # 尾数部分
expo = 2 ** e # 指数部分
flt = frac * expo
plist.append(flt)
plist.append(-flt)
for e in range(-2 ** (e_bits - 1) + 2, 2 ** (e_bits - 1) + 1):
expo = 2 ** e
for m in range(0, 2 ** m_bits):
frac = 1. + m * dist_m
flt = frac * expo
plist.append(flt)
plist.append(-flt)
plist = torch.Tensor(list(set(plist)))
return plist
def fold_ratio(layer, par_ratio, flop_ratio):
idx = -1
for name in layer:
idx = idx + 1
# layer是for name, param in model.named_parameters()中提取出来的,一定是有downsample的
if 'bn' in name or 'sample.1' in name:
par_ratio[idx-1] += par_ratio[idx]
flop_ratio[idx-1] += flop_ratio[idx]
return par_ratio,flop_ratio
def fold_model(model):
idx = -1
module_list = []
# print('fold model:')
for name, module in model.named_modules():
# print(name+'-- +')
idx += 1
module_list.append(module)
# 这里之前忘记考虑downsampl里的conv了,导致少融合了一些
if 'bn' in name or 'sample.1' in name:
# print(name+'-- *')
module_list[idx-1] = fold_bn(module_list[idx-1],module) # 在这里修改了
return model
# def fold_model(model):
# last_conv = None
# last_bn = None
# for name, module in model.named_modules():
# if isinstance(module, nn.Conv2d):
# # 如果当前模块是卷积层,则将其 "fold" 到上一个 BN 层中
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# last_conv = module
# elif isinstance(module, nn.BatchNorm2d):
# # 如果当前模块是 BN 层,则将其 "fold" 到上一个卷积层中
# last_bn = module
# if last_conv is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# # 处理最后一个 BN 层
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# return model
def fold_bn(conv, bn):
# 获取 BN 层的参数
gamma = bn.weight.data
beta = bn.bias.data
mean = bn.running_mean
var = bn.running_var
eps = bn.eps
std = torch.sqrt(var + eps)
feat = bn.num_features
# 获取卷积层的参数
weight = conv.weight.data
if conv.bias is not None:
bias = conv.bias.data
if bn.affine:
gamma_ = gamma / std
weight = weight * gamma_.view(feat, 1, 1, 1)
if conv.bias is not None:
bias = gamma_ * bias - gamma_ * mean + beta
else:
bias = beta - gamma_ * mean
else:
gamma_ = 1 / std
weight = weight * gamma_
if conv.bias is not None:
bias = gamma_ * bias - gamma_ * mean
else:
bias = -gamma_ * mean
# 设置新的 weight 和 bias
conv.weight.data = weight
# 适用于bias=none的
if conv.bias is None:
conv.bias = nn.Parameter(bias)
else:
conv.bias.data = bias
return conv
\ No newline at end of file
......@@ -30,9 +30,21 @@ def get_children(model: torch.nn.Module):
# print(flatt_children)
return flatt_children
# 定义获取不包含wrapper的所有子模块的函数
def get_all_child_modules(module):
for name, child in module.named_children():
if isinstance(child, nn.Sequential):
yield from get_all_child_modules(child)
elif len(list(child.children())) > 0:
yield from child.children()
else:
yield child
def filter_fn(module, n_inp, outp_shape):
# if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU,torch.nn.BatchNorm2d,torch.nn.Linear,torch.nn.AdaptiveAvgPool2d)):
if 'conv' in module or 'bn' in module or 'fc' in module or 'avg' in module or 'relu' in module:
return True
return False
if __name__ == "__main__":
......
......@@ -108,16 +108,16 @@ class ResNet(nn.Module):
def fakefreeze(self):
pass
def quantize_inference(self, x):
def quantize_inference(self, x, quant_type):
qx = self.qconvbnrelu1.qi.quantize_tensor(x)
qx = self.qconvbnrelu1.quantize_inference(qx)
qx = self.layer1.quantize_inference(qx)
qx = self.layer2.quantize_inference(qx)
qx = self.layer3.quantize_inference(qx)
qx = self.layer4.quantize_inference(qx)
qx = self.qavgpool1.quantize_inference(qx)
qx = self.qconvbnrelu1.quantize_inference(qx, quant_type)
qx = self.layer1.quantize_inference(qx, quant_type)
qx = self.layer2.quantize_inference(qx, quant_type)
qx = self.layer3.quantize_inference(qx, quant_type)
qx = self.layer4.quantize_inference(qx, quant_type)
qx = self.qavgpool1.quantize_inference(qx, quant_type)
qx = qx.view(qx.size(0), -1)
qx = self.qfc1.quantize_inference(qx)
qx = self.qfc1.quantize_inference(qx, quant_type)
qx = self.qfc1.qo.dequantize_tensor(qx)
......@@ -209,18 +209,18 @@ class BasicBlock(nn.Module):
self.qrelu1.freeze(qi = self.qelementadd.qo)
return self.qrelu1.qi # relu后的qo可用relu统计的qi
def quantize_inference(self, x):
def quantize_inference(self, x, quant_type):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
identity = x
out = self.qconvbnrelu1.quantize_inference(x)
out = self.qconvbn1.quantize_inference(out)
out = self.qconvbnrelu1.quantize_inference(x, quant_type)
out = self.qconvbn1.quantize_inference(out, quant_type)
if self.downsample is not None:
identity = self.qconvbn2.quantize_inference(identity)
identity = self.qconvbn2.quantize_inference(identity, quant_type)
# out = identity + out # 这里可能需要写一个elementwiseadd的变换的,待后续修改
out = self.qelementadd.quantize_inference(out,identity)
out = self.qrelu1.quantize_inference(out)
out = self.qelementadd.quantize_inference(out,identity, quant_type)
out = self.qrelu1.quantize_inference(out, quant_type)
return out
......@@ -318,19 +318,19 @@ class Bottleneck(nn.Module):
self.qrelu1.freeze(qi = self.qelementadd.qo) # 需要自己统计qi
return self.qrelu1.qi # relu后的qo可用relu统计的qi
def quantize_inference(self, x):
def quantize_inference(self, x, quant_type):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
identity = x
out = self.qconvbnrelu1.quantize_inference(x)
out = self.qconvbnrelu2.quantize_inference(out)
out = self.qconvbn1.quantize_inference(out)
out = self.qconvbnrelu1.quantize_inference(x, quant_type)
out = self.qconvbnrelu2.quantize_inference(out, quant_type)
out = self.qconvbn1.quantize_inference(out, quant_type)
if self.downsample is not None:
identity = self.qconvbn2.quantize_inference(identity)
identity = self.qconvbn2.quantize_inference(identity, quant_type)
# out = identity + out # 这里可能需要写一个elementwiseadd的变换的,待后续修改
out = self.qelementadd.quantize_inference(out,identity)
out = self.qrelu1.quantize_inference(out)
out = self.qelementadd.quantize_inference(out,identity, quant_type)
out = self.qrelu1.quantize_inference(out, quant_type)
return out
......@@ -408,10 +408,10 @@ class MakeLayer(nn.Module):
return qo # 供后续的层用
def quantize_inference(self, x):
def quantize_inference(self, x, quant_type):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
for _, layer in self.blockdict.items():
x = layer.quantize_inference(x) # 每个block中有具体的quantize_inference
x = layer.quantize_inference(x, quant_type) # 每个block中有具体的quantize_inference
return x
......
......@@ -105,10 +105,13 @@ class QParam(nn.Module):
self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, self.qmax)
def quantize_tensor(self, tensor):
def quantize_tensor(self, tensor,quant_type = None):
if quant_type == None:
return quantize_tensor(self.quant_type, tensor, self.scale, self.zero_point, self.qmax)
else:
return quantize_tensor(quant_type, tensor, self.scale, self.zero_point, self.qmax)
def dequantize_tensor(self, q_x):
def dequantize_tensor(self, q_x,quant_type = None):
return dequantize_tensor(q_x, self.scale, self.zero_point)
# 该方法保证了可以从state_dict里恢复
......@@ -146,7 +149,7 @@ class QModule(nn.Module):
def freeze(self):
pass # 空语句
def quantize_inference(self, x):
def quantize_inference(self, x, quant_type):
raise NotImplementedError('quantize_inference should be implemented.')
......@@ -219,13 +222,16 @@ class QConv2d(QModule):
return x
# 利用公式 q_a = M(\sigma(q_w-Z_w)(q_x-Z_x) + q_b)
def quantize_inference(self, x): # 此处input为已经量化的qx
def quantize_inference(self, x, quant_type): # 此处input为已经量化的qx
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
# if self.quant_type is not 'POT':
# x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(quant_type,x)
x = x + self.qo.zero_point
return x
......@@ -279,14 +285,15 @@ class QLinear(QModule):
return x
def quantize_inference(self, x):
def quantize_inference(self, x, quant_type):
x = x - self.qi.zero_point
x = self.fc_module(x)
x = self.M * x
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
# if self.quant_type is not 'POT':
# x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(quant_type,x)
x = x + self.qo.zero_point
......@@ -317,7 +324,7 @@ class QReLU(QModule):
return x
def quantize_inference(self, x):
def quantize_inference(self, x, quant_type):
x = x.clone()
# x[x < self.qi.zero_point] = self.qi.zero_point
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
......@@ -351,7 +358,7 @@ class QMaxPooling2d(QModule):
return x
def quantize_inference(self, x):
def quantize_inference(self, x, quant_type):
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
class QConvBNReLU(QModule):
......@@ -457,13 +464,15 @@ class QConvBNReLU(QModule):
return x
def quantize_inference(self, x):
def quantize_inference(self, x, quant_type):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
# if self.quant_type is not 'POT':
# x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(quant_type,x)
x = x + self.qo.zero_point
......@@ -575,14 +584,17 @@ class QConvBN(QModule):
return x
def quantize_inference(self, x):
def quantize_inference(self, x, quant_type):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
# print(self.quant_type)
# if self.quant_type is not 'POT':
# x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(quant_type,x)
x = x + self.qo.zero_point
# x.clamp_(min=0)
......@@ -626,13 +638,14 @@ class QAdaptiveAvgPool2d(QModule):
return x
def quantize_inference(self, x):
def quantize_inference(self, x, quant_type):
x = F.adaptive_avg_pool2d(x,(1, 1)) # 对输入输出都量化一下就算是量化了
x = self.M * x
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
# if self.quant_type is not 'POT':
# x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(quant_type,x)
return x
......@@ -662,7 +675,7 @@ class QModule_2(nn.Module):
def fakefreeze(self):
pass
def quantize_inference(self, x):
def quantize_inference(self, x, quant_type):
raise NotImplementedError('quantize_inference should be implemented.')
......@@ -718,15 +731,16 @@ class QElementwiseAdd(QModule_2):
return x
def quantize_inference(self, x0, x1): # 此处input为已经量化的qx
def quantize_inference(self, x0, x1, quant_type): # 此处input为已经量化的qx
x0 = x0 - self.qi0.zero_point
x1 = x1 - self.qi1.zero_point
x = self.M0 * (x0 + x1*self.M1)
if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x)
# if self.quant_type is not 'POT':
# x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(quant_type,x)
x = x + self.qo.zero_point
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
1575.126077030527 980.8324825038856 447.4871705577316 203.8177281153719 94.1658153206219 44.73944284292641 21.730716696253086 10.687903335080755 5.2935009924434775 2.6865031426677675 1.345978185346981 0.6738058971124082 0.34590930672785625 0.16620132379306904 0.09185943251823848 1575.0264663456858 767.5068295225365 59.80415491853343 17.32189175246257 17.160386413787755 17.15972613238827 17.160655554562823 547.0296470821636 228.09197712606053 153.9307141697144 102.8744121697856 63.04910506966272 11.893784458090247 49.68929151890493 30.72369295281706 4.336553462330601 4.810517948583543 25.62475856077897 16.963161148931942 1.7730239215421446 1.2492962287085048 4.844787354857122 14.21240714817728 10.605240065475499 0.7963437572573967 0.32131797583853794 1.3061700599586734 4.844787523330232
js_param_list:
2231.9475377209037 1458.7430817370525 656.866021106162 290.661557510572 132.0211812900384 62.06574209045005 29.96287022906031 14.768791159744465 7.344364349715033 3.757019554618513 1.896182903527843 0.9241808205303167 0.45857306080932436 0.2269121111102425 0.12261352661167306 2231.901673608193 1143.359470049635 82.82637961696304 24.06635574752677 23.843136397545287 23.842358607630732 23.84306741528584 799.9775130544906 323.8336430582792 218.61973701520765 143.18120884416584 88.72081224892759 16.52024912262558 68.08470436272326 43.20128678260041 6.041579655336327 6.686327875421352 34.6238061335222 24.064747116161215 2.491426419987749 1.7403336017725606 6.690842031928857 18.94797143083834 15.257619881935225 1.0957373786589855 0.44768947355373956 1.7705741794826835 6.690842738428997
ptq_acc_list:
10.0 10.0 10.0 78.52 86.7 89.95 90.73 90.96 90.64 87.4 74.21 52.1 40.65 30.51 20.3 10.0 10.0 10.0 39.21 40.15 44.33 34.83 10.0 19.98 10.0 34.59 85.82 80.56 57.06 88.62 90.17 81.06 68.03 89.75 90.85 88.77 10.0 72.61 90.02 91.08 89.55 10.0 10.0
acc_loss_list:
0.8900978129464776 0.8900978129464776 0.8900978129464776 0.1370480272557424 0.04714803824596101 0.011429827453566238 0.0028574568633914815 0.0003297065611605796 0.0038465765468732203 0.03945488515221441 0.18441586987581055 0.4274096054511484 0.5532476096274316 0.6646884272997032 0.7768985602813496 0.8900978129464776 0.8900978129464776 0.8900978129464776 0.5690735245631388 0.5587427189801077 0.5128036047917354 0.6172106824925816 0.8900978129464776 0.7804154302670623 0.8900978129464776 0.6198483349818661 0.05681943070667109 0.11462798109682375 0.3728981206726013 0.026046818331684696 0.00901197933838876 0.10913287174414764 0.2523354214748873 0.013627871194636718 0.0015386306187493194 0.024398285525881955 0.8900978129464776 0.20200021980437408 0.010660512144191657 -0.0009891196834817388 0.015825914935707196 0.8900978129464776 0.8900978129464776
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
1833.4576454973073 814.7863891368864 217.7654229387627 54.07616924802023 13.731802945455469 3.5847427020530582 0.9118541432904458 0.2622900218848318 0.07627003915874074 0.027745791769400664 0.015915006254486226 0.012409352705166696 0.0077479353538904274 0.0062617873011873975 0.005917287498327866 1833.2003417254284 544.2136113656462 35.21026365121499 33.83804856891729 33.83703344984572 33.83750169488491 33.84147193704756 342.096219925368 82.6043808610444 75.92517125989443 27.82235802343243 26.574672151466128 9.6049355988981 14.044291246668882 14.55114135659603 2.4864347446515884 9.426150750133262 10.07193874315086 10.701781541507756 0.6597191298214471 2.4197650833272104 9.550487563237345 8.849135643873504 9.216705123201871 0.1929881628940372 0.6207588325434388 2.5428780026080493 9.550487563237345
js_param_list:
3613.037168160796 1825.7907466758202 512.0932785883192 129.26071365337654 33.314456921282606 8.673843570791789 2.1826018682118424 0.6138186833325912 0.1691841503982388 0.05180905191755439 0.02266508641878177 0.014530378356803484 0.00975786055068809 0.005063431812688739 0.00398069855228542 3612.992302272399 1246.9340617899438 71.14710558688047 67.61964317269017 67.6172664356203 67.61753548832318 67.6175100773394 755.379587970111 181.41267691267066 170.89087380459807 56.989989927129535 59.371069176236894 19.274735775346528 26.031672719261728 32.363778392002544 5.0043194398511135 18.814548222792805 17.309141148134536 23.84953967534161 1.332034978863292 4.83191046013193 18.864051408815957 14.787650268158211 20.519388091926267 0.3942680972083926 1.231435885110694 4.879394902995963 18.864051408815957
ptq_acc_list:
10.0 10.0 31.15 81.89 84.93 85.69 85.78 85.63 82.63 74.8 51.56 29.34 13.78 11.57 10.17 10.0 10.0 44.45 44.64 46.43 44.18 38.58 9.92 38.85 70.91 65.34 82.3 80.82 73.99 84.14 85.05 76.68 75.95 84.95 85.55 81.54 10.0 77.73 85.18 85.98 81.93 10.01 13.34
acc_loss_list:
0.8835991153532767 0.8835991153532767 0.6374112443254569 0.04679315562798273 0.011407286695378766 0.0025608194622279 0.0015132115004073503 0.003259224770108266 0.038179490164125265 0.1293213828425096 0.3998370387614945 0.6584798044465138 0.8395995809568153 0.8653241764637412 0.8816203003142824 0.8835991153532767 0.8835991153532767 0.48259806774531483 0.4803864509370271 0.45955069258526365 0.4857408916307764 0.5509253870329415 0.8845303224304505 0.5477825631474799 0.17460132697008499 0.2394366197183098 0.04202071935746711 0.05924805028518221 0.1387498544988942 0.02060295658246998 0.010010476079618198 0.10743801652892551 0.11593528110813635 0.011174484926085367 0.004190431847282033 0.05086718659061798 0.8835991153532767 0.09521592364101959 0.008497264579210684 -0.0008148061925271492 0.04632755208939576 0.8834827144686299 0.844721219881271
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
1489.6432790793892 858.47390911721 350.38842997977486 146.66108726257698 65.51871772345022 30.802447738403625 15.015633081763848 7.372939759214539 3.602748170145869 1.7596017349024324 0.9023980469489912 0.42604559053986407 0.2086617904343124 0.11696138076612213 0.06650877266410397 1489.4950585919782 648.6726890563766 44.0313761461945 14.813184296979202 14.70886411284525 14.708637223793453 14.708329981851291 442.110054514285 167.03961080744105 110.912352356486 73.14257321252117 44.75826464643717 8.918392355710786 35.41728607793805 22.00069249787684 3.0807357022322006 4.133106679411769 18.786975210869198 12.291142909976228 1.2420341267864268 1.0780820385160967 4.196963771246701 10.816659177228358 7.780715811154463 0.513961854917128 0.2801788104206083 1.150574461896198 4.196963771246701
js_param_list:
2988.8747567488617 1887.9793935628438 794.5371505720092 330.3960680775245 145.92231495119452 67.9559314292448 33.03981244952361 16.124047122726786 8.021401990398326 3.943098007875918 1.9811299823118427 0.9460539051395199 0.44709418282093033 0.22449034273754867 0.12425914862692854 2988.8363531204886 1451.7681143260804 94.67273844326954 30.460878266197444 30.244231409403923 30.2446749589304 30.244134610251493 984.9086948427197 371.60971497639866 248.5749360354289 159.90777702378026 99.54631101875773 19.048673214252524 75.87671359764475 48.95576239520067 6.683113070521427 8.485231215526596 39.31778320380456 27.44412247810391 2.6854627255413566 2.207580403630901 8.479439151405776 21.80574465505866 17.614834435129385 1.148945392883737 0.5553705895013917 2.2254689905601692 8.479439151405776
ptq_acc_list:
10.0 10.0 10.01 72.69 87.21 89.67 90.45 90.33 89.37 79.82 61.97 35.21 22.84 21.47 13.47 10.0 10.0 12.81 17.49 27.49 30.18 34.97 10.0 15.78 21.89 33.3 82.29 82.49 58.04 87.21 88.9 82.42 67.65 88.34 90.33 87.15 10.05 70.35 89.06 90.52 88.78 9.99 10.0
acc_loss_list:
0.8896369054188279 0.8896369054188279 0.8895265423242467 0.19777066548946035 0.037523452157598565 0.010374130890630148 0.0017658095132987153 0.00309016664827283 0.01368502372806528 0.11908177905308472 0.31607990288047677 0.6114115439796932 0.747930691976603 0.7630504359342236 0.8513409115991613 0.8896369054188279 0.8896369054188279 0.8586248758415186 0.8069749475775301 0.6966118529963581 0.6669241805540227 0.6140602582496413 0.8896369054188279 0.8258470367509105 0.7584151859618143 0.6324908950446971 0.09182209469153507 0.08961483279991177 0.3594525990508774 0.037523452157598565 0.018872089173380353 0.09038737446197989 0.253393665158371 0.025052422469926013 0.00309016664827283 0.03818563072508546 0.8890850899459222 0.22359562962145466 0.017106279660081637 0.0009932678512305862 0.020196446308354467 0.8897472685134091 0.8896369054188279
## update: <br>2023.4.24<br>
在ResNet_nobias中采用了fakefreeze,无需再对PoT进行单独考虑。
## update: <br>2023.4.23<br>
对PoT量化进行单独的考虑,先只是量化权值,看一下权值分布相似度跟acc的关系。然后选一个acc比较高的PoT权值保持不变,再用不同的数据表示去量化激活,这时候看激活的位宽和acc的关系。<br>
因为在PTQ的freeze后,qb.scale已经固定(对bias的量化有直接影响),所有的scale都不方便重新修改,因此INT量化无法被用作激活<br>
FP量化因为受scale影响相对较小,因此可以被试着作为激活量化。具体数据在POT_ptq_result_ResNet50.xlsx和POT_ptq_result_ResNet152.xlsx中。<br>
以FP作为激活的ResNet50,152中的数据有相似的规律,不过目前还没能解释为什么E1经常acc较高,E2,E3降低,E4又回升。
## update: <br>2023.4.17<br>
- 已针对4.12中的问题进行了修正和补充:
......
from model import *
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += lossLayer(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
test_loss, 100. * correct / len(test_loader.dataset)
))
if __name__ == "__main__":
batch_size = 128
test_batch_size = 128
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
momentum = 0.5
save_model = True
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
)
# model = AlexNet_BN().to(device)
model = resnet18().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
if save_model:
if not osp.exists('ckpt'):
os.makedirs('ckpt')
torch.save(model.state_dict(), 'ckpt/cifar10_ResNet18.pt')
\ No newline at end of file
import re
pattern = r"\(\w+(\.\w+)*\)"
# text = "" # 输入的文本字符串
# 从文本中查找匹配的子模块名称列表
matches = re.findall(pattern, text)
# 提取所有子模块路径并存储到一个列表中
submodule_paths = [match.strip("()") for match in matches]
# 输出所有子模块路径
print(submodule_paths)
\ No newline at end of file
import sys
import os
# 从get_param.py输出重定向文件val.txt中提取参数量和计算量
def extract_ratio(md='ResNet18'):
fr = open('param_flops_' + md + '.txt','r')
lines = fr.readlines()
layer = []
par_ratio = []
flop_ratio = []
for line in lines:
# if '(' in line and ')' in line:
if 'Conv' in line or 'BatchNorm2d' in line or 'Linear' in line:
layer.append(line.split(':')[1].split('(')[0])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
par_ratio.append(r1)
r2 = line.split('%')[-2].split(',')[-1]
r2 = float(r2)
flop_ratio.append(r2)
return layer, par_ratio, flop_ratio
if __name__ == "__main__":
layer, par_ratio, flop_ratio = extract_ratio()
print(len(layer))
print(len(par_ratio))
print(len(flop_ratio))
\ No newline at end of file
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
\ No newline at end of file
from model import *
import torch
from ptflops import get_model_complexity_info
import argparse
def get_children(model: torch.nn.Module):
# get children form model!
# 为了后续也能够更新参数,需要用nn.ModuleList来承载
# children = nn.ModuleList(model.children())
# print(children)
# 方便对其中的module进行后续的更新
# flatt_children = nn.ModuleList()
children = list(model.children())
# flatt_children = nn.ModuleList()
flatt_children = []
if len(children) == 0:
# if model has no children; model is last child! :O
return model
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
# print(flatt_children)
return flatt_children
# 定义获取不包含wrapper的所有子模块的函数
def get_all_child_modules(module):
for name, child in module.named_children():
if isinstance(child, nn.Sequential):
yield from get_all_child_modules(child)
elif len(list(child.children())) > 0:
yield from child.children()
else:
yield child
def filter_fn(module, n_inp, outp_shape):
# if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU,torch.nn.BatchNorm2d,torch.nn.Linear,torch.nn.AdaptiveAvgPool2d)):
if 'conv' in module or 'bn' in module or 'fc' in module or 'avg' in module or 'relu' in module:
return True
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Analysis --- params & flops')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
args = parser.parse_args()
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
full_file = 'ckpt/cifar10_' + args.model + '.pt'
model.load_state_dict(torch.load(full_file))
# flat = get_children(model)
# print(flat)
# flat = get_children(model)
# new_model = nn.Sequential(*flat)
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
def get_model_histogram(model):
"""
Description:
- get norm gradients from model, and store in a OrderDict
Args:
- model: (torch.nn.Module), torch model
Returns:
- grads in OrderDict
"""
gradshisto = OrderedDict()
grads = OrderedDict()
for name, params in model.named_parameters():
grad = params.grad
if grad is not None:
tmp = {}
params_np = grad.cpu().numpy()
histogram, bins = np.histogram(params_np.flatten(),bins=20)
tmp['histogram'] = list(histogram)
tmp['bins'] = list(bins)
gradshisto[name] = tmp
grads[name] = params_np
return gradshisto,grads
def get_model_norm_gradient(model):
"""
Description:
- get norm gradients from model, and store in a OrderDict
Args:
- model: (torch.nn.Module), torch model
Returns:
- grads in OrderDict
"""
grads = OrderedDict()
for name, params in model.named_parameters():
grad = params.grad
if grad is not None:
grads[name] = grad.norm().item()
return grads
def get_grad_histogram(grads_sum):
gradshisto = OrderedDict()
# grads = OrderedDict()
for name, params in grads_sum.items():
grad = params
if grad is not None:
tmp = {}
#params_np = grad.cpu().numpy()
params_np = grad
histogram, bins = np.histogram(params_np.flatten(),bins=20)
tmp['histogram'] = list(histogram)
tmp['bins'] = list(bins)
gradshisto[name] = tmp #每层一个histogram (tmp中的是描述直方图的信息)
# grads[name] = params_np
return gradshisto
\ No newline at end of file
class GlobalVariables:
SELF_INPLANES = 0
\ No newline at end of file
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
from model import *
from extract_ratio import *
from utils import *
import argparse
import openpyxl
import os
import os.path as osp
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Loss-Grad Analysis')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
args = parser.parse_args()
wb = openpyxl.Workbook()
ws = wb.active
writer = SummaryWriter(log_dir='log/' + args.model + '/qat_loss_grad')
# layer, par_ratio, flop_ratio = extract_ratio()
layer, par_ratio, flop_ratio = extract_ratio(args.model)
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
layer = []
for name, param in model.named_parameters():
if 'weight' in name:
n = name.split('.')
pre = '.'.join(n[:len(n)-1])
layer.append(pre)
# dir_prefix = 'ckpt/qat/epoch_'
dir_prefix = 'ckpt/qat/'+ args.model + '/'
quant_type_list = ['INT','POT','FLOAT']
for epoch in [5,10,15,20]:
ws_epoch = wb.create_sheet('epoch_%d'%epoch)
full_state = torch.load(dir_prefix+'%d/'%epoch + 'full.pt')
ws_epoch.cell(row=1,column=2,value='loss')
ws_epoch.cell(row=1,column=3,value='loss_sum')
ws_epoch.cell(row=1,column=4,value='loss_avg')
ws_epoch.cell(row=2,column=1,value='FP32')
ws_epoch.cell(row=2,column=2,value=full_state['loss'].cpu().item())
ws_epoch.cell(row=2,column=3,value=full_state['loss_sum'].cpu().item())
ws_epoch.cell(row=2,column=4,value=full_state['loss_avg'].cpu().item())
# full_grad = full_state['grad_dict']
# full_grad_sum = full_state['grad_dict_sum']
full_grad_avg = full_state['grad_dict_avg']
for name,tmpgrad in full_grad_avg.items():
writer.add_histogram('FULL: '+name,tmpgrad,global_step=epoch)
ws_epoch.cell(row=4,column=1,value='title')
ws_epoch.cell(row=4,column=2,value='loss')
ws_epoch.cell(row=4,column=3,value='loss_sum')
ws_epoch.cell(row=4,column=4,value='loss_avg')
ws_epoch.cell(row=4,column=5,value='js_grad_avg_norm')
# ws_epoch.cell(row=4,column=6,value='conv1.weight')
# ws_epoch.cell(row=4,column=7,value='conv1.bias')
# ws_epoch.cell(row=4,column=8,value='conv2.weight')
# ws_epoch.cell(row=4,column=9,value='conv2.bias')
# ws_epoch.cell(row=4,column=10,value='conv3.weight')
# ws_epoch.cell(row=4,column=11,value='conv3.bias')
# ws_epoch.cell(row=4,column=12,value='conv4.weight')
# ws_epoch.cell(row=4,column=13,value='conv4.bias')
# ws_epoch.cell(row=4,column=14,value='conv5.weight')
# ws_epoch.cell(row=4,column=15,value='conv5.bias')
# ws_epoch.cell(row=4,column=16,value='fc1.weight')
# ws_epoch.cell(row=4,column=17,value='fc1.bias')
# ws_epoch.cell(row=4,column=18,value='fc2.weight')
# ws_epoch.cell(row=4,column=19,value='fc2.bias')
# ws_epoch.cell(row=4,column=20,value='fc3.weight')
# ws_epoch.cell(row=4,column=21,value='fc3.bias')
cnt = 5
for n in layer:
cnt = cnt + 1
ws_epoch.cell(row=4,column=cnt,value=n)
currow=4
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nAnalyse: '+title)
currow += 1
qat_state=torch.load(dir_prefix+'%d/'%epoch+title+'.pt')
js_grad_avg_norm=0.
grad_avg = qat_state['grad_dict_avg']
for name,tmpgrad in grad_avg.items():
writer.add_histogram(title+': '+name,tmpgrad,global_step=epoch)
colidx=5
for name,_ in full_grad_avg.items():
prefix = name.split('.')[0]
colidx += 1
layer_idx = layer.index(prefix)
js_norm = js_div_norm(full_grad_avg[name],grad_avg[name])
ws_epoch.cell(row=currow,column=colidx,value=js_norm.cpu().item())
js_grad_avg_norm += flop_ratio[layer_idx] * js_norm
ws_epoch.cell(row=currow,column=1,value=title)
ws_epoch.cell(row=currow,column=2,value=qat_state['loss'].cpu().item())
ws_epoch.cell(row=currow,column=3,value=qat_state['loss_sum'].cpu().item())
ws_epoch.cell(row=currow,column=4,value=qat_state['loss_avg'].cpu().item())
ws_epoch.cell(row=currow,column=5,value=js_grad_avg_norm.cpu().item())
wb.save('loss_grad.xlsx')
writer.close()
\ No newline at end of file
ResNet(
712.67 k, 101.597% Params, 35.92 MMac, 100.000% MACs,
(conv1): Conv2d(432, 0.062% Params, 442.37 KMac, 1.232% MACs, 3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 16.38 KMac, 0.046% MACs, )
(layer1): MakeLayer(
9.34 k, 1.332% Params, 9.63 MMac, 26.822% MACs,
(blockdict): ModuleDict(
9.34 k, 1.332% Params, 9.63 MMac, 26.822% MACs,
(block1): BasicBlock(
4.67 k, 0.666% Params, 4.82 MMac, 13.411% MACs,
(conv1): Conv2d(2.3 k, 0.328% Params, 2.36 MMac, 6.569% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(2.3 k, 0.328% Params, 2.36 MMac, 6.569% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 32.77 KMac, 0.091% MACs, )
)
(block2): BasicBlock(
4.67 k, 0.666% Params, 4.82 MMac, 13.411% MACs,
(conv1): Conv2d(2.3 k, 0.328% Params, 2.36 MMac, 6.569% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(2.3 k, 0.328% Params, 2.36 MMac, 6.569% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 32.77 KMac, 0.091% MACs, )
)
)
)
(layer2): MakeLayer(
33.66 k, 4.799% Params, 8.65 MMac, 24.085% MACs,
(downsample): Sequential(
576, 0.082% Params, 147.46 KMac, 0.411% MACs,
(0): Conv2d(512, 0.073% Params, 131.07 KMac, 0.365% MACs, 16, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.046% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(blockdict): ModuleDict(
33.09 k, 4.717% Params, 8.5 MMac, 23.675% MACs,
(block1): BasicBlock(
14.53 k, 2.071% Params, 3.74 MMac, 10.400% MACs,
(conv1): Conv2d(4.61 k, 0.657% Params, 1.18 MMac, 3.284% MACs, 16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.046% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(9.22 k, 1.314% Params, 2.36 MMac, 6.569% MACs, 32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.046% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 16.38 KMac, 0.046% MACs, )
(downsample): Sequential(
576, 0.082% Params, 147.46 KMac, 0.411% MACs,
(0): Conv2d(512, 0.073% Params, 131.07 KMac, 0.365% MACs, 16, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.046% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block2): BasicBlock(
18.56 k, 2.646% Params, 4.77 MMac, 13.274% MACs,
(conv1): Conv2d(9.22 k, 1.314% Params, 2.36 MMac, 6.569% MACs, 32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.046% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(9.22 k, 1.314% Params, 2.36 MMac, 6.569% MACs, 32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.046% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 16.38 KMac, 0.046% MACs, )
)
)
)
(layer3): MakeLayer(
133.89 k, 19.087% Params, 8.59 MMac, 23.903% MACs,
(downsample): Sequential(
2.18 k, 0.310% Params, 139.26 KMac, 0.388% MACs,
(0): Conv2d(2.05 k, 0.292% Params, 131.07 KMac, 0.365% MACs, 32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(blockdict): ModuleDict(
131.71 k, 18.777% Params, 8.45 MMac, 23.515% MACs,
(block1): BasicBlock(
57.73 k, 8.230% Params, 3.7 MMac, 10.309% MACs,
(conv1): Conv2d(18.43 k, 2.628% Params, 1.18 MMac, 3.284% MACs, 32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(36.86 k, 5.255% Params, 2.36 MMac, 6.569% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 8.19 KMac, 0.023% MACs, )
(downsample): Sequential(
2.18 k, 0.310% Params, 139.26 KMac, 0.388% MACs,
(0): Conv2d(2.05 k, 0.292% Params, 131.07 KMac, 0.365% MACs, 32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block2): BasicBlock(
73.98 k, 10.547% Params, 4.74 MMac, 13.206% MACs,
(conv1): Conv2d(36.86 k, 5.255% Params, 2.36 MMac, 6.569% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(36.86 k, 5.255% Params, 2.36 MMac, 6.569% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 8.19 KMac, 0.023% MACs, )
)
)
)
(layer4): MakeLayer(
534.02 k, 76.129% Params, 8.55 MMac, 23.812% MACs,
(downsample): Sequential(
8.45 k, 1.204% Params, 135.17 KMac, 0.376% MACs,
(0): Conv2d(8.19 k, 1.168% Params, 131.07 KMac, 0.365% MACs, 64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(blockdict): ModuleDict(
525.57 k, 74.924% Params, 8.42 MMac, 23.435% MACs,
(block1): BasicBlock(
230.14 k, 32.809% Params, 3.69 MMac, 10.264% MACs,
(conv1): Conv2d(73.73 k, 10.511% Params, 1.18 MMac, 3.284% MACs, 64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(147.46 k, 21.021% Params, 2.36 MMac, 6.569% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 4.1 KMac, 0.011% MACs, )
(downsample): Sequential(
8.45 k, 1.204% Params, 135.17 KMac, 0.376% MACs,
(0): Conv2d(8.19 k, 1.168% Params, 131.07 KMac, 0.365% MACs, 64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block2): BasicBlock(
295.42 k, 42.115% Params, 4.73 MMac, 13.172% MACs,
(conv1): Conv2d(147.46 k, 21.021% Params, 2.36 MMac, 6.569% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(147.46 k, 21.021% Params, 2.36 MMac, 6.569% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 4.1 KMac, 0.011% MACs, )
)
)
)
(avgpool): AdaptiveAvgPool2d(0, 0.000% Params, 2.05 KMac, 0.006% MACs, output_size=(1, 1))
(fc): Linear(1.29 k, 0.184% Params, 1.29 KMac, 0.004% MACs, in_features=128, out_features=10, bias=True)
)
\ No newline at end of file
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
23.751628696034675 9.970130122904623 1.27369342087768 0.28313082786825217 0.0661330799143482 0.015909788749053112 0.003919561637024714 0.000976824876513856 0.0002478207988959643 6.729750617685151e-05 2.465999254564852e-05 1.4437640662014585e-05 1.753492282816915e-05 1.6118921091019203e-05 1.0942083041102445e-05 19.360958236249306 3.720049772121422 1.203861698176446 1.205551302740704 1.1875784274865653 1.197261990471644 1.187304920096399 3.9054602854075413 1.359654720924604 0.6455689626848736 0.738021827869755 0.21746836151979718 0.24024395039789773 0.5416934888086916 0.1090641843494709 0.062104454894076684 0.24001038871414732 0.4679728318331829 0.07558863061735592 0.015767804836310184 0.062021724817765404 0.24001037755982702 0.43631459469295775 0.06362924454198643 0.003988977252631458 0.015714696763899187 0.06202157326324543 0.24307858155352763
js_param_list:
20.85553941977874 9.981209988545288 1.3539305429512496 0.30139434372520046 0.07038894495212882 0.016983314589243823 0.004172975453783345 0.0010420360112590532 0.00026273648166581423 7.17926315777753e-05 2.4605053810581976e-05 1.317531321555632e-05 1.9623210180810698e-05 1.747944047326041e-05 1.170199482974471e-05 17.390004975267495 3.8124967157643033 1.0178092545455621 1.0125712143846217 0.9928878407387453 1.0125745208871095 0.995760463684502 3.5930026564803716 1.1991610585566574 0.6596486829871696 0.6274086902313186 0.2277369707300968 0.20674715542055963 0.45375203785887974 0.11819179575270293 0.05360085504661807 0.20648789904506443 0.3899066809035698 0.0837931514468961 0.013553482490574238 0.05351237083520507 0.206488208407236 0.3632293653938247 0.07115836515515826 0.003443783278898041 0.01349690316709514 0.0535122559122496 0.20908980476129443
ptq_acc_list:
10.0 10.0 11.86 46.5 75.15 88.16 90.38 90.67 90.81 90.78 90.81 90.82 90.79 90.77 90.77 10.0 10.32 16.21 18.59 17.49 17.34 18.47 13.0 10.63 33.64 21.79 46.89 72.72 36.71 68.56 88.48 76.4 52.7 58.52 90.32 88.99 75.24 51.23 57.34 90.59 90.41 88.78 70.99
acc_loss_list:
0.8898314421064228 0.8898314421064228 0.8693400903382175 0.48771620579486613 0.17208328742976745 0.028753993610223638 0.004296573757849516 0.001101685578935709 -0.0004406742315743776 -0.00011016855789363355 -0.0004406742315743776 -0.0005508427894678545 -0.0002203371157872671 0.0 0.0 0.8898314421064228 0.8863060482538283 0.8214167676545114 0.79519665087584 0.8073151922441336 0.8089677206125372 0.7965186735705629 0.8567808747383496 0.8828908229591275 0.6293929712460063 0.7599427123498953 0.4834196320370166 0.19885424699790677 0.5955712239726781 0.24468436708163485 0.025228599757629085 0.1583122176930703 0.41941169990084826 0.3552935992067863 0.0049575851052110044 0.01961000330505675 0.17109177040872536 0.43560647791120416 0.36829348903822845 0.0019830340420843077 0.003966068084168772 0.021923543020821803 0.21791340751349567
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
17.39068197353568 4.992727157836896 0.7082449333226905 0.15278217922691148 0.035902245013520316 0.008569443717644162 0.0021334381003386507 0.000530096122199641 0.00013202739139047992 4.320540852718224e-05 1.330992682049459e-05 2.0368878546145547e-05 8.560127359019233e-06 1.5414435348480747e-05 1.034017508399343e-05 13.592069723232997 2.603415308985623 1.0163834751275276 1.0004914009898496 1.0267242270943244 1.0377851066533914 1.0093247382340054 2.4882911516734794 0.8719669278759606 0.39824198801738997 0.48712327345618334 0.13437272845045903 0.15283618430614965 0.36366539038391377 0.067001336425203 0.04023162172835145 0.15278158008432788 0.3153947829963702 0.04619076816233325 0.010272439305968415 0.040191841741053155 0.15277940050779587 0.2944708799554367 0.03873975046512612 0.002562897125930476 0.010259629159672047 0.040194168479249566 0.16940471118115202
js_param_list:
13.75589548965823 6.139934297648725 0.9983266223973442 0.21685218331277975 0.05073819961183949 0.012236724314191861 0.0030479968659331298 0.0007642922990751943 0.00018892417520173234 6.914265082028416e-05 2.0654949731546248e-05 1.2782403100869643e-05 9.44194520770214e-06 1.7287066209449194e-05 1.2478257919601581e-05 13.5254986055193 3.463308629604464 1.160512064451701 1.1084529894485604 1.193382464700844 1.2041072939829902 1.2033002903326413 2.4474367498054836 0.7570219992909055 0.5246367584038191 0.3844942455969616 0.18715129322472332 0.13583544042939869 0.2740889994593467 0.10076411331478825 0.03570359182314611 0.13571819932622736 0.2344246202009843 0.07290873555706959 0.008973915201922977 0.03566098435431149 0.13571299092732977 0.21809584818526528 0.0623769948948551 0.00226774688205569 0.00895695770659421 0.03566482115654845 0.18569380791656256
ptq_acc_list:
10.0 9.97 23.88 76.77 87.61 88.95 89.11 89.49 89.46 89.51 89.44 89.49 89.5 89.49 89.48 10.0 10.2 41.62 46.35 41.57 33.14 18.97 11.7 27.9 58.44 63.17 81.54 80.64 71.73 85.63 88.05 80.84 76.92 86.23 88.79 87.8 81.29 79.62 86.98 89.18 89.08 88.03 82.63
acc_loss_list:
0.8882431828341529 0.8885784532856504 0.7331247206079572 0.14204291461779175 0.02089852481001346 0.00592311130978991 0.004135002235136394 -0.00011175681716574548 0.00022351363433180857 -0.00033527045149755406 0.0004470272686634583 -0.00011175681716574548 -0.00022351363433164976 -0.00011175681716574548 0.0 0.8882431828341529 0.8860080464908359 0.5348681269557444 0.4820071524362986 0.5354269110415736 0.6296379079123826 0.787997317836388 0.8692445239159589 0.6881984801072866 0.3468931604827895 0.2940321859633438 0.08873491282968259 0.09879302637460889 0.19836835046937862 0.04302637460885123 0.015981224854716213 0.09655789003129191 0.140366562360304 0.03632096557890031 0.007711220384443425 0.018775145283862392 0.09152883325882875 0.11019222172552524 0.027939204291461777 0.0033527045149753815 0.004470272686633948 0.016204738489047864 0.07655341975860537
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
33.51452234870289 11.491148760988082 1.443232369683574 0.30895904466985846 0.07212736726486271 0.017521603310707724 0.004322850749958015 0.0010654208648101923 0.00027006499577500253 7.161059028231077e-05 2.538806909059987e-05 2.055073568547306e-05 1.620682115453542e-05 1.1986353787221355e-05 9.067183894872267e-06 30.863645293169988 4.784097662511332 1.6646070424398351 1.6585560614298085 1.6573939189141602 1.678128465984322 1.696152544154087 5.78339268163358 2.1114062743040702 0.8425751347317033 1.2052541601543698 0.2767788178149355 0.3748657360498163 0.9081736813090305 0.13299755106486685 0.09626535561783886 0.37471558677670597 0.7918321592327399 0.08953626957132739 0.024398799135561436 0.0962213077547234 0.3747139603715501 0.7410242796157969 0.07442835219519522 0.0061065013202944015 0.024361557223410733 0.09622132719129213 0.3932453051939647
js_param_list:
33.721554105756816 13.269154166033433 1.877153395798168 0.4076554546710941 0.09554313376446293 0.023148392970637945 0.005675039691374743 0.001426682819835179 0.00036345209532231146 9.38985449811094e-05 2.837714142247516e-05 3.191036192167962e-05 2.277355467755515e-05 1.0829856218760036e-05 8.09611746277672e-06 33.383130206500184 5.912517637356709 1.725167775811526 1.6936075456198305 1.7079693950603978 1.745066474153093 1.7771297088256834 6.363561522947833 2.233996327924317 1.0358764351056118 1.245546534684258 0.3518593191648724 0.3809527745802683 0.9287923064091098 0.17743315049465594 0.09900998360709243 0.38070531858108136 0.8063356331707637 0.12348247277252783 0.024949418674407576 0.09893290024757509 0.3807042031089577 0.7535293303127147 0.1040074579869313 0.006337410559870987 0.02488997718210982 0.09893348151517127 0.4220865848581757
ptq_acc_list:
10.0 10.03 14.16 73.84 88.41 89.87 90.3 90.48 90.45 90.54 90.53 90.54 90.58 90.57 90.56 10.0 10.93 22.69 23.64 17.97 24.65 23.01 15.05 15.76 45.01 26.09 75.66 75.82 41.03 84.68 87.95 78.47 49.61 86.32 90.14 88.01 77.2 54.96 87.52 90.17 90.02 88.25 78.85
acc_loss_list:
0.8895759717314488 0.8892446996466431 0.8436395759717314 0.18462897526501765 0.023741166077738577 0.00761925795053001 0.0028710247349823886 0.000883392226148391 0.0012146643109540573 0.00022084805653705853 0.00033127208480566626 0.00022084805653705853 -0.00022084805653705853 -0.0001104240282684508 0.0 0.8895759717314488 0.8793065371024734 0.7494478798586572 0.7389575971731449 0.8015680212014135 0.7278047703180212 0.7459143109540636 0.8338118374558304 0.8259717314487632 0.5029814487632509 0.7119037102473498 0.1645318021201414 0.16276501766784462 0.5469302120141343 0.06492932862190808 0.028820671378091866 0.13350265017667848 0.4521863957597173 0.04681978798586582 0.004637809187279171 0.028158127208480533 0.14752650176678445 0.3931095406360424 0.03356890459363964 0.0043065371024735045 0.005962897526501836 0.02550795053003536 0.12930653710247358
## update: <br>2023.4.24<br>
补充了一些数据和拟合图<br>
尝试将ResNet18,ResNet50,ResNet152,MobileNetV2四个模型的数据点拟合在同一张图上,效果还不错。不过考虑到这四个模型的结构较为相似,暂不确定与其他的结构差异较大的模型的数据点在一起拟合效果如何。
1. ResNet50:
<img src = "fig/50_flops_f.png" class="h-90 auto">
<img src = "fig/50_params_F.png" class="h-90 auto">
2. ResNet152:
<img src = "fig/152_flops_f.png" class="h-90 auto">
<img src = "fig/152_params_f.png" class="h-90 auto">
3. 综合了ResNet18,ResNet50,ResNet152,MobileNetV2的数据点的拟合图
<img src = "fig/total_flops.png" class="h-90 auto">
<img src = "fig/total_params.png" class="h-90 auto">
## update: <br>2023.4.24<br>
- 解决的问题:
1. PoT拟合异常 <br>
2. 曲线拟合效果较差
- 思路记录<br>
在解决MobileNetV2遇到的js散度,曲线拟合问题时,发现了F.normalize对相似度计算的影响很大,其虽然能够帮助把量化后的权值参数调整成与全精度模型相同的scale,但是导致数据分布改变过大,进而导致了出现了反常的关系和拟合结果。具体思路和分析可见MobileNetV2的readme,此处同样也使用了fakefreeze来将量化后的权值参数dequantize,使其与全精度模型的权值参数处于相同的scale,再直接用js散度计算距离。
<br>
在ResNet系列上进行了重新实验,ResNet50,152的数据和曲线待程序运行完后再补充. <br>
ResNet18:
<img src = "fig/18_flops_f.png" class="h-90 auto">
<img src = "fig/18_params_f.png" class="h-90 auto">
可以看到曲线拟合效果有了非常好的提升。
## update: <br>2023.4.23<br>
1. 本文件夹下的ResNet系列模型的Conv层不具有bias,与经典的ResNet一致。但在含有fold BN操作的ConvBN,ConvBNReLU后,Conv会具有bias,在代码中进行了相应修改以适配。
2. ResNet18的无bias的acc比有bias的略高。且无bias的情况下,训练速度明显增快。
3. 修正了一系列小bug
4. 一系列拟合图
(1) 有PoT:
resnet18:
<img src = "fig/18_flops_nobias.png" class="h-90 auto">
<img src = "fig/18_params_nobias.png" class="h-90 auto">
resnet50:
<img src = "fig/50_flops_nobias.png" class="h-90 auto">
<img src = "fig/50_params_nobias.png" class="h-90 auto">
resnet152:
<img src = "fig/152_flops_nobias.png" class="h-90 auto">
<img src = "fig/152_params_nobias.png" class="h-90 auto">
(2) 无PoT:
resnet18:
<img src = "fig/18_flops_nobias_nopot.png" class="h-90 auto">
resnet50:
<img src = "fig/50_flops_nobias_nopot.png" class="h-90 auto">
resnet152:
<img src = "fig/152_flops_nobias_nopot.png" class="h-90 auto">
from model import *
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += lossLayer(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
test_loss, 100. * correct / len(test_loader.dataset)
))
if __name__ == "__main__":
batch_size = 128
test_batch_size = 128
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
momentum = 0.5
save_model = True
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
)
# model = AlexNet_BN().to(device)
model = resnet18().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
if save_model:
if not osp.exists('ckpt'):
os.makedirs('ckpt')
torch.save(model.state_dict(), 'ckpt/cifar10_ResNet18.pt')
\ No newline at end of file
import re
pattern = r"\(\w+(\.\w+)*\)"
# text = "" # 输入的文本字符串
# 从文本中查找匹配的子模块名称列表
matches = re.findall(pattern, text)
# 提取所有子模块路径并存储到一个列表中
submodule_paths = [match.strip("()") for match in matches]
# 输出所有子模块路径
print(submodule_paths)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
e_bit_list = list(range(1,num_bits-1))
else:
e_bit_list = [0]
return e_bit_list
def numbit_list(quant_type):
if quant_type == 'INT':
num_bit_list = list(range(2,17))
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
else:
num_bit_list = list(range(2,9))
# num_bit_list = [8]
return num_bit_list
def build_bias_list(quant_type):
if quant_type == 'POT':
return build_pot_list(8) #
else:
return build_float_list(16,7)
def build_list(quant_type, num_bits, e_bits):
if quant_type == 'POT':
return build_pot_list(num_bits)
else:
return build_float_list(num_bits,e_bits)
def build_pot_list(num_bits):
plist = [0.]
for i in range(-2 ** (num_bits-1) + 2, 1):
# i最高到0,即pot量化最大值为1
plist.append(2. ** i)
plist.append(-2. ** i)
plist = torch.Tensor(list(set(plist)))
# plist = plist.mul(1.0 / torch.max(plist))
return plist
def build_float_list(num_bits,e_bits):
m_bits = num_bits - 1 - e_bits
plist = [0.]
# 相邻尾数的差值
dist_m = 2 ** (-m_bits)
e = -2 ** (e_bits - 1) + 1
for m in range(1, 2 ** m_bits):
frac = m * dist_m # 尾数部分
expo = 2 ** e # 指数部分
flt = frac * expo
plist.append(flt)
plist.append(-flt)
for e in range(-2 ** (e_bits - 1) + 2, 2 ** (e_bits - 1) + 1):
expo = 2 ** e
for m in range(0, 2 ** m_bits):
frac = 1. + m * dist_m
flt = frac * expo
plist.append(flt)
plist.append(-flt)
plist = torch.Tensor(list(set(plist)))
return plist
def fold_ratio(layer, par_ratio, flop_ratio):
idx = -1
for name in layer:
idx = idx + 1
# layer是for name, param in model.named_parameters()中提取出来的,一定是有downsample的
if 'bn' in name or 'sample.1' in name:
par_ratio[idx-1] += par_ratio[idx]
flop_ratio[idx-1] += flop_ratio[idx]
return par_ratio,flop_ratio
def fold_model(model):
idx = -1
module_list = []
# print('fold model:')
for name, module in model.named_modules():
# print(name+'-- +')
idx += 1
module_list.append(module)
# 这里之前忘记考虑downsampl里的conv了,导致少融合了一些
if 'bn' in name or 'sample.1' in name:
# print(name+'-- *')
module_list[idx-1] = fold_bn(module_list[idx-1],module) # 在这里修改了
return model
# def fold_model(model):
# last_conv = None
# last_bn = None
# for name, module in model.named_modules():
# if isinstance(module, nn.Conv2d):
# # 如果当前模块是卷积层,则将其 "fold" 到上一个 BN 层中
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# last_conv = module
# elif isinstance(module, nn.BatchNorm2d):
# # 如果当前模块是 BN 层,则将其 "fold" 到上一个卷积层中
# last_bn = module
# if last_conv is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# # 处理最后一个 BN 层
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# return model
def fold_bn(conv, bn):
# 获取 BN 层的参数
gamma = bn.weight.data
beta = bn.bias.data
mean = bn.running_mean
var = bn.running_var
eps = bn.eps
std = torch.sqrt(var + eps)
feat = bn.num_features
# 获取卷积层的参数
weight = conv.weight.data
if conv.bias is not None:
bias = conv.bias.data
if bn.affine:
gamma_ = gamma / std
weight = weight * gamma_.view(feat, 1, 1, 1)
if conv.bias is not None:
bias = gamma_ * bias - gamma_ * mean + beta
else:
bias = beta - gamma_ * mean
else:
gamma_ = 1 / std
weight = weight * gamma_
if conv.bias is not None:
bias = gamma_ * bias - gamma_ * mean
else:
bias = -gamma_ * mean
# 设置新的 weight 和 bias
conv.weight.data = weight
# 适用于bias=none的
if conv.bias is None:
conv.bias = nn.Parameter(bias)
else:
conv.bias.data = bias
return conv
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment