Commit 15882bd1 by Klin

feat: VGG_16: ptq and fit

parent 4d897197
# AlexNet_BN 量化说明
+ 结构和AlexNet基本一致,在训练中的每个conv后加了bn层,量化时将conv、bn、relu合并为QConvBNReLu层。
+ 合并之后如果直接和原全精度模型的conv层进行相似度比对,不符合眼里,且拟合效果较差(可见之前commit的README)
+ 措施:将全精度模型的BN层参数fold至Conv层中,参数量和计算量也相应加至Conv层
+ 该方法不会降低全精度模型推理精度(可为全精度模型加入排除BN层的inference方法,使用fold后模型推理验证)
+ 另外测试了分别单独量化Conv和BN层的方案,精度下降较为明显。同时当前大部分量化策略都采用了Conv+BN的方案,融合能够减少运算量,使得模型更为高效。
## ptq部分
+ 量化结果:
FP32-acc:87.09
![image-20230410030841210](image/image-20230410030841210.png)
+ 数据拟合:
matlab导入数据,选择列向量
+ 加入FP3-FP7前:
+ js_flops - acc_loss
Rational: Numerator degree 2 / Denominator degree 2
- [ ] center and scale
![image-20230410030613387](image/image-20230410030613387.png)
- [x] center and scale
![image-20230410030625395](image/image-20230410030625395.png)
+ js_param - acc_loss
Rational: Numerator degree 2 / Denominator degree 2
- [ ] center and scale
![image-20230410030654186](image/image-20230410030654186.png)
- [x] center and scale
![image-20230410030707277](image/image-20230410030707277.png)
+ 加入FP3-FP7后
+ js_flops - acc_loss
Rational: Numerator degree 2 / Denominator degree 2
- [ ] center and scale
![image-20230410030018190](image/image-20230410030018190.png)
- [x] center and scale
![image-20230410030035550](image/image-20230410030035550.png)
+ js_param - acc_loss
Rational: Numerator degree 2 / Denominator degree 2
+ [ ] center and scale
![image-20230410030148120](image/image-20230410030148120.png)
+ [x] center and scale
![image-20230410030206554](image/image-20230410030206554.png)
\ No newline at end of file
import sys
import os
# 从get_param.py输出重定向文件val.txt中提取参数量和计算量
def extract_ratio():
fr = open('param_flops.txt','r')
lines = fr.readlines()
layer = []
par_ratio = []
flop_ratio = []
for line in lines:
if '(' in line and ')' in line:
layer.append(line.split(')')[0].split('(')[1])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
par_ratio.append(r1)
r2 = line.split('%')[-2].split(',')[-1]
r2 = float(r2)
flop_ratio.append(r2)
return layer, par_ratio, flop_ratio
if __name__ == "__main__":
layer, par_ratio, flop_ratio = extract_ratio()
print(layer)
print(par_ratio)
print(flop_ratio)
\ No newline at end of file
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
\ No newline at end of file
from model import *
import torch
from ptflops import get_model_complexity_info
if __name__ == "__main__":
model = VGG_16()
full_file = 'ckpt/cifar10_VGG_16.pt'
model.load_state_dict(torch.load(full_file))
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
import torch
import torch.nn as nn
import torch.nn.functional as F
from module import *
import module
# cfg = {
# 'A' : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
# 'B' : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
# 'D' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
# 'E' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
# }
feature_cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
classifier_cfg = [4096, 4096, 'LF']
def make_feature_layers(cfg, batch_norm=False):
layers = []
names = []
input_channel = 3
idx = 0
for l in cfg:
if l == 'M':
names.append('pool%d'%idx)
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
continue
idx += 1
names.append('conv%d'%idx)
layers.append(nn.Conv2d(input_channel, l, kernel_size=3, padding=1))
if batch_norm:
names.append('bn%d'%idx)
layers.append(nn.BatchNorm2d(l))
names.append('relu%d'%idx)
layers.append(nn.ReLU(inplace=True))
input_channel = l
return names,layers
def make_classifier_layers(cfg, in_features, num_classes):
layers=[]
names=[]
idx = 0
for l in cfg:
idx += 1
if l=='LF': #last fc
names.append('fc%d'%idx)
layers.append(nn.Linear(in_features,num_classes))
continue
names.append('fc%d'%idx)
layers.append(nn.Linear(in_features,l))
in_features=l
names.append('crelu%d'%idx) # classifier relu
layers.append(nn.ReLU(inplace=True))
names.append('drop%d'%idx)
layers.append(nn.Dropout())
return names,layers
def quantize_feature_layers(model,name_list,quant_type,num_bits,e_bits):
layers=[]
names=[]
last_conv = None
last_bn = None
idx = 0
for name in name_list:
if 'pool' in name:
names.append('qpool%d'%idx)
layers.append(QMaxPooling2d(quant_type, kernel_size=2, stride=2, padding=0, num_bits=num_bits, e_bits=e_bits))
elif 'conv' in name:
last_conv = getattr(model,name)
elif 'bn' in name:
last_bn = getattr(model,name)
elif 'relu' in name:
idx += 1
names.append('qconv%d'%idx)
if idx == 1:
layers.append(QConvBNReLU(quant_type, last_conv, last_bn, qi=True, qo=True, num_bits=num_bits, e_bits=e_bits))
else:
layers.append(QConvBNReLU(quant_type, last_conv, last_bn, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits))
return names,layers
def quantize_classifier_layers(model,name_list,quant_type,num_bits,e_bits):
layers=[]
names=[]
idx=0
for name in name_list:
layer = getattr(model,name)
if 'fc' in name:
idx+=1
names.append('qfc%d'%idx)
layers.append(QLinear(quant_type, layer, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits))
elif 'crelu' in name:
names.append('qcrelu%d'%idx)
layers.append(QReLU(quant_type, num_bits=num_bits, e_bits=e_bits))
elif 'drop' in name:
names.append(name)
layers.append(layer)
return names, layers
def quantize_utils(model,qfeature_name,qclassifier_name,func, x=None):
if func == 'inference':
layer=getattr(model,qfeature_name[0])
x = layer.qi.quantize_tensor(x)
last_qo = None
for name in qfeature_name:
layer = getattr(model,name)
if func == 'forward':
x = layer(x)
elif func == 'inference':
x = layer.quantize_inference(x)
else: #freeze
layer.freeze(last_qo)
if 'conv' in name:
last_qo = layer.qo
if func != 'freeze':
x = torch.flatten(x, start_dim=1)
for name in qclassifier_name:
layer = getattr(model,name)
if func == 'forward':
x = layer(x)
elif 'drop' not in name:
if func == 'inference':
x = layer.quantize_inference(x)
else: # freeze
layer.freeze(last_qo)
if 'fc' in name:
last_qo = layer.qo
if func == 'inference':
x = last_qo.dequantize_tensor(x)
return x
class VGG_16(nn.Module):
def __init__(self, num_class=10):
super().__init__()
feature_name,feature_layer = make_feature_layers(feature_cfg,batch_norm=True)
self.feature_name = feature_name
for name,layer in zip(feature_name,feature_layer):
self.add_module(name,layer)
classifier_name,classifier_layer = make_classifier_layers(classifier_cfg,512,num_class)
self.classifier_name = classifier_name
for name,layer in zip(classifier_name,classifier_layer):
self.add_module(name,layer)
# self.fc1 = nn.Linear(512, 4096)
# self.crelu1 = nn.ReLU(inplace=True)
# self.drop1 = nn.Dropout()
# self.fc2 = nn.Linear(4096, 4096)
# self.crelu2 = nn.ReLU(inplace=True)
# self.drop2 = nn.Dropout()
# self.fc3 = nn.Linear(4096, num_class)
def forward(self, x):
#feature
for name in self.feature_name:
layer = getattr(self,name)
x = layer(x)
x = torch.flatten(x, start_dim=1)
#classifier
for name in self.classifier_name:
layer = getattr(self,name)
x = layer(x)
# x = self.fc1(x)
# x = self.crelu1(x)
# x = self.drop1(x)
# x = self.fc2(x)
# x = self.crelu2(x)
# x = self.drop2(x)
# x = self.fc3(x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
# feature
qfeature_name,qfeature_layer = quantize_feature_layers(self,self.feature_name,quant_type,num_bits,e_bits)
self.qfeature_name = qfeature_name
for name,layer in zip(qfeature_name,qfeature_layer):
self.add_module(name,layer)
# classifier
qclassifier_name,qclassifier_layer = quantize_classifier_layers(self,self.classifier_name,quant_type,num_bits,e_bits)
self.qclassifier_name = qclassifier_name
for name,layer in zip(qclassifier_name,qclassifier_layer):
if 'drop' not in name:
self.add_module(name,layer)
# self.qfc1 = QLinear(quant_type, self.fc1, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
# self.qcrelu1 = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
# self.qfc2 = QLinear(quant_type, self.fc2, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
# self.qcrelu2 = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
# self.qfc3 = QLinear(quant_type, self.fc3, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
def quantize_forward(self,x):
x = quantize_utils(self, self.qfeature_name, self.qclassifier_name,
func='forward', x=x)
return x
def freeze(self):
quantize_utils(self, self.qfeature_name, self.qclassifier_name,
func='freeze', x=None)
def quantize_inference(self,x):
x = quantize_utils(self, self.qfeature_name, self.qclassifier_name,
func='inference', x=x)
return x
\ No newline at end of file
import math
import numpy as np
import gol
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from function import FakeQuantize
# 获取最近的量化值
def get_nearest_val(quant_type, x, is_bias=False, block_size=1000000):
if quant_type == 'INT':
return x.round_()
plist = gol.get_value(is_bias)
shape = x.shape
xhard = x.view(-1)
xout = torch.zeros_like(xhard)
plist = plist.type_as(x)
n_blocks = (x.numel() + block_size - 1) // block_size
for i in range(n_blocks):
start_idx = i * block_size
end_idx = min(start_idx + block_size, xhard.numel())
block_size_i = end_idx - start_idx
# print(x.numel())
# print(block_size_i)
# print(start_idx)
# print(end_idx)
xblock = xhard[start_idx:end_idx]
# xblock = xblock.view(shape[start_idx:end_idx])
plist_block = plist.unsqueeze(1) #.expand(-1, block_size_i)
idx = (xblock.unsqueeze(0) - plist_block).abs().min(dim=0)[1]
# print(xblock.shape)
xhard_block = plist[idx].view(xblock.shape)
xout[start_idx:end_idx] = (xhard_block - xblock).detach() + xblock
xout = xout.view(shape)
return xout
# 采用对称有符号量化时,获取量化范围最大值
def get_qmax(quant_type,num_bits=None, e_bits=None):
if quant_type == 'INT':
qmax = 2. ** (num_bits - 1) - 1
elif quant_type == 'POT':
qmax = 1
else: #FLOAT
m_bits = num_bits - 1 - e_bits
dist_m = 2 ** (-m_bits)
e = 2 ** (e_bits - 1)
expo = 2 ** e
m = 2 ** m_bits -1
frac = 1. + m * dist_m
qmax = frac * expo
return qmax
# 都采用有符号量化,zeropoint都置为0
def calcScaleZeroPoint(min_val, max_val, qmax):
scale = torch.max(max_val.abs(),min_val.abs()) / qmax
zero_point = torch.tensor(0.)
return scale, zero_point
# 将输入进行量化,输入输出都为tensor
def quantize_tensor(quant_type, x, scale, zero_point, qmax, is_bias=False):
# 量化后范围,直接根据位宽确定
qmin = -qmax
q_x = zero_point + x / scale
q_x.clamp_(qmin, qmax)
q_x = get_nearest_val(quant_type, q_x, is_bias)
return q_x
# bias使用不同精度,需要根据量化类型指定num_bits/e_bits
def bias_qmax(quant_type):
if quant_type == 'INT':
return get_qmax(quant_type, 64)
elif quant_type == 'POT':
return get_qmax(quant_type)
else:
return get_qmax(quant_type, 16, 5)
# 转化为FP32,不需再做限制
def dequantize_tensor(q_x, scale, zero_point):
return scale * (q_x - zero_point)
class QParam(nn.Module):
def __init__(self,quant_type, num_bits=8, e_bits=3):
super(QParam, self).__init__()
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.qmax = get_qmax(quant_type, num_bits, e_bits)
scale = torch.tensor([], requires_grad=False)
zero_point = torch.tensor([], requires_grad=False)
min = torch.tensor([], requires_grad=False)
max = torch.tensor([], requires_grad=False)
# 通过注册为register,使得buffer可以被记录到state_dict
self.register_buffer('scale', scale)
self.register_buffer('zero_point', zero_point)
self.register_buffer('min', min)
self.register_buffer('max', max)
# 更新统计范围及量化参数
def update(self, tensor):
if self.max.nelement() == 0 or self.max.data < tensor.max().data:
self.max.data = tensor.max().data
self.max.clamp_(min=0)
if self.min.nelement() == 0 or self.min.data > tensor.min().data:
self.min.data = tensor.min().data
self.min.clamp_(max=0)
self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, self.qmax)
def quantize_tensor(self, tensor):
return quantize_tensor(self.quant_type, tensor, self.scale, self.zero_point, self.qmax)
def dequantize_tensor(self, q_x):
return dequantize_tensor(q_x, self.scale, self.zero_point)
# 该方法保证了可以从state_dict里恢复
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
key_names = ['scale', 'zero_point', 'min', 'max']
for key in key_names:
value = getattr(self, key)
value.data = state_dict[prefix + key].data
state_dict.pop(prefix + key)
# 该方法返回值将是打印该对象的结果
def __str__(self):
info = 'scale: %.10f ' % self.scale
info += 'zp: %.6f ' % self.zero_point
info += 'min: %.6f ' % self.min
info += 'max: %.6f' % self.max
return info
# 作为具体量化层的父类,qi和qo分别为量化输入/输出
class QModule(nn.Module):
def __init__(self,quant_type, qi=True, qo=True, num_bits=8, e_bits=3):
super(QModule, self).__init__()
if qi:
self.qi = QParam(quant_type,num_bits, e_bits)
if qo:
self.qo = QParam(quant_type,num_bits, e_bits)
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
def freeze(self):
pass # 空语句
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
"""
QModule 量化卷积
:quant_type: 量化类型
:conv_module: 卷积模块
:qi: 是否量化输入特征图
:qo: 是否量化输出特征图
:num_bits: 8位bit数
"""
class QConv2d(QModule):
def __init__(self, quant_type, conv_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QConv2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
# freeze方法可以固定真量化的权重参数,并将该值更新到原全精度层上,便于散度计算
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
# 这里因为在池化或者激活的输入,不需要对最大值和最小是进行额外的统计,会共享相同的输出
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
# 根据https://zhuanlan.zhihu.com/p/156835141, 这是式3 的系数
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
self.conv_module.bias.data = quantize_tensor(self.quant_type,
self.conv_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0.,qmax=self.bias_qmax, is_bias=True)
def forward(self, x): # 前向传播,输入张量,x为浮点型数据
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi) # 对输入张量X完成量化
# foward前更新qw,保证量化weight时候scale正确
self.qw.update(self.conv_module.weight.data)
# 注意:此处主要为了统计各层x和weight范围,未对bias进行量化操作
tmp_wgt = FakeQuantize.apply(self.conv_module.weight, self.qw)
x = F.conv2d(x, tmp_wgt, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
# 利用公式 q_a = M(\sigma(q_w-Z_w)(q_x-Z_x) + q_b)
def quantize_inference(self, x): # 此处input为已经量化的qx
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QLinear(QModule):
def __init__(self, quant_type, fc_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QLinear, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.fc_module = fc_module
self.qw = QParam(quant_type, num_bits, e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data)
self.fc_module.weight.data = self.fc_module.weight.data - self.qw.zero_point
self.fc_module.bias.data = quantize_tensor(self.quant_type,
self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax, is_bias=True)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
self.qw.update(self.fc_module.weight.data)
tmp_wgt = FakeQuantize.apply(self.fc_module.weight, self.qw)
x = F.linear(x, tmp_wgt, self.fc_module.bias)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.fc_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QReLU(QModule):
def __init__(self,quant_type, qi=False, qo=True, num_bits=8, e_bits=3):
super(QReLU, self).__init__(quant_type, qi, qo, num_bits, e_bits)
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.relu(x)
return x
def quantize_inference(self, x):
x = x.clone()
x[x < self.qi.zero_point] = self.qi.zero_point
return x
class QMaxPooling2d(QModule):
def __init__(self, quant_type, kernel_size=3, stride=1, padding=0, qi=False, qo=True, num_bits=8,e_bits=3):
super(QMaxPooling2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
return x
def quantize_inference(self, x):
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
class QConvBNReLU(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QConvBNReLU, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
x.clamp_(min=0)
return x
\ No newline at end of file
VGG_16(
33.65 M, 100.000% Params, 333.36 MMac, 100.000% MACs,
(conv1): Conv2d(1.79 k, 0.005% Params, 1.84 MMac, 0.550% MACs, 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): BatchNorm2d(128, 0.000% Params, 131.07 KMac, 0.039% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(0, 0.000% Params, 65.54 KMac, 0.020% MACs, inplace=True)
(conv2): Conv2d(36.93 k, 0.110% Params, 37.81 MMac, 11.343% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(128, 0.000% Params, 131.07 KMac, 0.039% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(0, 0.000% Params, 65.54 KMac, 0.020% MACs, inplace=True)
(pool2): MaxPool2d(0, 0.000% Params, 65.54 KMac, 0.020% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv3): Conv2d(73.86 k, 0.220% Params, 18.91 MMac, 5.672% MACs, 64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn3): BatchNorm2d(256, 0.001% Params, 65.54 KMac, 0.020% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu3): ReLU(0, 0.000% Params, 32.77 KMac, 0.010% MACs, inplace=True)
(conv4): Conv2d(147.58 k, 0.439% Params, 37.78 MMac, 11.334% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn4): BatchNorm2d(256, 0.001% Params, 65.54 KMac, 0.020% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu4): ReLU(0, 0.000% Params, 32.77 KMac, 0.010% MACs, inplace=True)
(pool4): MaxPool2d(0, 0.000% Params, 32.77 KMac, 0.010% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv5): Conv2d(295.17 k, 0.877% Params, 18.89 MMac, 5.667% MACs, 128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn5): BatchNorm2d(512, 0.002% Params, 32.77 KMac, 0.010% MACs, 256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu5): ReLU(0, 0.000% Params, 16.38 KMac, 0.005% MACs, inplace=True)
(conv6): Conv2d(590.08 k, 1.754% Params, 37.77 MMac, 11.329% MACs, 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn6): BatchNorm2d(512, 0.002% Params, 32.77 KMac, 0.010% MACs, 256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu6): ReLU(0, 0.000% Params, 16.38 KMac, 0.005% MACs, inplace=True)
(conv7): Conv2d(590.08 k, 1.754% Params, 37.77 MMac, 11.329% MACs, 256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn7): BatchNorm2d(512, 0.002% Params, 32.77 KMac, 0.010% MACs, 256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu7): ReLU(0, 0.000% Params, 16.38 KMac, 0.005% MACs, inplace=True)
(pool7): MaxPool2d(0, 0.000% Params, 16.38 KMac, 0.005% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv8): Conv2d(1.18 M, 3.508% Params, 18.88 MMac, 5.664% MACs, 256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn8): BatchNorm2d(1.02 k, 0.003% Params, 16.38 KMac, 0.005% MACs, 512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu8): ReLU(0, 0.000% Params, 8.19 KMac, 0.002% MACs, inplace=True)
(conv9): Conv2d(2.36 M, 7.013% Params, 37.76 MMac, 11.326% MACs, 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn9): BatchNorm2d(1.02 k, 0.003% Params, 16.38 KMac, 0.005% MACs, 512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu9): ReLU(0, 0.000% Params, 8.19 KMac, 0.002% MACs, inplace=True)
(conv10): Conv2d(2.36 M, 7.013% Params, 37.76 MMac, 11.326% MACs, 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn10): BatchNorm2d(1.02 k, 0.003% Params, 16.38 KMac, 0.005% MACs, 512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu10): ReLU(0, 0.000% Params, 8.19 KMac, 0.002% MACs, inplace=True)
(pool10): MaxPool2d(0, 0.000% Params, 8.19 KMac, 0.002% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv11): Conv2d(2.36 M, 7.013% Params, 9.44 MMac, 2.832% MACs, 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn11): BatchNorm2d(1.02 k, 0.003% Params, 4.1 KMac, 0.001% MACs, 512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu11): ReLU(0, 0.000% Params, 2.05 KMac, 0.001% MACs, inplace=True)
(conv12): Conv2d(2.36 M, 7.013% Params, 9.44 MMac, 2.832% MACs, 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn12): BatchNorm2d(1.02 k, 0.003% Params, 4.1 KMac, 0.001% MACs, 512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu12): ReLU(0, 0.000% Params, 2.05 KMac, 0.001% MACs, inplace=True)
(conv13): Conv2d(2.36 M, 7.013% Params, 9.44 MMac, 2.832% MACs, 512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn13): BatchNorm2d(1.02 k, 0.003% Params, 4.1 KMac, 0.001% MACs, 512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu13): ReLU(0, 0.000% Params, 2.05 KMac, 0.001% MACs, inplace=True)
(pool13): MaxPool2d(0, 0.000% Params, 2.05 KMac, 0.001% MACs, kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(fc1): Linear(2.1 M, 6.245% Params, 2.1 MMac, 0.630% MACs, in_features=512, out_features=4096, bias=True)
(crelu1): ReLU(0, 0.000% Params, 4.1 KMac, 0.001% MACs, inplace=True)
(drop1): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.5, inplace=False)
(fc2): Linear(16.78 M, 49.875% Params, 16.78 MMac, 5.034% MACs, in_features=4096, out_features=4096, bias=True)
(crelu2): ReLU(0, 0.000% Params, 4.1 KMac, 0.001% MACs, inplace=True)
(drop2): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.5, inplace=False)
(fc3): Linear(40.97 k, 0.122% Params, 40.97 KMac, 0.012% MACs, in_features=4096, out_features=10, bias=True)
)
from torch.serialization import load
from model import *
from extract_ratio import *
from utils import *
import gol
import openpyxl
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import torch.utils.bottleneck as bn
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
def direct_quantize(model, test_loader,device):
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model.quantize_forward(data).cpu()
if i % 500 == 0:
break
print('direct quantization finish')
def full_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model(data).cpu()
pred = output.argmax(dim=1, keepdim=True)
# print(pred)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
return 100. * correct / len(test_loader.dataset)
def quantize_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model.quantize_inference(data).cpu()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('Test set: Quant Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
return 100. * correct / len(test_loader.dataset)
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
if __name__ == "__main__":
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = VGG_16()
writer = SummaryWriter(log_dir='./log')
full_file = 'ckpt/cifar10_VGG_16.pt'
model.load_state_dict(torch.load(full_file))
model.to(device)
load_ptq = True
store_ptq = False
ptq_file_prefix = 'ckpt/cifar10_VGG_16_ptq_'
model.eval()
full_acc = full_inference(model, test_loader, device)
model_fold = fold_model(model)
full_params = []
layer, par_ratio, flop_ratio = extract_ratio()
par_ratio, flop_ratio = fold_ratio(layer, par_ratio, flop_ratio)
for name, param in model_fold.named_parameters():
if 'bn' in name:
continue
param_norm = F.normalize(param.data.cpu(),p=2,dim=-1)
full_params.append(param_norm)
writer.add_histogram(tag='Full_' + name + '_data', values=param.data)
gol._init()
quant_type_list = ['INT','POT','FLOAT']
title_list = []
js_flops_list = []
js_param_list = []
ptq_acc_list = []
acc_loss_list = []
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
model_ptq = VGG_16()
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nPTQ: '+title)
title_list.append(title)
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
# 判断是否需要载入
if load_ptq is True and osp.exists(ptq_file_prefix + title + '.pt'):
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.load_state_dict(torch.load(ptq_file_prefix + title + '.pt'))
model_ptq.to(device)
print('Successfully load ptq model: ' + title)
else:
model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.eval()
direct_quantize(model_ptq, train_loader, device)
if store_ptq:
torch.save(model_ptq.state_dict(), ptq_file_prefix + title + '.pt')
model_ptq.freeze()
ptq_acc = quantize_inference(model_ptq, test_loader, device)
ptq_acc_list.append(ptq_acc)
acc_loss = (full_acc - ptq_acc) / full_acc
acc_loss_list.append(acc_loss)
idx = -1
# 获取计算量/参数量下的js-div
js_flops = 0.
js_param = 0.
for name, param in model_ptq.named_parameters():
if '.' not in name or 'bn' in name:
continue
idx = idx + 1
prefix = name.split('.')[0]
if prefix in layer:
layer_idx = layer.index(prefix)
ptq_param = param.data.cpu()
# 取L2范数
ptq_norm = F.normalize(ptq_param,p=2,dim=-1)
writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param)
js = js_div(ptq_norm,full_params[idx])
js = js.item()
if js < 0.:
js = 0.
js_flops = js_flops + js * flop_ratio[layer_idx]
js_param = js_param + js * flop_ratio[layer_idx]
js_flops_list.append(js_flops)
js_param_list.append(js_param)
print(title + ': js_flops: %f js_param: %f acc_loss: %f' % (js_flops, js_param, acc_loss))
# 写入xlsx
workbook = openpyxl.Workbook()
worksheet = workbook.active
worksheet.cell(row=1,column=1,value='FP32-acc')
worksheet.cell(row=1,column=2,value=full_acc)
worksheet.cell(row=3,column=1,value='title')
worksheet.cell(row=3,column=2,value='js_flops')
worksheet.cell(row=3,column=3,value='js_param')
worksheet.cell(row=3,column=4,value='ptq_acc')
worksheet.cell(row=3,column=5,value='acc_loss')
for i in range(len(title_list)):
worksheet.cell(row=i+4, column=1, value=title_list[i])
worksheet.cell(row=i+4, column=2, value=js_flops_list[i])
worksheet.cell(row=i+4, column=3, value=js_param_list[i])
worksheet.cell(row=i+4, column=4, value=ptq_acc_list[i])
worksheet.cell(row=i+4, column=5, value=acc_loss_list[i])
workbook.save('ptq_result.xlsx')
writer.close()
ft = open('ptq_result.txt','w')
print('title_list:',file=ft)
print(" ".join(title_list),file=ft)
print('js_flops_list:',file=ft)
print(" ".join(str(i) for i in js_flops_list), file=ft)
print('js_param_list:',file=ft)
print(" ".join(str(i) for i in js_param_list), file=ft)
print('ptq_acc_list:',file=ft)
print(" ".join(str(i) for i in ptq_acc_list), file=ft)
print('acc_loss_list:',file=ft)
print(" ".join(str(i) for i in acc_loss_list), file=ft)
ft.close()
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
9536.47106469081 2226.0626158889118 479.08937782929337 110.29737767616203 26.512546399733633 6.543175037098545 1.6082547275660537 0.4010994730450761 0.09957296685951723 0.025111061601978204 0.006197339062824894 0.0015193397317902462 0.0003899318133136797 7.940267844549078e-05 5.244585611005781e-05 9536.459291644247 1346.1955440489223 186.04289013449343 184.66787552474796 184.66837148575442 184.66792146213083 184.66817620985174 1162.9049833811032 334.8873847416959 213.5935420790586 162.9083436672534 74.46976163903415 51.09336851340675 114.00986836720416 39.88037938309794 13.180682125861468 50.9202909583944 97.0445174571127 28.848990052727817 3.3335250261221736 13.120693891064953 50.939926759367474 90.27653845386365 24.693794509374676 0.8539888374026919 3.300766001642734 13.138977654051457 50.93992491304259
js_param_list:
9536.47106469081 2226.0626158889118 479.08937782929337 110.29737767616203 26.512546399733633 6.543175037098545 1.6082547275660537 0.4010994730450761 0.09957296685951723 0.025111061601978204 0.006197339062824894 0.0015193397317902462 0.0003899318133136797 7.940267844549078e-05 5.244585611005781e-05 9536.459291644247 1346.1955440489223 186.04289013449343 184.66787552474796 184.66837148575442 184.66792146213083 184.66817620985174 1162.9049833811032 334.8873847416959 213.5935420790586 162.9083436672534 74.46976163903415 51.09336851340675 114.00986836720416 39.88037938309794 13.180682125861468 50.9202909583944 97.0445174571127 28.848990052727817 3.3335250261221736 13.120693891064953 50.939926759367474 90.27653845386365 24.693794509374676 0.8539888374026919 3.300766001642734 13.138977654051457 50.93992491304259
ptq_acc_list:
10.0 12.67 51.93 86.38 88.89 89.24 89.57 89.51 89.54 89.46 89.43 89.42 89.43 89.44 89.44 10.0 19.8 68.72 63.57 64.2 68.3 64.97 14.62 64.76 80.17 78.59 87.49 86.76 82.86 88.38 88.85 75.46 84.16 88.8 89.27 69.36 10.25 84.8 88.95 89.4 65.43 10.38 10.32
acc_loss_list:
0.8881932021466905 0.8583407871198568 0.41938729874776387 0.034212880143112724 0.00614937388193199 0.0022361359570662216 -0.0014534883720929725 -0.000782647584973249 -0.0011180679785331902 -0.00022361359570657448 0.0001118067978532078 0.00022361359570657448 0.0001118067978532078 0.0 0.0 0.8881932021466905 0.7786225402504473 0.23166368515205724 0.2892441860465116 0.28220035778175306 0.23635957066189625 0.2735912343470483 0.8365384615384615 0.27593917710196775 0.10364490161001785 0.12131037567084073 0.02180232558139538 0.02996422182468686 0.07356887298747762 0.01185152057245083 0.006596601073345297 0.1563059033989267 0.05903398926654742 0.007155635062611813 0.0019007155635062804 0.22450805008944544 0.8853980322003577 0.051878354203935606 0.005478533094812108 0.00044722719141314897 0.26844812164579596 0.8839445438282648 0.8846153846153847
from model import *
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += lossLayer(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
test_loss, 100. * correct / len(test_loader.dataset)
))
if __name__ == "__main__":
batch_size = 32
test_batch_size = 32
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
momentum = 0.5
save_model = True
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = VGG_16().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
if save_model:
if not osp.exists('ckpt'):
os.makedirs('ckpt')
torch.save(model.state_dict(), 'ckpt/cifar10_VGG_16.pt')
\ No newline at end of file
import torch
import torch.nn as nn
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
e_bit_list = list(range(1,num_bits-1))
else:
e_bit_list = [0]
return e_bit_list
def numbit_list(quant_type):
if quant_type == 'INT':
num_bit_list = list(range(2,17))
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
# num_bit_list = [8]
else:
num_bit_list = list(range(2,9))
# num_bit_list = [8]
return num_bit_list
def build_bias_list(quant_type):
if quant_type == 'POT':
return build_pot_list(8)
else:
return build_float_list(16,7)
def build_list(quant_type, num_bits, e_bits):
if quant_type == 'POT':
return build_pot_list(num_bits)
else:
return build_float_list(num_bits,e_bits)
def build_pot_list(num_bits):
plist = [0.]
for i in range(-2 ** (num_bits-1) + 2, 1):
# i最高到0,即pot量化最大值为1
plist.append(2. ** i)
plist.append(-2. ** i)
plist = torch.Tensor(list(set(plist)))
# plist = plist.mul(1.0 / torch.max(plist))
return plist
def build_float_list(num_bits,e_bits):
m_bits = num_bits - 1 - e_bits
plist = [0.]
# 相邻尾数的差值
dist_m = 2 ** (-m_bits)
e = -2 ** (e_bits - 1) + 1
for m in range(1, 2 ** m_bits):
frac = m * dist_m # 尾数部分
expo = 2 ** e # 指数部分
flt = frac * expo
plist.append(flt)
plist.append(-flt)
for e in range(-2 ** (e_bits - 1) + 2, 2 ** (e_bits - 1) + 1):
expo = 2 ** e
for m in range(0, 2 ** m_bits):
frac = 1. + m * dist_m
flt = frac * expo
plist.append(flt)
plist.append(-flt)
plist = torch.Tensor(list(set(plist)))
return plist
def fold_ratio(layer, par_ratio, flop_ratio):
idx = -1
for name in layer:
idx = idx + 1
if 'bn' in name:
par_ratio[idx-1] += par_ratio[idx]
flop_ratio[idx-1] += flop_ratio[idx]
return par_ratio,flop_ratio
def fold_model(model):
idx = -1
module_list = []
for name, module in model.named_modules():
idx += 1
module_list.append(module)
if 'bn' in name:
module_list[idx-1] = fold_bn(module_list[idx-1],module)
return model
# def fold_model(model):
# last_conv = None
# last_bn = None
# for name, module in model.named_modules():
# if isinstance(module, nn.Conv2d):
# # 如果当前模块是卷积层,则将其 "fold" 到上一个 BN 层中
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# last_conv = module
# elif isinstance(module, nn.BatchNorm2d):
# # 如果当前模块是 BN 层,则将其 "fold" 到上一个卷积层中
# last_bn = module
# if last_conv is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# # 处理最后一个 BN 层
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# return model
def fold_bn(conv, bn):
# 获取 BN 层的参数
gamma = bn.weight.data
beta = bn.bias.data
mean = bn.running_mean
var = bn.running_var
eps = bn.eps
std = torch.sqrt(var + eps)
feat = bn.num_features
# 获取卷积层的参数
weight = conv.weight.data
bias = conv.bias.data
if bn.affine:
gamma_ = gamma / std
weight = weight * gamma_.view(feat, 1, 1, 1)
if bias is not None:
bias = gamma_ * bias - gamma_ * mean + beta
else:
bias = beta - gamma_ * mean
else:
gamma_ = 1 / std
weight = weight * gamma_
if bias is not None:
bias = gamma_ * bias - gamma_ * mean
else:
bias = -gamma_ * mean
# 设置新的 weight 和 bias
conv.weight.data = weight
conv.bias.data = bias
return conv
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment