Commit 8fae747b by Zhihong Ma

feat: ResNet activation quantization change/ ResNet nobias/ MobileNetV2

parent 8352e4cf
import sys
import os
# 从get_param.py输出重定向文件val.txt中提取参数量和计算量
def extract_ratio(md='ResNet18'):
fr = open('param_flops_' + md + '.txt','r')
lines = fr.readlines()
layer = []
par_ratio = []
flop_ratio = []
for line in lines:
# if '(' in line and ')' in line:
if 'Conv' in line or 'BatchNorm2d' in line or 'Linear' in line:
layer.append(line.split(':')[1].split('(')[0])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
par_ratio.append(r1)
r2 = line.split('%')[-2].split(',')[-1]
r2 = float(r2)
flop_ratio.append(r2)
return layer, par_ratio, flop_ratio
if __name__ == "__main__":
layer, par_ratio, flop_ratio = extract_ratio()
print(len(layer))
print(len(par_ratio))
print(len(flop_ratio))
\ No newline at end of file
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
\ No newline at end of file
from model import *
import torch
from ptflops import get_model_complexity_info
import argparse
def get_children(model: torch.nn.Module):
# get children form model!
# 为了后续也能够更新参数,需要用nn.ModuleList来承载
# children = nn.ModuleList(model.children())
# print(children)
# 方便对其中的module进行后续的更新
# flatt_children = nn.ModuleList()
children = list(model.children())
# flatt_children = nn.ModuleList()
flatt_children = []
if len(children) == 0:
# if model has no children; model is last child! :O
return model
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
# print(flatt_children)
return flatt_children
# 定义获取不包含wrapper的所有子模块的函数
def get_all_child_modules(module):
for name, child in module.named_children():
if isinstance(child, nn.Sequential):
yield from get_all_child_modules(child)
elif len(list(child.children())) > 0:
yield from child.children()
else:
yield child
def filter_fn(module, n_inp, outp_shape):
# if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU,torch.nn.BatchNorm2d,torch.nn.Linear,torch.nn.AdaptiveAvgPool2d)):
if 'conv' in module or 'bn' in module or 'fc' in module or 'avg' in module or 'relu' in module:
return True
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Analysis --- params & flops')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='MobileNetV2')
args = parser.parse_args()
# if args.model == 'ResNet18':
# model = resnet18()
# elif args.model == 'ResNet50':
# model = resnet50()
# elif args.model == 'ResNet152':
# model = resnet152()
if args.model == 'MobileNetV2':
model = MobileNetV2()
full_file = 'ckpt/cifar10_' + args.model + '.pt'
model.load_state_dict(torch.load(full_file))
# flat = get_children(model)
# print(flat)
# flat = get_children(model)
# new_model = nn.Sequential(*flat)
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
def get_model_histogram(model):
"""
Description:
- get norm gradients from model, and store in a OrderDict
Args:
- model: (torch.nn.Module), torch model
Returns:
- grads in OrderDict
"""
gradshisto = OrderedDict()
grads = OrderedDict()
for name, params in model.named_parameters():
grad = params.grad
if grad is not None:
tmp = {}
params_np = grad.cpu().numpy()
histogram, bins = np.histogram(params_np.flatten(),bins=20)
tmp['histogram'] = list(histogram)
tmp['bins'] = list(bins)
gradshisto[name] = tmp
grads[name] = params_np
return gradshisto,grads
def get_model_norm_gradient(model):
"""
Description:
- get norm gradients from model, and store in a OrderDict
Args:
- model: (torch.nn.Module), torch model
Returns:
- grads in OrderDict
"""
grads = OrderedDict()
for name, params in model.named_parameters():
grad = params.grad
if grad is not None:
grads[name] = grad.norm().item()
return grads
def get_grad_histogram(grads_sum):
gradshisto = OrderedDict()
# grads = OrderedDict()
for name, params in grads_sum.items():
grad = params
if grad is not None:
tmp = {}
#params_np = grad.cpu().numpy()
params_np = grad
histogram, bins = np.histogram(params_np.flatten(),bins=20)
tmp['histogram'] = list(histogram)
tmp['bins'] = list(bins)
gradshisto[name] = tmp #每层一个histogram (tmp中的是描述直方图的信息)
# grads[name] = params_np
return gradshisto
\ No newline at end of file
class GlobalVariables:
SELF_INPLANES = 0
\ No newline at end of file
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
from model import *
from extract_ratio import *
from utils import *
import argparse
import openpyxl
import os
import os.path as osp
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Loss-Grad Analysis')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
args = parser.parse_args()
wb = openpyxl.Workbook()
ws = wb.active
writer = SummaryWriter(log_dir='log/' + args.model + '/qat_loss_grad')
# layer, par_ratio, flop_ratio = extract_ratio()
layer, par_ratio, flop_ratio = extract_ratio(args.model)
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
layer = []
for name, param in model.named_parameters():
if 'weight' in name:
n = name.split('.')
pre = '.'.join(n[:len(n)-1])
layer.append(pre)
# dir_prefix = 'ckpt/qat/epoch_'
dir_prefix = 'ckpt/qat/'+ args.model + '/'
quant_type_list = ['INT','POT','FLOAT']
for epoch in [5,10,15,20]:
ws_epoch = wb.create_sheet('epoch_%d'%epoch)
full_state = torch.load(dir_prefix+'%d/'%epoch + 'full.pt')
ws_epoch.cell(row=1,column=2,value='loss')
ws_epoch.cell(row=1,column=3,value='loss_sum')
ws_epoch.cell(row=1,column=4,value='loss_avg')
ws_epoch.cell(row=2,column=1,value='FP32')
ws_epoch.cell(row=2,column=2,value=full_state['loss'].cpu().item())
ws_epoch.cell(row=2,column=3,value=full_state['loss_sum'].cpu().item())
ws_epoch.cell(row=2,column=4,value=full_state['loss_avg'].cpu().item())
# full_grad = full_state['grad_dict']
# full_grad_sum = full_state['grad_dict_sum']
full_grad_avg = full_state['grad_dict_avg']
for name,tmpgrad in full_grad_avg.items():
writer.add_histogram('FULL: '+name,tmpgrad,global_step=epoch)
ws_epoch.cell(row=4,column=1,value='title')
ws_epoch.cell(row=4,column=2,value='loss')
ws_epoch.cell(row=4,column=3,value='loss_sum')
ws_epoch.cell(row=4,column=4,value='loss_avg')
ws_epoch.cell(row=4,column=5,value='js_grad_avg_norm')
# ws_epoch.cell(row=4,column=6,value='conv1.weight')
# ws_epoch.cell(row=4,column=7,value='conv1.bias')
# ws_epoch.cell(row=4,column=8,value='conv2.weight')
# ws_epoch.cell(row=4,column=9,value='conv2.bias')
# ws_epoch.cell(row=4,column=10,value='conv3.weight')
# ws_epoch.cell(row=4,column=11,value='conv3.bias')
# ws_epoch.cell(row=4,column=12,value='conv4.weight')
# ws_epoch.cell(row=4,column=13,value='conv4.bias')
# ws_epoch.cell(row=4,column=14,value='conv5.weight')
# ws_epoch.cell(row=4,column=15,value='conv5.bias')
# ws_epoch.cell(row=4,column=16,value='fc1.weight')
# ws_epoch.cell(row=4,column=17,value='fc1.bias')
# ws_epoch.cell(row=4,column=18,value='fc2.weight')
# ws_epoch.cell(row=4,column=19,value='fc2.bias')
# ws_epoch.cell(row=4,column=20,value='fc3.weight')
# ws_epoch.cell(row=4,column=21,value='fc3.bias')
cnt = 5
for n in layer:
cnt = cnt + 1
ws_epoch.cell(row=4,column=cnt,value=n)
currow=4
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nAnalyse: '+title)
currow += 1
qat_state=torch.load(dir_prefix+'%d/'%epoch+title+'.pt')
js_grad_avg_norm=0.
grad_avg = qat_state['grad_dict_avg']
for name,tmpgrad in grad_avg.items():
writer.add_histogram(title+': '+name,tmpgrad,global_step=epoch)
colidx=5
for name,_ in full_grad_avg.items():
prefix = name.split('.')[0]
colidx += 1
layer_idx = layer.index(prefix)
js_norm = js_div_norm(full_grad_avg[name],grad_avg[name])
ws_epoch.cell(row=currow,column=colidx,value=js_norm.cpu().item())
js_grad_avg_norm += flop_ratio[layer_idx] * js_norm
ws_epoch.cell(row=currow,column=1,value=title)
ws_epoch.cell(row=currow,column=2,value=qat_state['loss'].cpu().item())
ws_epoch.cell(row=currow,column=3,value=qat_state['loss_sum'].cpu().item())
ws_epoch.cell(row=currow,column=4,value=qat_state['loss_avg'].cpu().item())
ws_epoch.cell(row=currow,column=5,value=js_grad_avg_norm.cpu().item())
wb.save('loss_grad.xlsx')
writer.close()
\ No newline at end of file
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
3937.841102576776 3363.8190205631026 2484.282190565725 1666.2525440043232 1004.491228033656 619.2894405795665 418.1602230520327 312.24372234629664 249.45408998160573 193.16718609999484 138.595522944474 91.58500300086834 56.08617073690407 31.952073152774126 17.65229708562858 3937.814777378682 3068.5487659326095 613.4307511092446 32.126779884075525 0.1766064000339766 0.17675544226821238 0.17669672446190357 2644.0738554237805 1741.4374062424974 1207.0395601708549 1028.5434609756649 668.1556835852132 273.1682252338685 625.3801636402046 430.06294426631223 203.6193231190563 6.546919210669541 419.88446951149035 315.44088340877573 143.55644125961152 2.884048071197845 0.05101264315330961 312.7100494425904 250.59377938308924 93.48144767534725 1.3453977569799638 0.013982248527794 0.10109822916201198
js_param_list:
5523.5133138619085 4892.5106928859905 3857.644315893276 2707.0379699110163 1671.0773218634195 1050.057190366859 723.4234749265271 550.470391030846 445.310410225667 347.5339012995234 250.31388514236687 165.8204601056395 101.56282095527794 57.86236490894071 31.945254415161216 5523.499968355952 4553.489563739368 1040.3277683793247 57.8656798626289 0.006984618427447457 0.006992355997222603 0.007052672537742183 4053.9039848113193 2820.767118173691 1991.9978400778782 1709.3926174557196 1129.0122760496113 485.4325026388394 1059.7935852519893 742.5316942830176 365.9667180995057 11.782087473714286 726.0475165732927 555.5680608520271 259.1510147521308 5.1897931543416576 0.0014493051074338339 551.1504206498132 447.20405741295923 169.2517013866318 2.4155090345624393 0.000397355594218542 0.03827793516377293
ptq_acc_list:
10.0 10.0 10.44 37.82 81.77 88.99 91.31 91.84 91.96 91.9 91.93 91.88 91.87 91.88 91.89 10.0 10.0 14.09 19.53 11.56 11.34 17.05 10.01 10.36 20.89 25.1 62.06 79.8 45.72 81.59 89.76 79.92 71.56 85.92 91.67 90.2 81.46 73.66 87.45 91.65 91.51 90.24 81.11
acc_loss_list:
0.8914812805208898 0.8914812805208898 0.886706456863809 0.5895822029300054 0.11264243081931644 0.034291915355398925 0.009115572436245289 0.003364080303852439 0.002061855670103222 0.0027129679869777536 0.0023874118285404106 0.002930005425936085 0.003038524145415096 0.002930005425936085 0.0028214867064569192 0.8914812805208898 0.8914812805208898 0.8470971242539338 0.7880629408572979 0.8745523602821487 0.8769397721106891 0.8149755832881173 0.8913727618014107 0.8875746066196419 0.773304395008139 0.7276180141074337 0.32653282691264246 0.1340206185567011 0.5038524145415084 0.11459576776994033 0.02593597395550733 0.13271839392295173 0.22344004340748783 0.06760716223548566 0.00520889853499733 0.02116115029842651 0.11600651112316887 0.20065111231687474 0.051003798155181794 0.005425935973955507 0.006945198046663055 0.020727075420510156 0.11980466630493766
from model import *
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += lossLayer(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
test_loss, 100. * correct / len(test_loader.dataset)
))
if __name__ == "__main__":
batch_size = 128
test_batch_size = 128
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
momentum = 0.5
save_model = True
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
)
# model = AlexNet_BN().to(device)
model = resnet18().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
if save_model:
if not osp.exists('ckpt'):
os.makedirs('ckpt')
torch.save(model.state_dict(), 'ckpt/cifar10_ResNet18.pt')
\ No newline at end of file
from model import *
model = MobileNetV2()
model.quantize('INT',8,0)
# for name, module in model.named_modules():
# print(name)
# print('==============================')
# for name, param in model.named_parameters():
# print(name)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
e_bit_list = list(range(1,num_bits-1))
else:
e_bit_list = [0]
return e_bit_list
def numbit_list(quant_type):
if quant_type == 'INT':
num_bit_list = list(range(2,17))
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
else:
num_bit_list = list(range(2,9))
# num_bit_list = [8]
return num_bit_list
def build_bias_list(quant_type):
if quant_type == 'POT':
return build_pot_list(8) #
else:
return build_float_list(16,7)
def build_list(quant_type, num_bits, e_bits):
if quant_type == 'POT':
return build_pot_list(num_bits)
else:
return build_float_list(num_bits,e_bits)
def build_pot_list(num_bits):
plist = [0.]
for i in range(-2 ** (num_bits-1) + 2, 1):
# i最高到0,即pot量化最大值为1
plist.append(2. ** i)
plist.append(-2. ** i)
plist = torch.Tensor(list(set(plist)))
# plist = plist.mul(1.0 / torch.max(plist))
return plist
def build_float_list(num_bits,e_bits):
m_bits = num_bits - 1 - e_bits
plist = [0.]
# 相邻尾数的差值
dist_m = 2 ** (-m_bits)
e = -2 ** (e_bits - 1) + 1
for m in range(1, 2 ** m_bits):
frac = m * dist_m # 尾数部分
expo = 2 ** e # 指数部分
flt = frac * expo
plist.append(flt)
plist.append(-flt)
for e in range(-2 ** (e_bits - 1) + 2, 2 ** (e_bits - 1) + 1):
expo = 2 ** e
for m in range(0, 2 ** m_bits):
frac = 1. + m * dist_m
flt = frac * expo
plist.append(flt)
plist.append(-flt)
plist = torch.Tensor(list(set(plist)))
return plist
def fold_ratio(layer, par_ratio, flop_ratio):
idx = -1
for name in layer:
idx = idx + 1
# layer是for name, param in model.named_parameters()中提取出来的,一定是有downsample的
if 'bn' in name or 'sample.1' in name:
par_ratio[idx-1] += par_ratio[idx]
flop_ratio[idx-1] += flop_ratio[idx]
return par_ratio,flop_ratio
def fold_model(model):
idx = -1
module_list = []
# print('fold model:')
for name, module in model.named_modules():
# print(name+'-- +')
idx += 1
module_list.append(module)
# 这里之前忘记考虑downsampl里的conv了,导致少融合了一些
if 'bn' in name or 'sample.1' in name:
# print(name+'-- *')
module_list[idx-1] = fold_bn(module_list[idx-1],module) # 在这里修改了
return model
# def fold_model(model):
# last_conv = None
# last_bn = None
# for name, module in model.named_modules():
# if isinstance(module, nn.Conv2d):
# # 如果当前模块是卷积层,则将其 "fold" 到上一个 BN 层中
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# last_conv = module
# elif isinstance(module, nn.BatchNorm2d):
# # 如果当前模块是 BN 层,则将其 "fold" 到上一个卷积层中
# last_bn = module
# if last_conv is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# # 处理最后一个 BN 层
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# return model
def fold_bn(conv, bn):
# 获取 BN 层的参数
gamma = bn.weight.data
beta = bn.bias.data
mean = bn.running_mean
var = bn.running_var
eps = bn.eps
std = torch.sqrt(var + eps)
feat = bn.num_features
# 获取卷积层的参数
weight = conv.weight.data
if conv.bias is not None:
bias = conv.bias.data
if bn.affine:
gamma_ = gamma / std
weight = weight * gamma_.view(feat, 1, 1, 1)
if conv.bias is not None:
bias = gamma_ * bias - gamma_ * mean + beta
else:
bias = beta - gamma_ * mean
else:
gamma_ = 1 / std
weight = weight * gamma_
if conv.bias is not None:
bias = gamma_ * bias - gamma_ * mean
else:
bias = -gamma_ * mean
# 设置新的 weight 和 bias
conv.weight.data = weight
# 适用于bias=none的
if conv.bias is None:
conv.bias = nn.Parameter(bias)
else:
conv.bias.data = bias
return conv
\ No newline at end of file
...@@ -30,9 +30,21 @@ def get_children(model: torch.nn.Module): ...@@ -30,9 +30,21 @@ def get_children(model: torch.nn.Module):
# print(flatt_children) # print(flatt_children)
return flatt_children return flatt_children
# 定义获取不包含wrapper的所有子模块的函数
def get_all_child_modules(module):
for name, child in module.named_children():
if isinstance(child, nn.Sequential):
yield from get_all_child_modules(child)
elif len(list(child.children())) > 0:
yield from child.children()
else:
yield child
def filter_fn(module, n_inp, outp_shape):
# if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU,torch.nn.BatchNorm2d,torch.nn.Linear,torch.nn.AdaptiveAvgPool2d)):
if 'conv' in module or 'bn' in module or 'fc' in module or 'avg' in module or 'relu' in module:
return True
return False
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -108,16 +108,16 @@ class ResNet(nn.Module): ...@@ -108,16 +108,16 @@ class ResNet(nn.Module):
def fakefreeze(self): def fakefreeze(self):
pass pass
def quantize_inference(self, x): def quantize_inference(self, x, quant_type):
qx = self.qconvbnrelu1.qi.quantize_tensor(x) qx = self.qconvbnrelu1.qi.quantize_tensor(x)
qx = self.qconvbnrelu1.quantize_inference(qx) qx = self.qconvbnrelu1.quantize_inference(qx, quant_type)
qx = self.layer1.quantize_inference(qx) qx = self.layer1.quantize_inference(qx, quant_type)
qx = self.layer2.quantize_inference(qx) qx = self.layer2.quantize_inference(qx, quant_type)
qx = self.layer3.quantize_inference(qx) qx = self.layer3.quantize_inference(qx, quant_type)
qx = self.layer4.quantize_inference(qx) qx = self.layer4.quantize_inference(qx, quant_type)
qx = self.qavgpool1.quantize_inference(qx) qx = self.qavgpool1.quantize_inference(qx, quant_type)
qx = qx.view(qx.size(0), -1) qx = qx.view(qx.size(0), -1)
qx = self.qfc1.quantize_inference(qx) qx = self.qfc1.quantize_inference(qx, quant_type)
qx = self.qfc1.qo.dequantize_tensor(qx) qx = self.qfc1.qo.dequantize_tensor(qx)
...@@ -209,18 +209,18 @@ class BasicBlock(nn.Module): ...@@ -209,18 +209,18 @@ class BasicBlock(nn.Module):
self.qrelu1.freeze(qi = self.qelementadd.qo) self.qrelu1.freeze(qi = self.qelementadd.qo)
return self.qrelu1.qi # relu后的qo可用relu统计的qi return self.qrelu1.qi # relu后的qo可用relu统计的qi
def quantize_inference(self, x): def quantize_inference(self, x, quant_type):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。 # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
identity = x identity = x
out = self.qconvbnrelu1.quantize_inference(x) out = self.qconvbnrelu1.quantize_inference(x, quant_type)
out = self.qconvbn1.quantize_inference(out) out = self.qconvbn1.quantize_inference(out, quant_type)
if self.downsample is not None: if self.downsample is not None:
identity = self.qconvbn2.quantize_inference(identity) identity = self.qconvbn2.quantize_inference(identity, quant_type)
# out = identity + out # 这里可能需要写一个elementwiseadd的变换的,待后续修改 # out = identity + out # 这里可能需要写一个elementwiseadd的变换的,待后续修改
out = self.qelementadd.quantize_inference(out,identity) out = self.qelementadd.quantize_inference(out,identity, quant_type)
out = self.qrelu1.quantize_inference(out) out = self.qrelu1.quantize_inference(out, quant_type)
return out return out
...@@ -318,19 +318,19 @@ class Bottleneck(nn.Module): ...@@ -318,19 +318,19 @@ class Bottleneck(nn.Module):
self.qrelu1.freeze(qi = self.qelementadd.qo) # 需要自己统计qi self.qrelu1.freeze(qi = self.qelementadd.qo) # 需要自己统计qi
return self.qrelu1.qi # relu后的qo可用relu统计的qi return self.qrelu1.qi # relu后的qo可用relu统计的qi
def quantize_inference(self, x): def quantize_inference(self, x, quant_type):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。 # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
identity = x identity = x
out = self.qconvbnrelu1.quantize_inference(x) out = self.qconvbnrelu1.quantize_inference(x, quant_type)
out = self.qconvbnrelu2.quantize_inference(out) out = self.qconvbnrelu2.quantize_inference(out, quant_type)
out = self.qconvbn1.quantize_inference(out) out = self.qconvbn1.quantize_inference(out, quant_type)
if self.downsample is not None: if self.downsample is not None:
identity = self.qconvbn2.quantize_inference(identity) identity = self.qconvbn2.quantize_inference(identity, quant_type)
# out = identity + out # 这里可能需要写一个elementwiseadd的变换的,待后续修改 # out = identity + out # 这里可能需要写一个elementwiseadd的变换的,待后续修改
out = self.qelementadd.quantize_inference(out,identity) out = self.qelementadd.quantize_inference(out,identity, quant_type)
out = self.qrelu1.quantize_inference(out) out = self.qrelu1.quantize_inference(out, quant_type)
return out return out
...@@ -408,10 +408,10 @@ class MakeLayer(nn.Module): ...@@ -408,10 +408,10 @@ class MakeLayer(nn.Module):
return qo # 供后续的层用 return qo # 供后续的层用
def quantize_inference(self, x): def quantize_inference(self, x, quant_type):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。 # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
for _, layer in self.blockdict.items(): for _, layer in self.blockdict.items():
x = layer.quantize_inference(x) # 每个block中有具体的quantize_inference x = layer.quantize_inference(x, quant_type) # 每个block中有具体的quantize_inference
return x return x
......
...@@ -105,10 +105,13 @@ class QParam(nn.Module): ...@@ -105,10 +105,13 @@ class QParam(nn.Module):
self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, self.qmax) self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, self.qmax)
def quantize_tensor(self, tensor): def quantize_tensor(self, tensor,quant_type = None):
if quant_type == None:
return quantize_tensor(self.quant_type, tensor, self.scale, self.zero_point, self.qmax) return quantize_tensor(self.quant_type, tensor, self.scale, self.zero_point, self.qmax)
else:
return quantize_tensor(quant_type, tensor, self.scale, self.zero_point, self.qmax)
def dequantize_tensor(self, q_x): def dequantize_tensor(self, q_x,quant_type = None):
return dequantize_tensor(q_x, self.scale, self.zero_point) return dequantize_tensor(q_x, self.scale, self.zero_point)
# 该方法保证了可以从state_dict里恢复 # 该方法保证了可以从state_dict里恢复
...@@ -146,7 +149,7 @@ class QModule(nn.Module): ...@@ -146,7 +149,7 @@ class QModule(nn.Module):
def freeze(self): def freeze(self):
pass # 空语句 pass # 空语句
def quantize_inference(self, x): def quantize_inference(self, x, quant_type):
raise NotImplementedError('quantize_inference should be implemented.') raise NotImplementedError('quantize_inference should be implemented.')
...@@ -219,13 +222,16 @@ class QConv2d(QModule): ...@@ -219,13 +222,16 @@ class QConv2d(QModule):
return x return x
# 利用公式 q_a = M(\sigma(q_w-Z_w)(q_x-Z_x) + q_b) # 利用公式 q_a = M(\sigma(q_w-Z_w)(q_x-Z_x) + q_b)
def quantize_inference(self, x): # 此处input为已经量化的qx def quantize_inference(self, x, quant_type): # 此处input为已经量化的qx
x = x - self.qi.zero_point x = x - self.qi.zero_point
x = self.conv_module(x) x = self.conv_module(x)
x = self.M * x x = self.M * x
if self.quant_type is not 'POT': # if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x) # x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(quant_type,x)
x = x + self.qo.zero_point x = x + self.qo.zero_point
return x return x
...@@ -279,14 +285,15 @@ class QLinear(QModule): ...@@ -279,14 +285,15 @@ class QLinear(QModule):
return x return x
def quantize_inference(self, x): def quantize_inference(self, x, quant_type):
x = x - self.qi.zero_point x = x - self.qi.zero_point
x = self.fc_module(x) x = self.fc_module(x)
x = self.M * x x = self.M * x
if self.quant_type is not 'POT': # if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x) # x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(quant_type,x)
x = x + self.qo.zero_point x = x + self.qo.zero_point
...@@ -317,7 +324,7 @@ class QReLU(QModule): ...@@ -317,7 +324,7 @@ class QReLU(QModule):
return x return x
def quantize_inference(self, x): def quantize_inference(self, x, quant_type):
x = x.clone() x = x.clone()
# x[x < self.qi.zero_point] = self.qi.zero_point # x[x < self.qi.zero_point] = self.qi.zero_point
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
...@@ -351,7 +358,7 @@ class QMaxPooling2d(QModule): ...@@ -351,7 +358,7 @@ class QMaxPooling2d(QModule):
return x return x
def quantize_inference(self, x): def quantize_inference(self, x, quant_type):
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding) return F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
class QConvBNReLU(QModule): class QConvBNReLU(QModule):
...@@ -457,13 +464,15 @@ class QConvBNReLU(QModule): ...@@ -457,13 +464,15 @@ class QConvBNReLU(QModule):
return x return x
def quantize_inference(self, x): def quantize_inference(self, x, quant_type):
x = x - self.qi.zero_point x = x - self.qi.zero_point
x = self.conv_module(x) x = self.conv_module(x)
x = self.M * x x = self.M * x
if self.quant_type is not 'POT': # if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x) # x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(quant_type,x)
x = x + self.qo.zero_point x = x + self.qo.zero_point
...@@ -575,14 +584,17 @@ class QConvBN(QModule): ...@@ -575,14 +584,17 @@ class QConvBN(QModule):
return x return x
def quantize_inference(self, x): def quantize_inference(self, x, quant_type):
x = x - self.qi.zero_point x = x - self.qi.zero_point
x = self.conv_module(x) x = self.conv_module(x)
x = self.M * x x = self.M * x
if self.quant_type is not 'POT': # print(self.quant_type)
x = get_nearest_val(self.quant_type,x)
# if self.quant_type is not 'POT':
# x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(quant_type,x)
x = x + self.qo.zero_point x = x + self.qo.zero_point
# x.clamp_(min=0) # x.clamp_(min=0)
...@@ -626,13 +638,14 @@ class QAdaptiveAvgPool2d(QModule): ...@@ -626,13 +638,14 @@ class QAdaptiveAvgPool2d(QModule):
return x return x
def quantize_inference(self, x): def quantize_inference(self, x, quant_type):
x = F.adaptive_avg_pool2d(x,(1, 1)) # 对输入输出都量化一下就算是量化了 x = F.adaptive_avg_pool2d(x,(1, 1)) # 对输入输出都量化一下就算是量化了
x = self.M * x x = self.M * x
if self.quant_type is not 'POT': # if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x) # x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(quant_type,x)
return x return x
...@@ -662,7 +675,7 @@ class QModule_2(nn.Module): ...@@ -662,7 +675,7 @@ class QModule_2(nn.Module):
def fakefreeze(self): def fakefreeze(self):
pass pass
def quantize_inference(self, x): def quantize_inference(self, x, quant_type):
raise NotImplementedError('quantize_inference should be implemented.') raise NotImplementedError('quantize_inference should be implemented.')
...@@ -718,15 +731,16 @@ class QElementwiseAdd(QModule_2): ...@@ -718,15 +731,16 @@ class QElementwiseAdd(QModule_2):
return x return x
def quantize_inference(self, x0, x1): # 此处input为已经量化的qx def quantize_inference(self, x0, x1, quant_type): # 此处input为已经量化的qx
x0 = x0 - self.qi0.zero_point x0 = x0 - self.qi0.zero_point
x1 = x1 - self.qi1.zero_point x1 = x1 - self.qi1.zero_point
x = self.M0 * (x0 + x1*self.M1) x = self.M0 * (x0 + x1*self.M1)
if self.quant_type is not 'POT': # if self.quant_type is not 'POT':
x = get_nearest_val(self.quant_type,x) # x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(quant_type,x)
x = x + self.qo.zero_point x = x + self.qo.zero_point
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
1575.126077030527 980.8324825038856 447.4871705577316 203.8177281153719 94.1658153206219 44.73944284292641 21.730716696253086 10.687903335080755 5.2935009924434775 2.6865031426677675 1.345978185346981 0.6738058971124082 0.34590930672785625 0.16620132379306904 0.09185943251823848 1575.0264663456858 767.5068295225365 59.80415491853343 17.32189175246257 17.160386413787755 17.15972613238827 17.160655554562823 547.0296470821636 228.09197712606053 153.9307141697144 102.8744121697856 63.04910506966272 11.893784458090247 49.68929151890493 30.72369295281706 4.336553462330601 4.810517948583543 25.62475856077897 16.963161148931942 1.7730239215421446 1.2492962287085048 4.844787354857122 14.21240714817728 10.605240065475499 0.7963437572573967 0.32131797583853794 1.3061700599586734 4.844787523330232
js_param_list:
2231.9475377209037 1458.7430817370525 656.866021106162 290.661557510572 132.0211812900384 62.06574209045005 29.96287022906031 14.768791159744465 7.344364349715033 3.757019554618513 1.896182903527843 0.9241808205303167 0.45857306080932436 0.2269121111102425 0.12261352661167306 2231.901673608193 1143.359470049635 82.82637961696304 24.06635574752677 23.843136397545287 23.842358607630732 23.84306741528584 799.9775130544906 323.8336430582792 218.61973701520765 143.18120884416584 88.72081224892759 16.52024912262558 68.08470436272326 43.20128678260041 6.041579655336327 6.686327875421352 34.6238061335222 24.064747116161215 2.491426419987749 1.7403336017725606 6.690842031928857 18.94797143083834 15.257619881935225 1.0957373786589855 0.44768947355373956 1.7705741794826835 6.690842738428997
ptq_acc_list:
10.0 10.0 10.0 78.52 86.7 89.95 90.73 90.96 90.64 87.4 74.21 52.1 40.65 30.51 20.3 10.0 10.0 10.0 39.21 40.15 44.33 34.83 10.0 19.98 10.0 34.59 85.82 80.56 57.06 88.62 90.17 81.06 68.03 89.75 90.85 88.77 10.0 72.61 90.02 91.08 89.55 10.0 10.0
acc_loss_list:
0.8900978129464776 0.8900978129464776 0.8900978129464776 0.1370480272557424 0.04714803824596101 0.011429827453566238 0.0028574568633914815 0.0003297065611605796 0.0038465765468732203 0.03945488515221441 0.18441586987581055 0.4274096054511484 0.5532476096274316 0.6646884272997032 0.7768985602813496 0.8900978129464776 0.8900978129464776 0.8900978129464776 0.5690735245631388 0.5587427189801077 0.5128036047917354 0.6172106824925816 0.8900978129464776 0.7804154302670623 0.8900978129464776 0.6198483349818661 0.05681943070667109 0.11462798109682375 0.3728981206726013 0.026046818331684696 0.00901197933838876 0.10913287174414764 0.2523354214748873 0.013627871194636718 0.0015386306187493194 0.024398285525881955 0.8900978129464776 0.20200021980437408 0.010660512144191657 -0.0009891196834817388 0.015825914935707196 0.8900978129464776 0.8900978129464776
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
1833.4576454973073 814.7863891368864 217.7654229387627 54.07616924802023 13.731802945455469 3.5847427020530582 0.9118541432904458 0.2622900218848318 0.07627003915874074 0.027745791769400664 0.015915006254486226 0.012409352705166696 0.0077479353538904274 0.0062617873011873975 0.005917287498327866 1833.2003417254284 544.2136113656462 35.21026365121499 33.83804856891729 33.83703344984572 33.83750169488491 33.84147193704756 342.096219925368 82.6043808610444 75.92517125989443 27.82235802343243 26.574672151466128 9.6049355988981 14.044291246668882 14.55114135659603 2.4864347446515884 9.426150750133262 10.07193874315086 10.701781541507756 0.6597191298214471 2.4197650833272104 9.550487563237345 8.849135643873504 9.216705123201871 0.1929881628940372 0.6207588325434388 2.5428780026080493 9.550487563237345
js_param_list:
3613.037168160796 1825.7907466758202 512.0932785883192 129.26071365337654 33.314456921282606 8.673843570791789 2.1826018682118424 0.6138186833325912 0.1691841503982388 0.05180905191755439 0.02266508641878177 0.014530378356803484 0.00975786055068809 0.005063431812688739 0.00398069855228542 3612.992302272399 1246.9340617899438 71.14710558688047 67.61964317269017 67.6172664356203 67.61753548832318 67.6175100773394 755.379587970111 181.41267691267066 170.89087380459807 56.989989927129535 59.371069176236894 19.274735775346528 26.031672719261728 32.363778392002544 5.0043194398511135 18.814548222792805 17.309141148134536 23.84953967534161 1.332034978863292 4.83191046013193 18.864051408815957 14.787650268158211 20.519388091926267 0.3942680972083926 1.231435885110694 4.879394902995963 18.864051408815957
ptq_acc_list:
10.0 10.0 31.15 81.89 84.93 85.69 85.78 85.63 82.63 74.8 51.56 29.34 13.78 11.57 10.17 10.0 10.0 44.45 44.64 46.43 44.18 38.58 9.92 38.85 70.91 65.34 82.3 80.82 73.99 84.14 85.05 76.68 75.95 84.95 85.55 81.54 10.0 77.73 85.18 85.98 81.93 10.01 13.34
acc_loss_list:
0.8835991153532767 0.8835991153532767 0.6374112443254569 0.04679315562798273 0.011407286695378766 0.0025608194622279 0.0015132115004073503 0.003259224770108266 0.038179490164125265 0.1293213828425096 0.3998370387614945 0.6584798044465138 0.8395995809568153 0.8653241764637412 0.8816203003142824 0.8835991153532767 0.8835991153532767 0.48259806774531483 0.4803864509370271 0.45955069258526365 0.4857408916307764 0.5509253870329415 0.8845303224304505 0.5477825631474799 0.17460132697008499 0.2394366197183098 0.04202071935746711 0.05924805028518221 0.1387498544988942 0.02060295658246998 0.010010476079618198 0.10743801652892551 0.11593528110813635 0.011174484926085367 0.004190431847282033 0.05086718659061798 0.8835991153532767 0.09521592364101959 0.008497264579210684 -0.0008148061925271492 0.04632755208939576 0.8834827144686299 0.844721219881271
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
1489.6432790793892 858.47390911721 350.38842997977486 146.66108726257698 65.51871772345022 30.802447738403625 15.015633081763848 7.372939759214539 3.602748170145869 1.7596017349024324 0.9023980469489912 0.42604559053986407 0.2086617904343124 0.11696138076612213 0.06650877266410397 1489.4950585919782 648.6726890563766 44.0313761461945 14.813184296979202 14.70886411284525 14.708637223793453 14.708329981851291 442.110054514285 167.03961080744105 110.912352356486 73.14257321252117 44.75826464643717 8.918392355710786 35.41728607793805 22.00069249787684 3.0807357022322006 4.133106679411769 18.786975210869198 12.291142909976228 1.2420341267864268 1.0780820385160967 4.196963771246701 10.816659177228358 7.780715811154463 0.513961854917128 0.2801788104206083 1.150574461896198 4.196963771246701
js_param_list:
2988.8747567488617 1887.9793935628438 794.5371505720092 330.3960680775245 145.92231495119452 67.9559314292448 33.03981244952361 16.124047122726786 8.021401990398326 3.943098007875918 1.9811299823118427 0.9460539051395199 0.44709418282093033 0.22449034273754867 0.12425914862692854 2988.8363531204886 1451.7681143260804 94.67273844326954 30.460878266197444 30.244231409403923 30.2446749589304 30.244134610251493 984.9086948427197 371.60971497639866 248.5749360354289 159.90777702378026 99.54631101875773 19.048673214252524 75.87671359764475 48.95576239520067 6.683113070521427 8.485231215526596 39.31778320380456 27.44412247810391 2.6854627255413566 2.207580403630901 8.479439151405776 21.80574465505866 17.614834435129385 1.148945392883737 0.5553705895013917 2.2254689905601692 8.479439151405776
ptq_acc_list:
10.0 10.0 10.01 72.69 87.21 89.67 90.45 90.33 89.37 79.82 61.97 35.21 22.84 21.47 13.47 10.0 10.0 12.81 17.49 27.49 30.18 34.97 10.0 15.78 21.89 33.3 82.29 82.49 58.04 87.21 88.9 82.42 67.65 88.34 90.33 87.15 10.05 70.35 89.06 90.52 88.78 9.99 10.0
acc_loss_list:
0.8896369054188279 0.8896369054188279 0.8895265423242467 0.19777066548946035 0.037523452157598565 0.010374130890630148 0.0017658095132987153 0.00309016664827283 0.01368502372806528 0.11908177905308472 0.31607990288047677 0.6114115439796932 0.747930691976603 0.7630504359342236 0.8513409115991613 0.8896369054188279 0.8896369054188279 0.8586248758415186 0.8069749475775301 0.6966118529963581 0.6669241805540227 0.6140602582496413 0.8896369054188279 0.8258470367509105 0.7584151859618143 0.6324908950446971 0.09182209469153507 0.08961483279991177 0.3594525990508774 0.037523452157598565 0.018872089173380353 0.09038737446197989 0.253393665158371 0.025052422469926013 0.00309016664827283 0.03818563072508546 0.8890850899459222 0.22359562962145466 0.017106279660081637 0.0009932678512305862 0.020196446308354467 0.8897472685134091 0.8896369054188279
from model import *
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += lossLayer(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
test_loss, 100. * correct / len(test_loader.dataset)
))
if __name__ == "__main__":
batch_size = 128
test_batch_size = 128
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
momentum = 0.5
save_model = True
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
)
# model = AlexNet_BN().to(device)
model = resnet18().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
if save_model:
if not osp.exists('ckpt'):
os.makedirs('ckpt')
torch.save(model.state_dict(), 'ckpt/cifar10_ResNet18.pt')
\ No newline at end of file
import re
pattern = r"\(\w+(\.\w+)*\)"
# text = "" # 输入的文本字符串
# 从文本中查找匹配的子模块名称列表
matches = re.findall(pattern, text)
# 提取所有子模块路径并存储到一个列表中
submodule_paths = [match.strip("()") for match in matches]
# 输出所有子模块路径
print(submodule_paths)
\ No newline at end of file
import sys
import os
# 从get_param.py输出重定向文件val.txt中提取参数量和计算量
def extract_ratio(md='ResNet18'):
fr = open('param_flops_' + md + '.txt','r')
lines = fr.readlines()
layer = []
par_ratio = []
flop_ratio = []
for line in lines:
# if '(' in line and ')' in line:
if 'Conv' in line or 'BatchNorm2d' in line or 'Linear' in line:
layer.append(line.split(':')[1].split('(')[0])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
par_ratio.append(r1)
r2 = line.split('%')[-2].split(',')[-1]
r2 = float(r2)
flop_ratio.append(r2)
return layer, par_ratio, flop_ratio
if __name__ == "__main__":
layer, par_ratio, flop_ratio = extract_ratio()
print(len(layer))
print(len(par_ratio))
print(len(flop_ratio))
\ No newline at end of file
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
\ No newline at end of file
from model import *
import torch
from ptflops import get_model_complexity_info
import argparse
def get_children(model: torch.nn.Module):
# get children form model!
# 为了后续也能够更新参数,需要用nn.ModuleList来承载
# children = nn.ModuleList(model.children())
# print(children)
# 方便对其中的module进行后续的更新
# flatt_children = nn.ModuleList()
children = list(model.children())
# flatt_children = nn.ModuleList()
flatt_children = []
if len(children) == 0:
# if model has no children; model is last child! :O
return model
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
# print(flatt_children)
return flatt_children
# 定义获取不包含wrapper的所有子模块的函数
def get_all_child_modules(module):
for name, child in module.named_children():
if isinstance(child, nn.Sequential):
yield from get_all_child_modules(child)
elif len(list(child.children())) > 0:
yield from child.children()
else:
yield child
def filter_fn(module, n_inp, outp_shape):
# if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU,torch.nn.BatchNorm2d,torch.nn.Linear,torch.nn.AdaptiveAvgPool2d)):
if 'conv' in module or 'bn' in module or 'fc' in module or 'avg' in module or 'relu' in module:
return True
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Analysis --- params & flops')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
args = parser.parse_args()
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
full_file = 'ckpt/cifar10_' + args.model + '.pt'
model.load_state_dict(torch.load(full_file))
# flat = get_children(model)
# print(flat)
# flat = get_children(model)
# new_model = nn.Sequential(*flat)
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
def get_model_histogram(model):
"""
Description:
- get norm gradients from model, and store in a OrderDict
Args:
- model: (torch.nn.Module), torch model
Returns:
- grads in OrderDict
"""
gradshisto = OrderedDict()
grads = OrderedDict()
for name, params in model.named_parameters():
grad = params.grad
if grad is not None:
tmp = {}
params_np = grad.cpu().numpy()
histogram, bins = np.histogram(params_np.flatten(),bins=20)
tmp['histogram'] = list(histogram)
tmp['bins'] = list(bins)
gradshisto[name] = tmp
grads[name] = params_np
return gradshisto,grads
def get_model_norm_gradient(model):
"""
Description:
- get norm gradients from model, and store in a OrderDict
Args:
- model: (torch.nn.Module), torch model
Returns:
- grads in OrderDict
"""
grads = OrderedDict()
for name, params in model.named_parameters():
grad = params.grad
if grad is not None:
grads[name] = grad.norm().item()
return grads
def get_grad_histogram(grads_sum):
gradshisto = OrderedDict()
# grads = OrderedDict()
for name, params in grads_sum.items():
grad = params
if grad is not None:
tmp = {}
#params_np = grad.cpu().numpy()
params_np = grad
histogram, bins = np.histogram(params_np.flatten(),bins=20)
tmp['histogram'] = list(histogram)
tmp['bins'] = list(bins)
gradshisto[name] = tmp #每层一个histogram (tmp中的是描述直方图的信息)
# grads[name] = params_np
return gradshisto
\ No newline at end of file
class GlobalVariables:
SELF_INPLANES = 0
\ No newline at end of file
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
from model import *
from extract_ratio import *
from utils import *
import argparse
import openpyxl
import os
import os.path as osp
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Loss-Grad Analysis')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
args = parser.parse_args()
wb = openpyxl.Workbook()
ws = wb.active
writer = SummaryWriter(log_dir='log/' + args.model + '/qat_loss_grad')
# layer, par_ratio, flop_ratio = extract_ratio()
layer, par_ratio, flop_ratio = extract_ratio(args.model)
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
layer = []
for name, param in model.named_parameters():
if 'weight' in name:
n = name.split('.')
pre = '.'.join(n[:len(n)-1])
layer.append(pre)
# dir_prefix = 'ckpt/qat/epoch_'
dir_prefix = 'ckpt/qat/'+ args.model + '/'
quant_type_list = ['INT','POT','FLOAT']
for epoch in [5,10,15,20]:
ws_epoch = wb.create_sheet('epoch_%d'%epoch)
full_state = torch.load(dir_prefix+'%d/'%epoch + 'full.pt')
ws_epoch.cell(row=1,column=2,value='loss')
ws_epoch.cell(row=1,column=3,value='loss_sum')
ws_epoch.cell(row=1,column=4,value='loss_avg')
ws_epoch.cell(row=2,column=1,value='FP32')
ws_epoch.cell(row=2,column=2,value=full_state['loss'].cpu().item())
ws_epoch.cell(row=2,column=3,value=full_state['loss_sum'].cpu().item())
ws_epoch.cell(row=2,column=4,value=full_state['loss_avg'].cpu().item())
# full_grad = full_state['grad_dict']
# full_grad_sum = full_state['grad_dict_sum']
full_grad_avg = full_state['grad_dict_avg']
for name,tmpgrad in full_grad_avg.items():
writer.add_histogram('FULL: '+name,tmpgrad,global_step=epoch)
ws_epoch.cell(row=4,column=1,value='title')
ws_epoch.cell(row=4,column=2,value='loss')
ws_epoch.cell(row=4,column=3,value='loss_sum')
ws_epoch.cell(row=4,column=4,value='loss_avg')
ws_epoch.cell(row=4,column=5,value='js_grad_avg_norm')
# ws_epoch.cell(row=4,column=6,value='conv1.weight')
# ws_epoch.cell(row=4,column=7,value='conv1.bias')
# ws_epoch.cell(row=4,column=8,value='conv2.weight')
# ws_epoch.cell(row=4,column=9,value='conv2.bias')
# ws_epoch.cell(row=4,column=10,value='conv3.weight')
# ws_epoch.cell(row=4,column=11,value='conv3.bias')
# ws_epoch.cell(row=4,column=12,value='conv4.weight')
# ws_epoch.cell(row=4,column=13,value='conv4.bias')
# ws_epoch.cell(row=4,column=14,value='conv5.weight')
# ws_epoch.cell(row=4,column=15,value='conv5.bias')
# ws_epoch.cell(row=4,column=16,value='fc1.weight')
# ws_epoch.cell(row=4,column=17,value='fc1.bias')
# ws_epoch.cell(row=4,column=18,value='fc2.weight')
# ws_epoch.cell(row=4,column=19,value='fc2.bias')
# ws_epoch.cell(row=4,column=20,value='fc3.weight')
# ws_epoch.cell(row=4,column=21,value='fc3.bias')
cnt = 5
for n in layer:
cnt = cnt + 1
ws_epoch.cell(row=4,column=cnt,value=n)
currow=4
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nAnalyse: '+title)
currow += 1
qat_state=torch.load(dir_prefix+'%d/'%epoch+title+'.pt')
js_grad_avg_norm=0.
grad_avg = qat_state['grad_dict_avg']
for name,tmpgrad in grad_avg.items():
writer.add_histogram(title+': '+name,tmpgrad,global_step=epoch)
colidx=5
for name,_ in full_grad_avg.items():
prefix = name.split('.')[0]
colidx += 1
layer_idx = layer.index(prefix)
js_norm = js_div_norm(full_grad_avg[name],grad_avg[name])
ws_epoch.cell(row=currow,column=colidx,value=js_norm.cpu().item())
js_grad_avg_norm += flop_ratio[layer_idx] * js_norm
ws_epoch.cell(row=currow,column=1,value=title)
ws_epoch.cell(row=currow,column=2,value=qat_state['loss'].cpu().item())
ws_epoch.cell(row=currow,column=3,value=qat_state['loss_sum'].cpu().item())
ws_epoch.cell(row=currow,column=4,value=qat_state['loss_avg'].cpu().item())
ws_epoch.cell(row=currow,column=5,value=js_grad_avg_norm.cpu().item())
wb.save('loss_grad.xlsx')
writer.close()
\ No newline at end of file
ResNet(
712.67 k, 101.597% Params, 35.92 MMac, 100.000% MACs,
(conv1): Conv2d(432, 0.062% Params, 442.37 KMac, 1.232% MACs, 3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 16.38 KMac, 0.046% MACs, )
(layer1): MakeLayer(
9.34 k, 1.332% Params, 9.63 MMac, 26.822% MACs,
(blockdict): ModuleDict(
9.34 k, 1.332% Params, 9.63 MMac, 26.822% MACs,
(block1): BasicBlock(
4.67 k, 0.666% Params, 4.82 MMac, 13.411% MACs,
(conv1): Conv2d(2.3 k, 0.328% Params, 2.36 MMac, 6.569% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(2.3 k, 0.328% Params, 2.36 MMac, 6.569% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 32.77 KMac, 0.091% MACs, )
)
(block2): BasicBlock(
4.67 k, 0.666% Params, 4.82 MMac, 13.411% MACs,
(conv1): Conv2d(2.3 k, 0.328% Params, 2.36 MMac, 6.569% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(2.3 k, 0.328% Params, 2.36 MMac, 6.569% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 32.77 KMac, 0.091% MACs, )
)
)
)
(layer2): MakeLayer(
33.66 k, 4.799% Params, 8.65 MMac, 24.085% MACs,
(downsample): Sequential(
576, 0.082% Params, 147.46 KMac, 0.411% MACs,
(0): Conv2d(512, 0.073% Params, 131.07 KMac, 0.365% MACs, 16, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.046% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(blockdict): ModuleDict(
33.09 k, 4.717% Params, 8.5 MMac, 23.675% MACs,
(block1): BasicBlock(
14.53 k, 2.071% Params, 3.74 MMac, 10.400% MACs,
(conv1): Conv2d(4.61 k, 0.657% Params, 1.18 MMac, 3.284% MACs, 16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.046% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(9.22 k, 1.314% Params, 2.36 MMac, 6.569% MACs, 32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.046% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 16.38 KMac, 0.046% MACs, )
(downsample): Sequential(
576, 0.082% Params, 147.46 KMac, 0.411% MACs,
(0): Conv2d(512, 0.073% Params, 131.07 KMac, 0.365% MACs, 16, 32, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.046% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block2): BasicBlock(
18.56 k, 2.646% Params, 4.77 MMac, 13.274% MACs,
(conv1): Conv2d(9.22 k, 1.314% Params, 2.36 MMac, 6.569% MACs, 32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.046% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(9.22 k, 1.314% Params, 2.36 MMac, 6.569% MACs, 32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.046% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 16.38 KMac, 0.046% MACs, )
)
)
)
(layer3): MakeLayer(
133.89 k, 19.087% Params, 8.59 MMac, 23.903% MACs,
(downsample): Sequential(
2.18 k, 0.310% Params, 139.26 KMac, 0.388% MACs,
(0): Conv2d(2.05 k, 0.292% Params, 131.07 KMac, 0.365% MACs, 32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(blockdict): ModuleDict(
131.71 k, 18.777% Params, 8.45 MMac, 23.515% MACs,
(block1): BasicBlock(
57.73 k, 8.230% Params, 3.7 MMac, 10.309% MACs,
(conv1): Conv2d(18.43 k, 2.628% Params, 1.18 MMac, 3.284% MACs, 32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(36.86 k, 5.255% Params, 2.36 MMac, 6.569% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 8.19 KMac, 0.023% MACs, )
(downsample): Sequential(
2.18 k, 0.310% Params, 139.26 KMac, 0.388% MACs,
(0): Conv2d(2.05 k, 0.292% Params, 131.07 KMac, 0.365% MACs, 32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block2): BasicBlock(
73.98 k, 10.547% Params, 4.74 MMac, 13.206% MACs,
(conv1): Conv2d(36.86 k, 5.255% Params, 2.36 MMac, 6.569% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(36.86 k, 5.255% Params, 2.36 MMac, 6.569% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 8.19 KMac, 0.023% MACs, )
)
)
)
(layer4): MakeLayer(
534.02 k, 76.129% Params, 8.55 MMac, 23.812% MACs,
(downsample): Sequential(
8.45 k, 1.204% Params, 135.17 KMac, 0.376% MACs,
(0): Conv2d(8.19 k, 1.168% Params, 131.07 KMac, 0.365% MACs, 64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(blockdict): ModuleDict(
525.57 k, 74.924% Params, 8.42 MMac, 23.435% MACs,
(block1): BasicBlock(
230.14 k, 32.809% Params, 3.69 MMac, 10.264% MACs,
(conv1): Conv2d(73.73 k, 10.511% Params, 1.18 MMac, 3.284% MACs, 64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(147.46 k, 21.021% Params, 2.36 MMac, 6.569% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 4.1 KMac, 0.011% MACs, )
(downsample): Sequential(
8.45 k, 1.204% Params, 135.17 KMac, 0.376% MACs,
(0): Conv2d(8.19 k, 1.168% Params, 131.07 KMac, 0.365% MACs, 64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(block2): BasicBlock(
295.42 k, 42.115% Params, 4.73 MMac, 13.172% MACs,
(conv1): Conv2d(147.46 k, 21.021% Params, 2.36 MMac, 6.569% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(147.46 k, 21.021% Params, 2.36 MMac, 6.569% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 4.1 KMac, 0.011% MACs, )
)
)
)
(avgpool): AdaptiveAvgPool2d(0, 0.000% Params, 2.05 KMac, 0.006% MACs, output_size=(1, 1))
(fc): Linear(1.29 k, 0.184% Params, 1.29 KMac, 0.004% MACs, in_features=128, out_features=10, bias=True)
)
\ No newline at end of file
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
1550.9715484177518 1010.6618450832785 490.36711983121603 228.63618240796055 105.95458399477697 49.891832454127176 23.971386463834847 11.766940263306672 5.802198156728322 2.945552854494226 1.4089209459809153 0.7028945356689421 0.3627421448027583 0.19639652432616655 0.10290621589207728 1550.8075202402079 809.3304044813559 63.26114709313436 15.830878394687405 15.634705354751606 15.634431154765087 15.634698562637233 587.3992465237176 253.07460850671436 167.7732979941726 114.12219518775181 68.27832659271789 12.144600108734192 54.196499986788375 32.71026185994328 4.564269861924828 4.394313417840971 27.2226004480338 17.773922603276603 1.835626977923447 1.140768842343387 4.352845163504906 14.670441182682843 10.848507506546051 0.8384693109379681 0.2879456941504031 1.125951118032696 4.513798580930398
js_param_list:
2002.106323493944 1372.02362009663 671.0149474230941 309.663259250007 143.27027648429524 67.49540905020389 32.55983555593742 15.869736824352183 7.926162889715859 4.02332834190554 1.9402921921025507 0.9559948430950564 0.508804418130463 0.2717493824489641 0.13919396885736918 2001.9790999589547 1107.2067857978273 83.57876316913644 19.292681748877833 19.02125598032171 19.020799891258196 19.02136321962437 798.8315967007426 340.9924481049778 224.3565008774726 153.1878319793109 91.23587962872053 15.849439499156954 72.38482837963767 43.77844645529062 6.041587890012816 5.361963458618261 36.153465149572746 23.556317981886238 2.4844498630535665 1.3908570555914732 5.3079356070464785 19.032545512994524 14.363684371124728 1.1298448661196772 0.35511280677379703 1.3705161714374172 5.430401902437382
ptq_acc_list:
10.0 10.01 11.68 35.86 71.22 88.15 90.01 90.61 90.77 90.77 90.79 90.77 90.77 90.77 90.77 10.0 8.88 18.15 23.5 19.47 15.55 15.61 12.42 18.8 37.09 19.94 44.44 74.39 40.94 49.59 88.61 75.76 46.23 56.11 90.24 88.38 78.55 56.06 53.84 90.71 90.27 88.32 71.82
acc_loss_list:
0.8898314421064228 0.8897212735485291 0.8713231243803019 0.6049355513936322 0.21537953068194335 0.028864162168117113 0.008372810399911765 0.0017626969262971974 0.0 0.0 -0.0002203371157872671 0.0 0.0 0.0 0.0 0.8898314421064228 0.9021703205905035 0.8000440674231575 0.7411038889500936 0.7855018177812052 0.8286878924754875 0.8280268811281261 0.8631706510961771 0.7928831111600749 0.5913848187727222 0.7803238955602071 0.5104109287209431 0.18045609782967936 0.548969923983695 0.45367412140575075 0.023796408505012634 0.16536300539825924 0.49069075685799274 0.38184422165913845 0.005838933568359603 0.02633028533656495 0.1346259777459513 0.3823950644486063 0.4068524843009804 0.0006610113473614881 0.0055084278946788585 0.02699129668392644 0.20876941720832878
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
1354.2945784950603 722.8096011458398 255.38235121030203 77.72514867219756 24.195293094663096 7.6188281643621805 2.711759981703649 1.1327320388891737 0.4918583902845934 0.2514171979338759 0.0851226729374511 0.031643228209198994 0.014960137087802734 0.014775519758112417 0.014717476906251569 1354.017620655594 530.956126028799 29.270594072095776 23.95970497578537 23.946279778497708 23.945813186814195 23.946195581641064 349.9397526929324 98.90322611449565 77.39375850717117 32.951875627562636 26.29951378038865 7.452277149981623 13.696730185700897 13.489537436477727 2.070642776277882 6.647399487862168 7.9287061447075295 9.3465570209075 0.5745562357250971 1.7289203756753833 6.647398112960826 6.036095650090301 7.68762550328311 0.18322236331308744 0.43723723366736766 1.728919472161257 6.887208440038224
js_param_list:
1905.034397666333 1304.431928190669 532.9886077118246 162.16424238527497 45.90216325868581 12.55083851407541 3.737229506223789 1.1881391381377582 0.4333202294403889 0.17966875415589165 0.0686934709653976 0.02711635083384763 0.012422401703485457 0.011924650749714109 0.011811756854569257 1904.9690246678717 1024.4893074529332 44.47069839549869 37.182205895517406 37.17084546399217 37.170360480922994 37.170596963709116 676.9975549510849 192.141761781115 146.37833455189465 54.830161088403294 46.75208855879235 11.459473919635343 17.814252125337802 23.67092841602859 3.110218252992642 10.411529270878503 8.13295996799088 16.519087866858648 0.8965427653389706 2.6972225487456454 10.411531041285842 5.310143987405656 13.920591998362562 0.3160476684964184 0.6824627078929684 2.697221260834613 10.486319878844261
ptq_acc_list:
10.0 10.0 18.77 75.9 86.88 88.89 89.36 89.46 89.51 89.48 89.47 89.5 89.5 89.49 89.48 10.0 11.15 45.68 42.59 32.54 49.07 45.67 13.83 30.06 63.49 63.24 81.67 81.33 74.09 84.67 88.05 80.51 78.4 86.44 89.1 87.93 79.02 78.99 86.51 89.29 89.18 88.01 82.21
acc_loss_list:
0.8882431828341529 0.8882431828341529 0.790232454179705 0.15176575771122036 0.029056772463120346 0.006593652212785018 0.0013410818059902162 0.00022351363433180857 -0.00033527045149755406 0.0 0.00011175681716590428 -0.00022351363433164976 -0.00022351363433164976 -0.00011175681716574548 0.0 0.8882431828341529 0.8753911488600804 0.4894948591864104 0.5240277156906571 0.6363433169423335 0.4516092981671882 0.48960661600357624 0.8454403218596335 0.6640590075994636 0.29045596781403665 0.2932498882431828 0.08728207420652662 0.09108180599016547 0.17199374161823872 0.053755029056772485 0.015981224854716213 0.10024586499776485 0.12382655341975858 0.03397407241841759 0.004246759052302298 0.01732230666070627 0.11689763075547617 0.11723290120697372 0.03319177469825658 0.0021233795261510697 0.0033527045149753815 0.016428252123379512 0.08124720607957096
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
1469.414342814702 866.0874342352733 381.5565602277652 169.18221229085108 78.82267155260986 37.754526594520954 18.515153791688547 9.354666264651968 4.563694245906456 2.205479237878888 1.1091092104104436 0.5377804901213837 0.24827767034479217 0.14283815911342562 0.07463956191264126 1469.2994550287851 670.7959107285827 47.486064140133465 11.113135358953256 10.970783263210283 10.970809026674786 10.971045295924528 468.1899266921865 189.15608989973475 119.8974617444954 85.97368505940334 49.70188531983513 9.180499482602498 41.72516188841945 23.929720892316485 3.3157563198123947 3.094939418253566 21.65610446037262 12.921220588085308 1.37834375168831 0.8028676010370622 3.0495072368117375 12.211155381042943 7.566622549160813 0.6068437164655487 0.206629018492998 0.7881132018682301 3.209749482698125
js_param_list:
2428.388695366032 1576.3835862311166 735.7144816775965 330.7879670826985 153.8651132826839 73.29165336554594 35.886440151847374 18.054417795113007 8.71550268107359 4.208844604961609 2.089532093278776 0.9633847051060299 0.3967191667447389 0.20409221634042343 0.09841711129863948 2428.32641725148 1252.5298934046723 88.6695229301097 18.02081222203601 17.817638699939312 17.81805197224828 17.817443494485715 883.6762566395342 364.86096182524466 232.53743290664556 164.79476403956693 96.11400180562215 16.67957044460877 78.7486628351845 46.083739584063444 6.189988926373023 4.988900305675509 39.98449203615441 24.734066963325354 2.585781231886643 1.3054649551746371 4.9472935954404 21.66781080710387 14.319081489805278 1.0990111398076245 0.33141428747561985 1.2786737197199143 5.037454816228963
ptq_acc_list:
10.0 10.0 20.08 75.03 88.19 89.81 90.23 90.5 90.48 90.54 90.53 90.57 90.59 90.57 90.56 10.0 12.24 16.16 13.0 21.38 19.87 21.15 12.25 12.99 37.39 24.11 76.45 76.57 38.86 85.05 88.16 78.1 51.66 86.67 90.09 87.89 76.06 54.48 87.17 90.32 90.06 87.96 79.95
acc_loss_list:
0.8895759717314488 0.8895759717314488 0.7782685512367491 0.17148851590106007 0.026170494699646694 0.008281802120141343 0.003643992932862172 0.0006625441696113325 0.000883392226148391 0.00022084805653705853 0.00033127208480566626 -0.0001104240282684508 -0.00033127208480566626 -0.0001104240282684508 0.0 0.8895759717314488 0.8648409893992933 0.8215547703180213 0.8564487632508834 0.7639134275618376 0.7805874558303887 0.766453180212014 0.8647305653710248 0.856559187279152 0.5871245583038869 0.733767667844523 0.1558083038869258 0.15448321554770328 0.57089222614841 0.06084363957597179 0.02650176678445236 0.13758833922261493 0.42954946996466437 0.04295494699646644 0.0051899293286218956 0.0294832155477032 0.1601148409893993 0.3984098939929329 0.03743374558303887 0.00265017667844533 0.005521201413427562 0.028710247349823415 0.11715989399293285
from model import *
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += lossLayer(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
test_loss, 100. * correct / len(test_loader.dataset)
))
if __name__ == "__main__":
batch_size = 128
test_batch_size = 128
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
momentum = 0.5
save_model = True
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
)
# model = AlexNet_BN().to(device)
model = resnet18().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
if save_model:
if not osp.exists('ckpt'):
os.makedirs('ckpt')
torch.save(model.state_dict(), 'ckpt/cifar10_ResNet18.pt')
\ No newline at end of file
import re
pattern = r"\(\w+(\.\w+)*\)"
# text = "" # 输入的文本字符串
# 从文本中查找匹配的子模块名称列表
matches = re.findall(pattern, text)
# 提取所有子模块路径并存储到一个列表中
submodule_paths = [match.strip("()") for match in matches]
# 输出所有子模块路径
print(submodule_paths)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
e_bit_list = list(range(1,num_bits-1))
else:
e_bit_list = [0]
return e_bit_list
def numbit_list(quant_type):
if quant_type == 'INT':
num_bit_list = list(range(2,17))
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
else:
num_bit_list = list(range(2,9))
# num_bit_list = [8]
return num_bit_list
def build_bias_list(quant_type):
if quant_type == 'POT':
return build_pot_list(8) #
else:
return build_float_list(16,7)
def build_list(quant_type, num_bits, e_bits):
if quant_type == 'POT':
return build_pot_list(num_bits)
else:
return build_float_list(num_bits,e_bits)
def build_pot_list(num_bits):
plist = [0.]
for i in range(-2 ** (num_bits-1) + 2, 1):
# i最高到0,即pot量化最大值为1
plist.append(2. ** i)
plist.append(-2. ** i)
plist = torch.Tensor(list(set(plist)))
# plist = plist.mul(1.0 / torch.max(plist))
return plist
def build_float_list(num_bits,e_bits):
m_bits = num_bits - 1 - e_bits
plist = [0.]
# 相邻尾数的差值
dist_m = 2 ** (-m_bits)
e = -2 ** (e_bits - 1) + 1
for m in range(1, 2 ** m_bits):
frac = m * dist_m # 尾数部分
expo = 2 ** e # 指数部分
flt = frac * expo
plist.append(flt)
plist.append(-flt)
for e in range(-2 ** (e_bits - 1) + 2, 2 ** (e_bits - 1) + 1):
expo = 2 ** e
for m in range(0, 2 ** m_bits):
frac = 1. + m * dist_m
flt = frac * expo
plist.append(flt)
plist.append(-flt)
plist = torch.Tensor(list(set(plist)))
return plist
def fold_ratio(layer, par_ratio, flop_ratio):
idx = -1
for name in layer:
idx = idx + 1
# layer是for name, param in model.named_parameters()中提取出来的,一定是有downsample的
if 'bn' in name or 'sample.1' in name:
par_ratio[idx-1] += par_ratio[idx]
flop_ratio[idx-1] += flop_ratio[idx]
return par_ratio,flop_ratio
def fold_model(model):
idx = -1
module_list = []
# print('fold model:')
for name, module in model.named_modules():
# print(name+'-- +')
idx += 1
module_list.append(module)
# 这里之前忘记考虑downsampl里的conv了,导致少融合了一些
if 'bn' in name or 'sample.1' in name:
# print(name+'-- *')
module_list[idx-1] = fold_bn(module_list[idx-1],module) # 在这里修改了
return model
# def fold_model(model):
# last_conv = None
# last_bn = None
# for name, module in model.named_modules():
# if isinstance(module, nn.Conv2d):
# # 如果当前模块是卷积层,则将其 "fold" 到上一个 BN 层中
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# last_conv = module
# elif isinstance(module, nn.BatchNorm2d):
# # 如果当前模块是 BN 层,则将其 "fold" 到上一个卷积层中
# last_bn = module
# if last_conv is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# # 处理最后一个 BN 层
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# return model
def fold_bn(conv, bn):
# 获取 BN 层的参数
gamma = bn.weight.data
beta = bn.bias.data
mean = bn.running_mean
var = bn.running_var
eps = bn.eps
std = torch.sqrt(var + eps)
feat = bn.num_features
# 获取卷积层的参数
weight = conv.weight.data
if conv.bias is not None:
bias = conv.bias.data
if bn.affine:
gamma_ = gamma / std
weight = weight * gamma_.view(feat, 1, 1, 1)
if conv.bias is not None:
bias = gamma_ * bias - gamma_ * mean + beta
else:
bias = beta - gamma_ * mean
else:
gamma_ = 1 / std
weight = weight * gamma_
if conv.bias is not None:
bias = gamma_ * bias - gamma_ * mean
else:
bias = -gamma_ * mean
# 设置新的 weight 和 bias
conv.weight.data = weight
# 适用于bias=none的
if conv.bias is None:
conv.bias = nn.Parameter(bias)
else:
conv.bias.data = bias
return conv
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment