Commit 217de88a by Zhihong Ma

fix: LSTM-OCR fp32 & naive fake quantization

parent 70477dab
......@@ -15,8 +15,8 @@ class seq_mnist(Dataset):
self.labels = []
self.input_lengths = np.ones(1, dtype=np.int32) * (28 * self.trainer_params.word_size)
self.label_lengths = np.ones(1, dtype=np.int32) * (self.trainer_params.word_size)
self.build_dataset()
# self.load_dataset()
# self.build_dataset()
self.load_dataset()
def build_dataset(self):
imgs = []
......@@ -57,12 +57,13 @@ class seq_mnist(Dataset):
def load_dataset(self):
self.images = np.load('data/images{}.npy'.format(self.suffix))
self.labels = np.load('data/labels{}.npy'.format(self.suffix))
print("Successfully load dataset!")
if self.trainer_params.quantize_input:
self.images = self.quantize_tensor_image(self.images)
self.images = np.asarray(self.images)
# 这里无需单独做
# 这里无需单独做, ptq时修改
def quantize_tensor_image(self, tensor_image):
frac_bits = self.trainer_params.recurrent_activation_bit_width-1
prescale = 2**frac_bits
......
import sys
import os
# 从get_param.py输出重定向文件val.txt中提取参数量和计算量
def extract_ratio(md='ResNet18'):
fr = open('param_flops_' + md + '.txt','r')
lines = fr.readlines()
layer = []
par_ratio = []
flop_ratio = []
for line in lines:
# if '(' in line and ')' in line:
if 'Conv' in line or 'BatchNorm2d' in line or 'Linear' in line:
layer.append(line.split(':')[1].split('(')[0])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
par_ratio.append(r1)
r2 = line.split('%')[-2].split(',')[-1]
r2 = float(r2)
flop_ratio.append(r2)
return layer, par_ratio, flop_ratio
if __name__ == "__main__":
layer, par_ratio, flop_ratio = extract_ratio()
print(len(layer))
print(len(par_ratio))
print(len(flop_ratio))
\ No newline at end of file
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
\ No newline at end of file
from model import *
import torch
from ptflops import get_model_complexity_info
import argparse
def get_children(model: torch.nn.Module):
# get children form model!
# 为了后续也能够更新参数,需要用nn.ModuleList来承载
# children = nn.ModuleList(model.children())
# print(children)
# 方便对其中的module进行后续的更新
# flatt_children = nn.ModuleList()
children = list(model.children())
# flatt_children = nn.ModuleList()
flatt_children = []
if len(children) == 0:
# if model has no children; model is last child! :O
return model
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
# print(flatt_children)
return flatt_children
# 定义获取不包含wrapper的所有子模块的函数
def get_all_child_modules(module):
for name, child in module.named_children():
if isinstance(child, nn.Sequential):
yield from get_all_child_modules(child)
elif len(list(child.children())) > 0:
yield from child.children()
else:
yield child
def filter_fn(module, n_inp, outp_shape):
# if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d, torch.nn.ReLU,torch.nn.BatchNorm2d,torch.nn.Linear,torch.nn.AdaptiveAvgPool2d)):
if 'conv' in module or 'bn' in module or 'fc' in module or 'avg' in module or 'relu' in module:
return True
return False
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Analysis --- params & flops')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
args = parser.parse_args()
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
full_file = 'ckpt/cifar10_' + args.model + '.pt'
model.load_state_dict(torch.load(full_file))
# flat = get_children(model)
# print(flat)
# flat = get_children(model)
# new_model = nn.Sequential(*flat)
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
class GlobalVariables:
SELF_INPLANES = 0
\ No newline at end of file
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
......@@ -36,7 +36,7 @@ def non_or_str(value):
if __name__ == '__main__':
# Training settings
parser = argparse.ArgumentParser(description='PyTorch Quantized BiLSTM Sequential MNIST Example')
parser = argparse.ArgumentParser(description='PyTorch BiLSTM Sequential MNIST Example')
parser.add_argument('--params', '-p', type=str, default="default_trainer_params.json", help='Path to params JSON file. Default ignored when resuming.')
# 这里是可以改的,原版本应该是支持多机训练
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
......@@ -69,18 +69,7 @@ if __name__ == '__main__':
parser.add_argument('--reduce_bidirectional', type=str)
parser.add_argument('--recurrent_bias_enabled', type=bool)
parser.add_argument('--checkpoint_interval', type=int)
parser.add_argument('--recurrent_weight_bit_width', type=int)
parser.add_argument('--recurrent_weight_quantization', type=str)
parser.add_argument('--recurrent_bias_bit_width', type=int)
parser.add_argument('--recurrent_bias_quantization', type=str)
parser.add_argument('--recurrent_activation_bit_width', type=int)
parser.add_argument('--recurrent_activation_quantization', type=str)
parser.add_argument('--internal_activation_bit_width', type=int)
parser.add_argument('--fc_weight_bit_width', type=int)
parser.add_argument('--fc_weight_quantization', type=str)
parser.add_argument('--fc_bias_bit_width', type=int)
parser.add_argument('--fc_bias_quantization', type=str)
parser.add_argument('--quantize_input', type=bool)
args = parser.parse_args()
if args.export:
......@@ -92,18 +81,23 @@ if __name__ == '__main__':
# 直接恢复
if (args.resume or args.eval or args.export) and args.params == "default_trainer_params.json":
package = torch.load(args.resume, map_location=lambda storage, loc: storage)
trainer_params = package['trainer_params']
# if (args.resume or args.eval) and args.params == "default_trainer_params.json":
# package = torch.load(args.resume, map_location=lambda storage, loc: storage)
# trainer_params = package['trainer_params']
# 重新训练
else:
# else:
with open(args.params) as d:
trainer_params = json.load(d, object_hook=ascii_encode_dict)
# trainer_params = json.load(d, object_hook=ascii_encode_dict)
trainer_params = json.load(d)
trainer_params = objdict(trainer_params)
for k in trainer_params.keys():
print(k, trainer_params[k])
if trainer_params[k] == 'LSTM':
print("LSTM YES")
elif trainer_params[k] == 'CONCAT':
print("CONCAT YES")
# args还是有用的,trainer_params中的default的和args中关注的参数往往是互补的
trainer = Seq_MNIST_Trainer(trainer_params, args)
......
......@@ -32,7 +32,7 @@ import math
import numpy
import torch
import torch.nn as nn
from module import *
from functools import partial
# from quantization.modules.rnn import QuantizedLSTM
......@@ -90,13 +90,13 @@ class BiLSTM(nn.Module):
def __init__(self, trainer_params):
super(BiLSTM, self).__init__()
self.trainer_params = trainer_params
# print(self.trainer_params.reduce_bidirectional)
self.trainer_params.reduce_bidirectional = 'CONCAT'
print(f"self.trainer_params.reduce_bidirectional:{self.trainer_params.reduce_bidirectional}")
# self.trainer_params.reduce_bidirectional = 'CONCAT'
if self.trainer_params.bidirectional and self.trainer_params.reduce_bidirectional == 'CONCAT':
self.reduce_factor = 2
else:
self.reduce_factor = 1
# if self.trainer_params.bidirectional and self.trainer_params.reduce_bidirectional == 'CONCAT':
# self.reduce_factor = 2
# else:
# self.reduce_factor = 1
# 若是 LSTM ,则括号中的是对类的输入设置
# self.recurrent_layer = self.recurrent_layer_type(input_size=self.trainer_params.input_size,
# hidden_size=self.trainer_params.num_units,
......@@ -105,39 +105,59 @@ class BiLSTM(nn.Module):
# bidirectional=self.trainer_params.bidirectional,
# bias=self.trainer_params.recurrent_bias_enabled)
self.recurrent_layer = nn.LSTM(input_size=self.trainer_params.input_size,
# self.recurrent_layer = nn.LSTM(input_size=self.trainer_params.input_size,
# hidden_size=self.trainer_params.num_units,
# num_layers=self.trainer_params.num_layers,
# batch_first=False,
# bidirectional=self.trainer_params.bidirectional,
# bias=self.trainer_params.recurrent_bias_enabled)
self.lstm_layers = nn.ModuleList()
# 创建第1层LSTM模型,并添加到ModuleList中
lstm = nn.LSTM( input_size=self.trainer_params.input_size,
hidden_size=self.trainer_params.num_units,
num_layers=1,
batch_first=False,
bidirectional=self.trainer_params.bidirectional,
bias=self.trainer_params.recurrent_bias_enabled)
self.lstm_layers.append(lstm)
# 创建第2至num_layers层LSTM模型,并添加到ModuleList中
for i in range(1, self.trainer_params.num_layers):
lstm = nn.LSTM(input_size=self.trainer_params.num_units * 2 if self.trainer_params.bidirectional else self.trainer_params.num_units,
hidden_size=self.trainer_params.num_units,
num_layers=self.trainer_params.num_layers,
num_layers=1,
batch_first=False,
bidirectional=self.trainer_params.bidirectional,
bias=self.trainer_params.recurrent_bias_enabled)
self.lstm_layers.append(lstm)
self.batch_norm_fc = FusedBatchNorm1dLinear(
trainer_params,
nn.BatchNorm1d(self.reduce_factor * self.trainer_params.num_units),
# QuantizedLinear(
# bias=True,
# self.batch_norm_fc = FusedBatchNorm1dLinear(
# trainer_params,
# nn.BatchNorm1d(self.reduce_factor * self.trainer_params.num_units),
# nn.Linear(
# in_features=self.reduce_factor * self.trainer_params.num_units,
# out_features=trainer_params.num_classes,
# bias_bit_width=self.trainer_params.fc_bias_bit_width,
# bias_q_type=self.trainer_params.fc_bias_quantization,
# weight_bit_width=self.trainer_params.fc_weight_bit_width,
# weight_q_type=self.trainer_params.fc_weight_quantization)
nn.Linear(
# bias=True )
# )
self.fc1 = nn.Linear(
in_features=self.reduce_factor * self.trainer_params.num_units,
out_features=trainer_params.num_classes,
bias=True )
)
# self.output_layer = nn.Sequential(SequenceWise(self.batch_norm_fc), nn.LogSoftmax(dim=2))
self.output_layer = nn.Sequential(SequenceWise(self.fc1), nn.LogSoftmax(dim=2))
self.output_layer = nn.Sequential(SequenceWise(self.batch_norm_fc), nn.LogSoftmax(dim=2))
# @property
# def reduce_factor(self):
# if self.trainer_params.bidirectional and self.trainer_params.reduce_bidirectional == 'CONCAT':
# return 2
# else:
# return 1
@property
def reduce_factor(self):
if self.trainer_params.bidirectional and self.trainer_params.reduce_bidirectional == 'CONCAT':
return 2
else:
return 1
# @property
# def recurrent_layer_type(self):
......@@ -163,7 +183,24 @@ class BiLSTM(nn.Module):
def forward(self, x):
# 似乎是因为现在只有一个lstm cell (num_layers = 1),所以h没用上
x, h = self.recurrent_layer(x)
# x, h = self.recurrent_layer(x)
h_n = []
c_n = []
# 遍历ModuleList中的每个LSTM模型,依次进行前向计算
for i, lstm in enumerate(self.lstm_layers):
# 如果不是第1层LSTM,则将输入的隐藏状态和细胞状态作为该层LSTM的初始状态
if i > 0:
x, (h, c) = lstm(x, (h_n[-1], c_n[-1]))
else:
x, (h, c) = lstm(x)
# 将该层LSTM的隐藏状态和细胞状态添加到列表中,用于下一层LSTM的输入
h_n.append(h)
c_n.append(c)
if self.trainer_params.bidirectional:
if self.trainer_params.reduce_bidirectional == 'SUM':
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1)
......@@ -175,6 +212,78 @@ class BiLSTM(nn.Module):
x = self.output_layer(x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
self.qlstm_layers = nn.ModuleDict()
for i, lstm in enumerate(self.lstm_layers):
# 如果不是第1层LSTM,则将输入的隐藏状态和细胞状态作为该层LSTM的初始状态
if i > 0:
self.qlstm_layers[str(i)] = QLSTM(quant_type=quant_type,lstm_module=lstm,qix=False,qih=False,qic=False,qox=True,qoh=True,qoc=True,num_bits=num_bits,e_bits=e_bits)
# 第一层lstm layer没有输入的h和c,因此qih,qic为False,有x,qix置为True
else:
self.qlstm_layers[str(i)] = QLSTM(quant_type=quant_type,lstm_module=lstm,qix=True,qih=False,qic=False,qox=True,qoh=True,qoc=True,num_bits=num_bits,e_bits=e_bits)
self.qfc1 = QLinear(quant_type, self.fc1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
# for name,layer in self.qlstm_layers.items():
# print(f"name:{name}")
def quantize_forward(self, x):
for name, layer in self.qlstm_layers.items():
if '0' in name:
x,(h,c) = layer(x)
else:
x,(h,c) = layer(x,h,c)
t, n = x.size(0), x.size(1)
x = x.view(t * n, -1)
x = self.qfc1(x)
x = x.view(t, n, -1)
x = F.log_softmax(x,dim=2)
# out = F.softmax(x, dim=1)
# return out
return x
def freeze(self):
for name, layer in self.qlstm_layers.items():
if '0' in name:
layer.freeze(flag=0)
else:
layer.freeze(qix = self.qlstm_layers[str(int(name)-1)].qox, qih=self.qlstm_layers[str(int(name)-1)].qoh, qic=self.qlstm_layers[str(int(name)-1)].qoc,flag=1)
self.qfc1.freeze(qi=self.qlstm_layers[name].qox)
def quantize_inference(self, x):
# 首先对x进行一个伪量化 (适配于lstm的伪量化)
x = FakeQuantize.apply(x,self.qlstm_layers['0'].qix)
for name, layer in self.qlstm_layers.items():
if '0' in name:
x,(h,c) = layer.quantize_inference(x)
else:
x,(h,c) = layer.quantize_inference(x,h,c)
t, n = x.size(0), x.size(1)
x = x.view(t * n, -1)
# 经过修改后的QLinear的quantize_inference中对输入的x进行过quantize,因此在这里需要dequantize一下.
x = self.qfc1.quantize_inference(x)
x = self.qfc1.qo.dequantize_tensor(x)
x = x.view(t, n, -1)
x = F.log_softmax(x,dim=2)
return x
def export(self, output_path, simd_factor, pe):
if self.trainer_params.neuron_type == 'QLSTM':
assert(self.trainer_params.input_size % simd_factor == 0)
......
This source diff could not be displayed because it is too large. You can view the blob instead.
## update 2023.5.4
### 对fp32的模型进行了改进,并进行了初步的PTQ实验(只做到了对部分参数伪量化 naive fakequantization)
1. 对fp32的改进
- 支持多层LSTM,并用nn.ModuleList组织
2. 对PTQ的尝试
- 对LSTM的结构,数据流向,输入输出有了更细致的理解。PTQ遇到的主要问题有:
- BiLSTM涉及到了双向的output,需要用SUM或者CONCAT处理,PTQ时可能需要引入更多module处理。
- LSTM内部进行的运算较为复杂,如下图所示:<img src = "fig/math.png" class="h-90 auto">
首先涉及到了多个门i,f,g,他们内部有Wx+b+Wh+b的结构,他们的scale是否整体考虑(即,是共用一个scale,还是多个scale并在+的时候做rescale)是一个问题。<br>他们外部的sigmoid或tanh也是一个问题(会导致不方便在各个层间通过scale变换保证PTQ结果的正确性)。<br>对c',h'的更新涉及到的*,+也是一个问题。
- LSTM内部各个门的权值被组织在`weight_ih_l[k]`、`weight_hh_l[k]`、`bias_ih_l[k]`和`bias_hh_l[k]`中,分属在不同行中,是否应该把这个权值矩阵拆开,分别量化后再合并。
- 暂时还没想好怎么处理上述问题,于是进行了简单的尝试性实验:
- 先只处理单向的LSTM,不考虑双向所需的SUM和CONCAT
-`weight_ih_l[k]``weight_hh_l[k]``bias_ih_l[k]``bias_hh_l[k]`各自整体量化,没有再对每个张量切成四块再分别对`weight_ir_l[k]`,`weight_hr_l[k]`,`bias_r_l[k]`,`weight_if_l[k]`,`weight_hf_l[k]`,`bias_f_l[k]`,`weight_ii_l[k]`,`weight_hi_l[k]`进行量化
- 对weight和bias和每层的输出tensor采取伪量化的方法,避免层间的scale变换。<br><br>后果:<br> (1) freeze后的权值不再是量化后的值,而是其量化值经过scale变换后的值
<br>(2) 没有考虑tanh,sigmoid的量化,他们不再是矩阵乘的形式,还没想好怎么消掉scale。只模拟了将权值和输出张量量化产生的rounding,溢出等误差。
<br>(3) 与实际可直接部署到硬件的lstm量化可能有差异。(不过目前我不太确定LSTM实际的量化应该怎么做,在网上没有找到很多lstm量化相关的资料)
<br>(4) 与之前的其他网络的PTQ量化略有差异,与scale相关的的计算顺序有些不同,但运算原理是类似的。
- 有待改进:
- 对LSTM的PTQ量化的更真实模拟,考虑sigmoid,tanh,各种乘加组合
- 补充对BiLSTM的PTQ量化 (如果还按现在的简化版处理方式,BiLSTM很容易实现,因为不需要考虑各种scale方面的问题,直接SUM或者CONCAT即可)
- 可以考虑把`weight_ih_l[k]``weight_hh_l[k]``bias_ih_l[k]``bias_hh_l[k]`拆开,按各自门的权值分别伪量化后再组合
- 使用更复杂的数据集
- 度量相似度
- 找到比较好的指标来度量精度
## update 2023.5.2
basic version: FP32版本,只有单个lstm cell,训练数据集采用序列化的MNIST,仅作记录方便后续修改。
\ No newline at end of file
import os
import math
import numpy
import torch
import torch.nn as nn
from model import *
import argparse
import json
# input_size = 32
# num_units = 128
# num_layers = 1
# bidirectional = True
# recurrent_bias_enabled = True
# lstm1 = nn.LSTM(input_size=input_size,
# hidden_size=num_units,
# num_layers=num_layers,
# batch_first=False,
# bidirectional= bidirectional,
# bias= recurrent_bias_enabled)
# lstm2 = nn.LSTM(input_size=input_size,
# hidden_size=num_units,
# num_layers=num_layers + 1,
# batch_first=False,
# bidirectional= bidirectional,
# bias= recurrent_bias_enabled)
# print("LSTM1:")
# for name,params in lstm1.named_parameters():
# print(f"name:{name},params:{params.shape}")
# print("=============================================")
# print("LSTM2:")
# for name,params in lstm2.named_parameters():
# print(f"name:{name},params:{params.shape}")
class objdict(dict):
def __getattr__(self, name):
if name in self:
return self[name]
else:
raise AttributeError("No such attribute: " + name)
def __setattr__(self, name, value):
self[name] = value
def __delattr__(self, name):
if name in self:
del self[name]
else:
raise AttributeError("No such attribute: " + name)
parser = argparse.ArgumentParser(description='PyTorch BiLSTM Sequential MNIST Example')
parser.add_argument('--params', '-p', type=str, default="default_trainer_params.json", help='Path to params JSON file. Default ignored when resuming.')
args = parser.parse_args()
with open(args.params) as d:
trainer_params = json.load(d)
# trainer_params = json.load(d, object_hook=ascii_encode_dict)
trainer_params = objdict(trainer_params)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 加入设备选择
model = BiLSTM(trainer_params).to(device)
model.quantize('INT',8,0)
......@@ -38,12 +38,29 @@ class Seq_MNIST_Trainer():
self.prev_loss = 10000
self.model = BiLSTM(trainer_params)
# self.criterion = wp.CTCLoss(size_average=False)
self.criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)
self.labels = [i for i in range(trainer_params.num_classes-1)]
self.decoder = seq_mnist_decoder(labels=self.labels)
self.optimizer = optim.Adam(self.model.parameters(), lr=trainer_params.lr)
# self.criterion = wp.CTCLoss(size_average=False)
# 默认为false
# if args.init_bn_fc_fusion:
# # 默认是false的,应该是用于记录当前是否fuse了吧
# if not trainer_params.prefused_bn_fc:
# self.model.batch_norm_fc.init_fusion() # fuse bn-fc
# self.trainer_params.prefused_bn_fc = True # 已fuse
# else:
# raise Exception("BN and FC are already fused.")
# 先fuse了再load fuse后的
if args.eval or args.resume :
save_dir = 'ckpt'
full_file = save_dir + '/mnist_' + self.trainer_params.reduce_bidirectional +'_' + str(self.trainer_params.bidirectional) + '.pt'
self.model.load_state_dict(torch.load(full_file))
print("load Model from existing file finished!")
if args.cuda:
# torch.cuda.set_device(args.gpus)
......@@ -53,29 +70,6 @@ class Seq_MNIST_Trainer():
self.model = self.model.to(device)
self.criterion = self.criterion.to(device)
if args.resume or args.eval or args.export:
print("Loading model from {}".format(args.resume))
package = torch.load(args.resume, map_location=lambda storage, loc: storage)
self.model.load_state_dict(package['state_dict'])
self.optimizer.load_state_dict(package['optim_dict'])
self.starting_epoch = package['starting_epoch']
self.prev_loss = package['prev_loss']
if args.cuda:
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
# state[k] = v.cuda()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state[k] = v.to(device)
# 默认为false
if args.init_bn_fc_fusion:
# 默认是false的,应该是用于记录当前是否fuse了吧
if not trainer_params.prefused_bn_fc:
self.model.batch_norm_fc.init_fusion()
self.trainer_params.prefused_bn_fc = True
else:
raise Exception("BN and FC are already fused.")
def serialize(self, model, trainer_params, optimizer, starting_epoch, prev_loss):
package = {'state_dict': model.state_dict(),
......@@ -86,11 +80,16 @@ class Seq_MNIST_Trainer():
}
return package
# 存储
def save_model(self, epoch, name):
path = self.args.experiments + '/' + name
print("Model saved at: {}\n".format(path))
torch.save(self.serialize(model=self.model, trainer_params=self.trainer_params,
optimizer=self.optimizer, starting_epoch=epoch + 1, prev_loss=self.prev_loss), path)
def save_model(self):
save_dir = 'ckpt'
if not os.path.isdir(save_dir):
os.makedirs(save_dir, mode=0o777)
os.chmod(save_dir, mode=0o777)
# path = self.args.experiments + '/' + name
torch.save(self.model.state_dict(), save_dir + '/mnist_' + self.trainer_params.reduce_bidirectional +'_' + str(self.trainer_params.bidirectional) + '.pt')
# print("Model saved at: {}\n".format(path))
# torch.save(self.serialize(model=self.model, trainer_params=self.trainer_params,
# optimizer=self.optimizer, starting_epoch=epoch + 1, prev_loss=self.prev_loss), path)
def train(self, epoch):
......@@ -163,9 +162,9 @@ class Seq_MNIST_Trainer():
print("Average Loss Value for Val Data is = {:.4f}\n".format(float(loss_value)))
if loss_value < self.prev_loss and save_model_flag:
self.prev_loss = loss_value
self.save_model(epoch, "best.tar")
elif save_model_flag:
self.save_model(epoch, "checkpoint.tar")
self.save_model()
# elif save_model_flag:
# self.save_model(epoch, "checkpoint.tar")
def eval_model(self):
self.test()
......@@ -173,7 +172,7 @@ class Seq_MNIST_Trainer():
def train_model(self):
for epoch in range(self.starting_epoch, self.args.epochs + 1):
self.train(epoch)
self.test(epoch=epoch, save_model_flag=False)
acc = self.test(epoch=epoch, save_model_flag=True) # 默认不save model (等做ptq的实验时在处理)
if epoch%20==0:
self.optimizer.param_groups[0]['lr'] = self.optimizer.param_groups[0]['lr']*0.98
......
import torch
import torch.nn as nn
import torch.nn.functional as F
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
e_bit_list = list(range(1,num_bits-1))
else:
e_bit_list = [0]
return e_bit_list
def numbit_list(quant_type):
if quant_type == 'INT':
num_bit_list = list(range(2,17))
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
else:
num_bit_list = list(range(2,9))
# num_bit_list = [8]
return num_bit_list
def build_bias_list(quant_type):
if quant_type == 'POT':
return build_pot_list(8) #
else:
return build_float_list(16,7)
def build_list(quant_type, num_bits, e_bits):
if quant_type == 'POT':
return build_pot_list(num_bits)
else:
return build_float_list(num_bits,e_bits)
def build_pot_list(num_bits):
plist = [0.]
for i in range(-2 ** (num_bits-1) + 2, 1):
# i最高到0,即pot量化最大值为1
plist.append(2. ** i)
plist.append(-2. ** i)
plist = torch.Tensor(list(set(plist)))
# plist = plist.mul(1.0 / torch.max(plist))
return plist
def build_float_list(num_bits,e_bits):
m_bits = num_bits - 1 - e_bits
plist = [0.]
# 相邻尾数的差值
dist_m = 2 ** (-m_bits)
e = -2 ** (e_bits - 1) + 1
for m in range(1, 2 ** m_bits):
frac = m * dist_m # 尾数部分
expo = 2 ** e # 指数部分
flt = frac * expo
plist.append(flt)
plist.append(-flt)
for e in range(-2 ** (e_bits - 1) + 2, 2 ** (e_bits - 1) + 1):
expo = 2 ** e
for m in range(0, 2 ** m_bits):
frac = 1. + m * dist_m
flt = frac * expo
plist.append(flt)
plist.append(-flt)
plist = torch.Tensor(list(set(plist)))
return plist
def fold_ratio(layer, par_ratio, flop_ratio):
idx = -1
for name in layer:
idx = idx + 1
# layer是for name, param in model.named_parameters()中提取出来的,一定是有downsample的
if 'bn' in name or 'sample.1' in name:
par_ratio[idx-1] += par_ratio[idx]
flop_ratio[idx-1] += flop_ratio[idx]
return par_ratio,flop_ratio
def fold_model(model):
idx = -1
module_list = []
# print('fold model:')
for name, module in model.named_modules():
# print(name+'-- +')
idx += 1
module_list.append(module)
# 这里之前忘记考虑downsampl里的conv了,导致少融合了一些
if 'bn' in name or 'sample.1' in name:
# print(name+'-- *')
module_list[idx-1] = fold_bn(module_list[idx-1],module) # 在这里修改了
return model
# def fold_model(model):
# last_conv = None
# last_bn = None
# for name, module in model.named_modules():
# if isinstance(module, nn.Conv2d):
# # 如果当前模块是卷积层,则将其 "fold" 到上一个 BN 层中
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# last_conv = module
# elif isinstance(module, nn.BatchNorm2d):
# # 如果当前模块是 BN 层,则将其 "fold" 到上一个卷积层中
# last_bn = module
# if last_conv is not None:
# last_conv = fold_bn(last_conv, last_bn)
# last_bn = None
# # 处理最后一个 BN 层
# if last_bn is not None:
# last_conv = fold_bn(last_conv, last_bn)
# return model
def fold_bn(conv, bn):
# 获取 BN 层的参数
gamma = bn.weight.data
beta = bn.bias.data
mean = bn.running_mean
var = bn.running_var
eps = bn.eps
std = torch.sqrt(var + eps)
feat = bn.num_features
# 获取卷积层的参数
weight = conv.weight.data
if conv.bias is not None:
bias = conv.bias.data
if bn.affine:
gamma_ = gamma / std
weight = weight * gamma_.view(feat, 1, 1, 1)
if conv.bias is not None:
bias = gamma_ * bias - gamma_ * mean + beta
else:
bias = beta - gamma_ * mean
else:
gamma_ = 1 / std
weight = weight * gamma_
if conv.bias is not None:
bias = gamma_ * bias - gamma_ * mean
else:
bias = -gamma_ * mean
# 设置新的 weight 和 bias
conv.weight.data = weight
# 适用于bias=none的
if conv.bias is None:
conv.bias = nn.Parameter(bias)
else:
conv.bias.data = bias
return conv
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment