Commit 061582e9 by Klin

feat: PTB_LSTM: inference with fake_quantize

parent f4b96743
Warning! No positional inputs found for a module, assuming batch size is 1.
Model(
14.86 M, 100.000% Params, 5.2 GMac, 100.000% MACs,
(embed1): Embedding(0, 0.000% Params, 0.0 Mac, 0.000% MACs, 10000, 700)
(drop2): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.65, inplace=False)
(lstm3): LSTM(3.93 M, 26.415% Params, 1.38 GMac, 26.455% MACs, 700, 700, dropout=0.65)
(lstm4): LSTM(3.93 M, 26.415% Params, 1.38 GMac, 26.455% MACs, 700, 700)
(drop5): Dropout(0, 0.000% Params, 0.0 Mac, 0.000% MACs, p=0.65, inplace=False)
(fc6): Linear(7.01 M, 47.170% Params, 2.45 GMac, 47.090% MACs, in_features=700, out_features=10000, bias=True)
)
# PTB_LSTM量化说明
## 全精度模型
1. 数据集及预处理:该模型使用PTB数据集,里面包含大量的英文句子,与cifar10数据集以图像为样本不同,该数据集将句子切割作为样本,并构建词表将单词转化为数字索引,便于网络处理。
2. 模型目标及评价指标:预测某个单词/短句后面出现的单词。由于每个样本在相同单词/短句的后续单词未必一致,不能简单的使用acc进行评判。语言模型使用困惑度ppl作为评价指标,值越小表示后续单词的可能性越集中,表征模型性能约好。数学上ppl表现为loss的指数。
3. 模型结构的额外要求:由于PTB LSTM模型中的embedding层具有稀疏矩阵参数,且其中很多值为0,导致Adam优化器学习率调整机制失效。pytorch官方embedding类处说明,当前可支持的CUDA优化器只有SGD、sparseAdam(专为稀疏矩阵准备的Adam优化器)。另外,在训练结束,进行推理时,一般会对lstm层使用flatten_parameters()方法,将其参数拉平为一维,方便并行计算。
4. 模型参数量和计算量的获取:由于PTB LSTM的输入并非floattensor,而是int类型的longTensor和hidden。之前所使用的ptflops需要使用其特殊输入构造API。在get_model_complexity_info方法中额外传入一个返回值为dict的输入构造器,从而获得特定的输入。
# ptq部分
### 量化层说明
#### embedding
+ 该层接受输入的每个元素均为int类型的索引,根据索引将embedding矩阵(即weight参数)的对应行取出,放置到对应位置。因此输入不需要进行量化,只需要对weight进行量化即可。
+ 注意到,输出的每个元素都来自于weight,因此只需要weight进行了量化。同时考虑到输出可能只包括了embedding矩阵的部分元素,设置了qo对输出进行了统计。并通过`self.M.data=self.qw.scale/self.qo.scale).data`进行rescale。
#### LSTM
+ 全精度层允许多层,为量化方便,我们将对多层LSTM进行拆分。
+ 该层接受输入x和隐层hidden。其中hidden是可选输入,可unpack为上一步隐层h和上一步状态c。如果未指定或输入为None时,nn.LSTM会自动初始化值为0的隐层作为输入。h,c形状为(nlayer,batch_sz, hidden_size)。考虑到第一层可能始终不接受hidden输入,或在第一个batch输入为None,后续输入为非零值。量化时指定参数has_hidden表明是否需要对输入hidden作为统计。
当has_hidden为true时,表示需要统计,设置相应的统计值qih和qic。考虑到第一个batch输入的hidden可能为None,只在hidden不为None时进行qih和qic的更新以及hidden的反量化,防止因为scale为0导致量化值出现nan。
当has_hidden为false时,表示该层始终不接受非零的hidden,无需进行统计,并在相应方法中进行检查。
+ 对于单层的lstm_module,参数主要有weight_ih_l0,weight_hh_l0,bias_ih_l0,bias_hh_l0。其中后缀为ih的表示输入x和输出hidden之间的关系,后缀为hh表示输入hidden和输出hidden之间的关系。简便起见,我们当前不对每个矩阵进行进一步分拆。
另外,由于该层运算并非简单的加减、乘法和卷积运算,不能很方便的进行rescale。当前仍然使用伪量化来进行推理。后续的rescale可以考虑用一个相近的线性函数来模拟。
+ 在quantize_forward时,与其他层直接调用toch.nn.Functional不同,仍然使用nn.LSTM进行。这是因为直接调用函数还涉及到flatten_weight等操作,为了简化调用逻辑,新建一个临时的LSTM层,并修改其参数与量化值一致来进行运算。
\ No newline at end of file
import sys
import os
# 从get_param.py输出重定向文件val.txt中提取参数量和计算量
def extract_ratio(model_name):
fr = open('param_flops/'+model_name+'.txt','r')
lines = fr.readlines()
Mac = lines[1].split('Mac,')[0].split(',')[-1]
if 'M' in Mac:
Mac = Mac.split('M')[0]
Mac = float(Mac)
elif 'G' in Mac:
Mac = Mac.split('G')[0]
Mac = float(Mac)
Mac *= 1024
Param = lines[1].split('M,')[0]
Param = float(Param)
layer = []
par_ratio = []
flop_ratio = []
weight_ratio = []
for line in lines:
if '(' in line and ')' in line:
layer.append(line.split(')')[0].split('(')[1])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
par_ratio.append(r1)
r2 = line.split('%')[-2].split(',')[-1]
r2 = float(r2)
flop_ratio.append(r2)
if 'conv' in line:
#无论是否bias=false都计算,fold之后直接使用conv的近似计算
inch = line.split(',')[4]
# outch = line.split(',')[5]
klsz = line.split(',')[6].split('(')[-1]
inch = float(inch)
# outch = float(outch)
klsz = float(klsz)
wr = inch * klsz * klsz
wr = wr / (1+wr)
weight_ratio.append(wr)
elif 'fc' in line:
inch = line.split(',')[4].split('=')[-1]
inch = float(inch)
wr = inch / (1+inch)
weight_ratio.append(wr)
else:
weight_ratio.append(0)
return Mac, Param, layer, par_ratio, flop_ratio, weight_ratio
if __name__ == "__main__":
Mac, Param, layer, par_ratio, flop_ratio, weight_ratio = extract_ratio('Inception_BN')
print(Mac)
print(Param)
print(layer)
print(par_ratio)
print(flop_ratio)
print(weight_ratio)
\ No newline at end of file
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
\ No newline at end of file
from model import *
from functools import partial
from lstm_utils import *
import sys
import torch
from ptflops import get_model_complexity_info
data_path = '../data/ptb'
embed_size = 700
hidden_size = 700
eval_batch_size = 10
dropout = 0.65
tied = True
def lstm_constructor(shape,hidden):
return {"x": torch.zeros(shape,dtype=torch.int64),
"hidden": hidden}
if __name__ == "__main__":
corpus = Corpus(data_path)
ntokens = len(corpus.dictionary)
model = Model(ntokens, embed_size, hidden_size, dropout, tied)
full_file = 'ckpt/ptb_PTB_LSTM.pt'
model.load_state_dict(torch.load(full_file))
hidden = model.init_hidden(eval_batch_size)
flops, params = get_model_complexity_info(model, (35,10), as_strings=True,
input_constructor = partial(lstm_constructor,hidden=hidden),
print_per_layer_stat=True)
#!/bin/bash
#- Job parameters
# (TODO)
# Please modify job name
#SBATCH -J PTB_LSTM # The job name
#SBATCH -o ret/ret-%j.out # Write the standard output to file named 'ret-<job_number>.out'
#SBATCH -e ret/ret-%j.err # Write the standard error to file named 'ret-<job_number>.err'
#- Resources
# (TODO)
# Please modify your requirements
#SBATCH -p nv-gpu # Submit to 'nv-gpu' Partitiion
#SBATCH -t 0-01:30:00 # Run for a maximum time of 0 days, 12 hours, 00 mins, 00 secs
#SBATCH --nodes=1 # Request N nodes
#SBATCH --gres=gpu:1 # Request M GPU per node
#SBATCH --gres-flags=enforce-binding # CPU-GPU Affinity
#SBATCH --qos=gpu-debug # Request QOS Type
###
### The system will alloc 8 or 16 cores per gpu by default.
### If you need more or less, use following:
### #SBATCH --cpus-per-task=K # Request K cores
###
###
### Without specifying the constraint, any available nodes that meet the requirement will be allocated
### You can specify the characteristics of the compute nodes, and even the names of the compute nodes
###
### #SBATCH --nodelist=gpu-v00 # Request a specific list of hosts
### #SBATCH --constraint="Volta|RTX8000" # Request GPU Type: Volta(V100 or V100S) or RTX8000
###
# set constraint for RTX8000 to meet my cuda
#SBATCH --constraint="Ampere|RTX8000|T4"
#- Log information
echo "Job start at $(date "+%Y-%m-%d %H:%M:%S")"
echo "Job run at:"
echo "$(hostnamectl)"
#- Load environments
source /tools/module_env.sh
module list # list modules loaded
##- Tools
module load cluster-tools/v1.0
module load slurm-tools/v1.0
module load cmake/3.15.7
module load git/2.17.1
module load vim/8.1.2424
##- language
module load python3/3.6.8
##- CUDA
# module load cuda-cudnn/10.2-7.6.5
# module load cuda-cudnn/11.2-8.2.1
module load cuda-cudnn/11.1-8.2.1
##- virtualenv
# source xxxxx/activate
echo $(module list) # list modules loaded
echo $(which gcc)
echo $(which python)
echo $(which python3)
cluster-quota # nas quota
nvidia-smi --format=csv --query-gpu=name,driver_version,power.limit # gpu info
#- Warning! Please not change your CUDA_VISIBLE_DEVICES
#- in `.bashrc`, `env.sh`, or your job script
echo "Use GPU ${CUDA_VISIBLE_DEVICES}" # which gpus
#- The CUDA_VISIBLE_DEVICES variable is assigned and specified by SLURM
#- Job step
# [EDIT HERE(TODO)]
python get_param_flops.py > PTB_LSTM.txt
#- End
echo "Job end at $(date "+%Y-%m-%d %H:%M:%S")"
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
import os
from io import open
import torch
def batchify(data, bsz, device):
nbatch = data.size(0) // bsz
data = data.narrow(0, 0, nbatch * bsz)
data = data.view(bsz, -1).t().contiguous()
return data.to(device)
def get_batch(source, i, bptt):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].view(-1)
return data, target
def repackage_hidden(h):
"""Wraps hidden states in new Tensors, to detach them from their history."""
if isinstance(h, torch.Tensor):
return h.detach()
else:
return tuple(repackage_hidden(v) for v in h)
# 在推理时,输入数据的形状是`(seq_len, batch, *)`,与训练时的形状相反。
# 如果不调用`.flatten_parameters()`,那么在推理前,LSTM需要对权重矩阵进行转置,将其形状调整为`(seq_len, batch, *)`,以匹配输入数据的形状。
def lstm_flatten(model):
for name,layer in model.named_modules():
if 'lstm' in name:
layer.flatten_parameters()
class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1
return self.word2idx[word]
def __len__(self):
return len(self.idx2word)
class Corpus(object):
def __init__(self, path):
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))
def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r', encoding="utf8") as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
self.dictionary.add_word(word)
with open(path, 'r', encoding="utf8") as f:
ids = torch.LongTensor(tokens)
token = 0
for line in f:
words = line.split() + ['<eos>']
for word in words:
ids[token] = self.dictionary.word2idx[word]
token += 1
return ids
import torch.nn as nn
from module import *
class Model(nn.Module):
def __init__(self,ntoken, ninp, nhid, dropout=0.5, tie_weights=False):
super(Model, self).__init__()
self.embed1 = nn.Embedding(ntoken, ninp)
self.drop2 = nn.Dropout(dropout)
self.lstm3 = nn.LSTM(ninp, nhid, 1, dropout=dropout)
# self.drop4 = nn.Dropout(dropout)
self.lstm4 = nn.LSTM(nhid, nhid, 1)
self.drop5 = nn.Dropout(dropout)
self.fc6 = nn.Linear(nhid, ntoken)
# "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
if tie_weights:
if nhid != ninp:
raise ValueError('the number of hidden unit per layer must be equal to the embedding size')
self.fc6.weight = self.embed1.weight
self.init_weights()
self.nhid = nhid
def init_weights(self):
initrange = 0.1
self.embed1.weight.data.uniform_(-initrange, initrange)
self.fc6.bias.data.zero_()
self.fc6.weight.data.uniform_(-initrange, initrange)
def forward(self, x, hidden=None):
x = self.embed1(x)
x = self.drop2(x)
x, hidden = self.lstm3(x, hidden)
x, hidden = self.lstm4(x, hidden)
x = self.drop5(x)
t, n = x.size(0), x.size(1)
x = x.view(t*n,-1)
x = self.fc6(x)
x = x.view(t,n,-1)
return x, hidden
def init_hidden(self,bsz):
# 获取模型第一个参数的device和数据类型
weight = next(self.parameters())
return (weight.new_zeros(1, bsz, self.nhid),
weight.new_zeros(1, bsz, self.nhid))
def quantize(self, quant_type, num_bits=8, e_bits=3):
# embed input作为索引,无需量化
self.qembed1 = QEmbedding(quant_type, self.embed1, num_bits=num_bits, e_bits=e_bits)
#qix承接上层,无需再量化,qih和qic初次对hidden操作,需量化
# self.qlstm3 = QLSTM(quant_type, self.lstm3, has_hidden=True, qix=False, qih=True, qic=True, num_bits=num_bits, e_bits=e_bits)
self.qlstm3 = QLSTM(quant_type, self.lstm3, has_hidden=True, qix=False, qih=True, qic=True, num_bits=num_bits, e_bits=e_bits)
self.qlstm4 = QLSTM(quant_type, self.lstm4, has_hidden=True, qix=False, qih=False, qic=False, num_bits=num_bits, e_bits=e_bits)
self.qfc6 = QLinear(quant_type, self.fc6, num_bits=num_bits, e_bits=e_bits)
def quantize_forward(self, x, hidden=None):
x = self.qembed1(x)
x = self.drop2(x)
x,hidden = self.qlstm3(x,hidden)
x,hidden = self.qlstm4(x,hidden)
x = self.drop5(x)
t,n = x.size(0), x.size(1)
x = x.view(t*n, -1)
x = self.qfc6(x)
x = x.view(t,n,-1)
return x,hidden
def freeze(self):
self.qembed1.freeze()
self.qlstm3.freeze(qix=self.qembed1.qo)
self.qlstm4.freeze(qix=self.qlstm3.qox, qih=self.qlstm3.qoh, qic=self.qlstm3.qoc)
self.qfc6.freeze(qi=self.qlstm4.qox)
def quantize_inference(self, x, hidden=None):
x = self.qembed1.quantize_inference(x)
x,hidden = self.qlstm3.quantize_inference(x,hidden)
x,hidden = self.qlstm4.quantize_inference(x,hidden)
t,n = x.size(0), x.size(1)
x = x.view(t*n, -1)
x = self.qfc6.quantize_inference(x)
x = x.view(t,n,-1)
return x,hidden
\ No newline at end of file
import math
import numpy as np
import gol
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from function import FakeQuantize
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
# 获取最近的量化值
def get_nearest_val(quant_type, x, is_bias=False, block_size=1000000):
if quant_type == 'INT':
return x.round_()
plist = gol.get_value(is_bias)
shape = x.shape
xhard = x.view(-1)
xout = torch.zeros_like(xhard)
plist = plist.type_as(x)
n_blocks = (x.numel() + block_size - 1) // block_size
for i in range(n_blocks):
start_idx = i * block_size
end_idx = min(start_idx + block_size, xhard.numel())
block_size_i = end_idx - start_idx
# print(x.numel())
# print(block_size_i)
# print(start_idx)
# print(end_idx)
xblock = xhard[start_idx:end_idx]
# xblock = xblock.view(shape[start_idx:end_idx])
plist_block = plist.unsqueeze(1) #.expand(-1, block_size_i)
idx = (xblock.unsqueeze(0) - plist_block).abs().min(dim=0)[1]
# print(xblock.shape)
xhard_block = plist[idx].view(xblock.shape)
xout[start_idx:end_idx] = (xhard_block - xblock).detach() + xblock
xout = xout.view(shape)
return xout
# 采用对称有符号量化时,获取量化范围最大值
def get_qmax(quant_type,num_bits=None, e_bits=None):
if quant_type == 'INT':
qmax = 2. ** (num_bits - 1) - 1
elif quant_type == 'POT':
qmax = 1
else: #FLOAT
m_bits = num_bits - 1 - e_bits
dist_m = 2 ** (-m_bits)
e = 2 ** (e_bits - 1)
expo = 2 ** e
m = 2 ** m_bits -1
frac = 1. + m * dist_m
qmax = frac * expo
return qmax
# 都采用有符号量化,zeropoint都置为0
def calcScaleZeroPoint(min_val, max_val, qmax):
scale = torch.max(max_val.abs(),min_val.abs()) / qmax
zero_point = torch.tensor(0.)
return scale, zero_point
# 将输入进行量化,输入输出都为tensor
def quantize_tensor(quant_type, x, scale, zero_point, qmax, is_bias=False):
# 量化后范围,直接根据位宽确定
qmin = -qmax
q_x = zero_point + x / scale
q_x.clamp_(qmin, qmax)
q_x = get_nearest_val(quant_type, q_x, is_bias)
return q_x
# bias使用不同精度,需要根据量化类型指定num_bits/e_bits
def bias_qmax(quant_type):
if quant_type == 'INT':
return get_qmax(quant_type, 64)
elif quant_type == 'POT':
return get_qmax(quant_type)
else:
return get_qmax(quant_type, 16, 7)
# 转化为FP32,不需再做限制
def dequantize_tensor(q_x, scale, zero_point):
return scale * (q_x - zero_point)
class QParam(nn.Module):
def __init__(self,quant_type, num_bits=8, e_bits=3):
super(QParam, self).__init__()
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.qmax = get_qmax(quant_type, num_bits, e_bits)
scale = torch.tensor([], requires_grad=False)
zero_point = torch.tensor([], requires_grad=False)
min = torch.tensor([], requires_grad=False)
max = torch.tensor([], requires_grad=False)
# 通过注册为register,使得buffer可以被记录到state_dict
self.register_buffer('scale', scale)
self.register_buffer('zero_point', zero_point)
self.register_buffer('min', min)
self.register_buffer('max', max)
# 更新统计范围及量化参数
def update(self, tensor):
if self.max.nelement() == 0 or self.max.data < tensor.max().data:
self.max.data = tensor.max().data
self.max.clamp_(min=0)
if self.min.nelement() == 0 or self.min.data > tensor.min().data:
self.min.data = tensor.min().data
self.min.clamp_(max=0)
self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, self.qmax)
def quantize_tensor(self, tensor):
return quantize_tensor(self.quant_type, tensor, self.scale, self.zero_point, self.qmax)
def dequantize_tensor(self, q_x):
return dequantize_tensor(q_x, self.scale, self.zero_point)
# 该方法保证了可以从state_dict里恢复
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
key_names = ['scale', 'zero_point', 'min', 'max']
for key in key_names:
value = getattr(self, key)
value.data = state_dict[prefix + key].data
state_dict.pop(prefix + key)
# 该方法返回值将是打印该对象的结果
def __str__(self):
info = 'scale: %.10f ' % self.scale
info += 'zp: %.6f ' % self.zero_point
info += 'min: %.6f ' % self.min
info += 'max: %.6f ' % self.max
info += 'qmax: %6f ' % self.qmax
return info
# 作为具体量化层的父类,qi和qo分别为量化输入/输出
class QModule(nn.Module):
def __init__(self,quant_type, qi=False, qo=True, num_bits=8, e_bits=3):
super(QModule, self).__init__()
if qi:
self.qi = QParam(quant_type,num_bits, e_bits)
if qo:
self.qo = QParam(quant_type,num_bits, e_bits)
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
def freeze(self):
pass # 空语句
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
def fakefreeze(self):
pass
"""
QModule 量化卷积
:quant_type: 量化类型
:conv_module: 卷积模块
:qi: 是否量化输入特征图
:qo: 是否量化输出特征图
:num_bits: 8位bit数
"""
class QConv2d(QModule):
def __init__(self, quant_type, conv_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QConv2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
# freeze方法可以固定真量化的权重参数,并将该值更新到原全精度层上,便于散度计算
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
# 这里因为在池化或者激活的输入,不需要对最大值和最小是进行额外的统计,会共享相同的输出
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
# 根据https://zhuanlan.zhihu.com/p/156835141, 这是式3 的系数
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
#考虑conv层无bias,此时forward和inference传入none亦可
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
self.conv_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0.,qmax=self.bias_qmax, is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
if self.conv_module.bias is not None:
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale*self.qw.scale,zero_point=0.)
def forward(self, x): # 前向传播,输入张量,x为浮点型数据
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi) # 对输入张量X完成量化
# foward前更新qw,保证量化weight时候scale正确
self.qw.update(self.conv_module.weight.data)
# 注意:此处主要为了统计各层x和weight范围,未对bias进行量化操作
tmp_wgt = FakeQuantize.apply(self.conv_module.weight, self.qw)
x = F.conv2d(x, tmp_wgt, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
# 利用公式 q_a = M(\sigma(q_w-Z_w)(q_x-Z_x) + q_b)
def quantize_inference(self, x): # 此处input为已经量化的qx
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QLinear(QModule):
def __init__(self, quant_type, fc_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QLinear, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.fc_module = fc_module
self.qw = QParam(quant_type, num_bits, e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data)
self.fc_module.weight.data = self.fc_module.weight.data - self.qw.zero_point
if self.fc_module.bias is not None:
self.fc_module.bias.data = quantize_tensor(self.quant_type,
self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax, is_bias=True)
def fakefreeze(self):
self.fc_module.weight.data = self.qw.dequantize_tensor(self.fc_module.weight.data)
if self.fc_module.bias is not None:
self.fc_module.bias.data = dequantize_tensor(self.fc_module.bias.data,scale=self.qi.scale*self.qw.scale,zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
self.qw.update(self.fc_module.weight.data)
tmp_wgt = FakeQuantize.apply(self.fc_module.weight, self.qw)
x = F.linear(x, tmp_wgt, self.fc_module.bias)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
# x = x - self.qi.zero_point
# x = self.fc_module(x)
# x = self.M * x
# x = get_nearest_val(self.quant_type,x)
# x = x + self.qo.zero_point
# return x
if hasattr(self, 'qi'):
x = FakeQuantize.apply(x, self.qi)
tmp_wgt = FakeQuantize.apply(self.fc_module.weight, self.qw)
x = F.linear(x, tmp_wgt, self.fc_module.bias)
if hasattr(self, 'qo'):
x = FakeQuantize.apply(x, self.qo)
return x
class QReLU(QModule):
def __init__(self,quant_type, qi=False, qo=False, num_bits=8, e_bits=3):
super(QReLU, self).__init__(quant_type, qi, qo, num_bits, e_bits)
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.relu(x)
return x
def quantize_inference(self, x):
x = x.clone()
x[x < self.qi.zero_point] = self.qi.zero_point
return x
class QMaxPooling2d(QModule):
def __init__(self, quant_type, kernel_size=3, stride=1, padding=0, qi=False, qo=False, num_bits=8,e_bits=3):
super(QMaxPooling2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
return x
def quantize_inference(self, x):
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
class QAdaptiveAvgPool2d(QModule):
def __init__(self, quant_type, output_size, qi=False, qo=True, num_bits=8,e_bits=3):
super(QAdaptiveAvgPool2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.output_size = output_size
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qi.scale / self.qo.scale).data
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.adaptive_avg_pool2d(x, self.output_size)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = F.adaptive_avg_pool2d(x, self.output_size)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x+self.qo.zero_point
return x
class QConvBNReLU(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QConvBNReLU, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
else:
bias = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
self.conv_module.bias = torch.nn.Parameter(bias)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale*self.qw.scale,zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
x.clamp_(min=0)
return x
class QConvBN(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QConvBN, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
else:
bias = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
self.conv_module.bias = torch.nn.Parameter(bias)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale*self.qw.scale,zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
# 作为具体量化层的父类,qi和qo分别为量化输入/输出
# 用于处理多个层结果或qo以array形式传入
class QModule_array(nn.Module):
def __init__(self,quant_type,len,qi_array=False, qo=True, num_bits=8, e_bits=3):
super(QModule_array, self).__init__()
if qi_array:
for i in range(len):
self.add_module('qi%d'%i,QParam(quant_type,num_bits, e_bits))
if qo:
self.qo = QParam(quant_type,num_bits, e_bits)
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
self.len = len
def freeze(self):
pass # 空语句
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
class QConcat(QModule_array):
def __init__(self, quant_type, len, qi_array=False, qo=True, num_bits=8, e_bits=3):
super(QConcat,self).__init__(quant_type, len, qi_array, qo, num_bits, e_bits)
for i in range(len):
self.register_buffer('M%d'%i,torch.tensor([], requires_grad=False))
def freeze(self, qi_array=None, qo=None):
if qi_array is None:
raise ValueError('qi_array should be provided')
elif len(qi_array) != self.len:
raise ValueError('qi_array len no match')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
for i in range(self.len):
self.add_module('qi%d'%i,qi_array[i])
if qo is not None:
self.qo = qo
for i in range(self.len):
getattr(self,'M%d'%i).data = (getattr(self,'qi%d'%i).scale / self.qo.scale).data
def forward(self,x_array):
outs=[]
for i in range(self.len):
x = x_array[i]
if hasattr(self,'qi%d'%i):
qi = getattr(self,'qi%d'%i)
qi.update(x)
x = FakeQuantize.apply(x,qi)
outs.append(x)
out = torch.cat(outs,1)
if hasattr(self,'qo'):
self.qo.update(x)
out = FakeQuantize.apply(out,self.qo)
return out
def quantize_inference(self, x_array):
outs=[]
for i in range(self.len):
qi = getattr(self,'qi%d'%i)
x = x_array[i] - qi.zero_point
x = getattr(self,'M%d'%i) * x
outs.append(x)
out = torch.concat(outs,1)
out = get_nearest_val(self.quant_type,out)
out = out + self.qo.zero_point
return out
class QModule_rnnbase(nn.Module):
def __init__(self,quant_type, qix=False, qih=False, qic=False, qox=True, qoh=True, qoc=True, num_bits=8, e_bits=3):
super(QModule_rnnbase,self).__init__()
if qix:
self.qix = QParam(quant_type, num_bits, e_bits)
if qih:
self.qih = QParam(quant_type, num_bits, e_bits)
if qic:
self.qic = QParam(quant_type, num_bits, e_bits)
if qox:
self.qox = QParam(quant_type, num_bits, e_bits)
if qoh:
self.qoh = QParam(quant_type, num_bits, e_bits)
if qoc:
self.qoc = QParam(quant_type, num_bits, e_bits)
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
def freeze(self):
pass # 空语句
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
# has_hidden表征该层能否接受非零值的hidden输入
# 对于第一层,可能第一个batch输入为零值,后续非零,则has_hidden仍需置为true
# 需额外判断hidden是否为none,避免scale为0的情形
class QLSTM(QModule_rnnbase):
def __init__(self, quant_type, lstm_module, has_hidden, qix=False, qih=False, qic=False, qox=True, qoh=True, qoc=True, num_bits=8, e_bits=3):
super(QLSTM,self).__init__(quant_type,qix,qih,qic,qox,qoh,qoc,num_bits,e_bits)
self.lstm_module = lstm_module
self.qwih = QParam(quant_type, num_bits, e_bits)
self.qwhh = QParam(quant_type, num_bits, e_bits)
# 表征是否输入hidden
self.has_hidden = has_hidden
def freeze(self, qix=None, qih=None, qic=None, qox=None, qoh=None, qoc=None):
if hasattr(self, 'qix') and qix is not None:
raise ValueError('qix has been provided in init function.')
if not hasattr(self, 'qix') and qix is None:
raise ValueError('qix is not existed, should be provided.')
if self.has_hidden:
if hasattr(self, 'qih') and qih is not None:
raise ValueError('qih has been provided in init function.')
if not hasattr(self, 'qih') and qih is None:
raise ValueError('qih is not existed, should be provided.')
if hasattr(self, 'qic') and qic is not None:
raise ValueError('qic has been provided in init function.')
if not hasattr(self, 'qic') and qic is None:
raise ValueError('qic is not existed, should be provided.')
if hasattr(self, 'qox') and qox is not None:
raise ValueError('qox has been provided in init function.')
if not hasattr(self, 'qox') and qox is None:
raise ValueError('qox is not existed, should be provided.')
if hasattr(self, 'qoh') and qoh is not None:
raise ValueError('qoh has been provided in init function.')
if not hasattr(self, 'qoh') and qoh is None:
raise ValueError('qoh is not existed, should be provided.')
if hasattr(self, 'qoc') and qoc is not None:
raise ValueError('qoc has been provided in init function.')
if not hasattr(self, 'qoc') and qoc is None:
raise ValueError('qoc is not existed, should be provided.')
if qix is not None:
self.qix = qix
if self.has_hidden:
if qih is not None:
self.qih = qih
if qic is not None:
self.qic = qic
if qox is not None:
self.qox = qox
if qoh is not None:
self.qoh = qoh
if qoc is not None:
self.qoc = qoc
self.lstm_module.weight_ih_l0.data = FakeQuantize.apply(self.lstm_module.weight_ih_l0.data, self.qwih)
self.lstm_module.weight_hh_l0.data = FakeQuantize.apply(self.lstm_module.weight_hh_l0.data, self.qwhh)
self.lstm_module.bias_ih_l0.data = quantize_tensor(self.quant_type,
self.lstm_module.bias_ih_l0.data,
scale=self.qix.scale * self.qwih.scale,
zero_point=0., qmax=self.bias_qmax, is_bias=True)
self.lstm_module.bias_ih_l0.data = dequantize_tensor(self.lstm_module.bias_ih_l0.data,
scale=self.qix.scale * self.qwih.scale, zero_point=0.)
if self.has_hidden:
self.lstm_module.bias_hh_l0.data = quantize_tensor(self.quant_type,
self.lstm_module.bias_hh_l0.data,
scale=self.qih.scale * self.qwhh.scale,
zero_point=0., qmax=self.bias_qmax, is_bias=True)
self.lstm_module.bias_hh_l0.data = dequantize_tensor(self.lstm_module.bias_hh_l0.data,
scale=self.qih.scale * self.qwhh.scale, zero_point=0.)
def forward(self, x, hidden=None):
# if self.has_hidden and hidden is None:
# raise ValueError("LSTM layer with has_hidden must accept hidden")
if not self.has_hidden and hidden is not None:
raise ValueError("LSTM layer without has_hidden cannot accept hidden")
if hasattr(self, 'qix'):
self.qix.update(x)
x = FakeQuantize.apply(x,self.qix)
if self.has_hidden and hidden is not None:
h,c = hidden
if hasattr(self, 'qih'):
self.qih.update(h)
h = FakeQuantize.apply(h, self.qih)
if hasattr(self, 'qic'):
self.qic.update(c)
c = FakeQuantize.apply(c, self.qic)
self.qwih.update(self.lstm_module.weight_ih_l0.data)
self.qwhh.update(self.lstm_module.weight_hh_l0.data)
# 由于该层实现与Conv等不同,若使用_VF.lstm,weight需要flatten
# 为简化调用逻辑,仍使用nn.LSTM新建层进行调用
tmplayer = nn.LSTM(input_size = self.lstm_module.input_size,
hidden_size = self.lstm_module.hidden_size,
num_layers = 1,
batch_first = False,
bidirectional = self.lstm_module.bidirectional,
bias = self.lstm_module.bias,
dropout = self.lstm_module.dropout)
tmplayer.weight_ih_l0.data = FakeQuantize.apply(self.lstm_module.weight_ih_l0.data, self.qwih)
tmplayer.weight_hh_l0.data = FakeQuantize.apply(self.lstm_module.weight_hh_l0.data, self.qwhh)
tmplayer.bias_ih_l0.data = self.lstm_module.bias_ih_l0.data
tmplayer.bias_hh_l0.data = self.lstm_module.bias_hh_l0.data
if self.has_hidden and hidden is not None:
x, (h,c) = tmplayer(x,(h,c))
else:
x, (h,c) = tmplayer(x)
if hasattr(self,'qox'):
self.qox.update(x)
x = FakeQuantize.apply(x,self.qox)
if hasattr(self,'qoh'):
self.qoh.update(h)
h = FakeQuantize.apply(h,self.qoh)
if hasattr(self,'qoc'):
self.qoc.update(c)
c = FakeQuantize.apply(c,self.qoc)
hidden = (h,c)
return x, hidden
def quantize_inference(self, x, hidden=None):
# if self.has_hidden and hidden is None:
# raise ValueError("LSTM layer with has_hidden must accept hidden")
# if not self.has_hidden and hidden is not None:
# raise ValueError("LSTM layer without has_hidden cannot accept hidden")
# if self.has_hidden:
# x, hidden = self.lstm_module(x,hidden)
# else:
# x, hidden = self.lstm_module(x)
# h,c = hidden
# if hasattr(self,'qox'):
# x = FakeQuantize.apply(x, self.qox)
# if hasattr(self,'qoh'):
# h = FakeQuantize.apply(h, self.qoh)
# if hasattr(self,'qoc'):
# c = FakeQuantize.apply(c, self.qoc)
# hidden = (h,c)
# return x,hidden
# if self.has_hidden and hidden is None:
# raise ValueError("LSTM layer with has_hidden must accept hidden")
if not self.has_hidden and hidden is not None:
raise ValueError("LSTM layer without has_hidden cannot accept hidden")
if hasattr(self, 'qix'):
x = FakeQuantize.apply(x,self.qix)
if self.has_hidden and hidden is not None:
h,c = hidden
if hasattr(self, 'qih'):
h = FakeQuantize.apply(h, self.qih)
if hasattr(self, 'qic'):
c = FakeQuantize.apply(c, self.qic)
# 由于该层实现与Conv等不同,若使用_VF.lstm,weight需要flatten
# 为简化调用逻辑,仍使用nn.LSTM新建层进行调用
tmplayer = nn.LSTM(input_size = self.lstm_module.input_size,
hidden_size = self.lstm_module.hidden_size,
num_layers = 1,
batch_first = False,
bidirectional = self.lstm_module.bidirectional,
bias = self.lstm_module.bias,
dropout = self.lstm_module.dropout)
tmplayer.weight_ih_l0.data = FakeQuantize.apply(self.lstm_module.weight_ih_l0.data, self.qwih)
tmplayer.weight_hh_l0.data = FakeQuantize.apply(self.lstm_module.weight_hh_l0.data, self.qwhh)
tmplayer.bias_ih_l0.data = self.lstm_module.bias_ih_l0.data
tmplayer.bias_hh_l0.data = self.lstm_module.bias_hh_l0.data
if self.has_hidden and hidden is not None:
x, (h,c) = tmplayer(x,(h,c))
else:
x, (h,c) = tmplayer(x)
if hasattr(self,'qox'):
x = FakeQuantize.apply(x,self.qox)
if hasattr(self,'qoh'):
h = FakeQuantize.apply(h,self.qoh)
if hasattr(self,'qoc'):
c = FakeQuantize.apply(c,self.qoc)
hidden = (h,c)
return x, hidden
class QEmbedding(QModule):
def __init__(self,quant_type,embedding_module,qi=False, qo=True, num_bits=8, e_bits=3):
super(QEmbedding, self).__init__(quant_type,qi,qo, num_bits, e_bits)
self.embedding_module = embedding_module
self.qw = QParam(quant_type, num_bits, e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False))
def freeze(self, qo=None):
#输入为index,不量化
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale / self.qo.scale).data
self.embedding_module.weight.data = self.qw.quantize_tensor(self.embedding_module.weight.data)
def forward(self, x):
self.qw.update(self.embedding_module.weight.data)
tmp_wgt = FakeQuantize.apply(self.embedding_module.weight, self.qw)
# tmp_wgt = self.embedding_module.weight
x = F.embedding(x, tmp_wgt)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
# x = self.embedding_module(x)
# x = self.M * x
# x = get_nearest_val(self.quant_type, x)
# return x
tmp_wgt = FakeQuantize.apply(self.embedding_module.weight, self.qw)
# tmp_wgt = self.embedding_module.weight
x = F.embedding(x, tmp_wgt)
x = FakeQuantize.apply(x, self.qo)
return x
\ No newline at end of file
from torch.serialization import load
from model import *
from extract_ratio import *
from utils import *
from lstm_utils import *
import gol
import openpyxl
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
import time
import math
data_path = '../data/ptb'
embed_size = 700
hidden_size = 700
lr = 22
clip = 0.25
eval_batch_size = 10
bptt = 35 #所取的串长度
dropout = 0.65
tied = True
seed = 1111
seed_gpu = 1111
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def direct_quantize(model, val_data, eval_batch_size, bptt):
# hidden = model.init_hidden(eval_batch_size)
hidden=None
with torch.no_grad():
for i in range(0, val_data.size(0) - 1, bptt):
data, targets = get_batch(val_data, i, bptt)
output, hidden = model.quantize_forward(data,hidden)
hidden = repackage_hidden(hidden)
print('direct quantization finish')
def full_inference(model, test_data, ntokens, eval_batch_size, bptt):
total_loss = 0.
lossLayer = nn.CrossEntropyLoss()
# hidden = model.init_hidden(eval_batch_size)
hidden=None
with torch.no_grad():
for i in range(0, test_data.size(0) - 1, bptt):
data, targets = get_batch(test_data, i, bptt)
output, hidden = model(data,hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * lossLayer(output_flat, targets).item()
hidden = repackage_hidden(hidden)
test_loss = total_loss / (len(test_data) - 1)
ppl = math.exp(test_loss)
print('\nTest set: Full Model Perplexity: {:.4f} Loss {:f}'.format(ppl,test_loss))
return ppl
def quantize_inference(model, test_data, ntokens, eval_batch_size, bptt):
total_loss = 0.
lossLayer = nn.CrossEntropyLoss()
# hidden = model.init_hidden(eval_batch_size)
hidden=None
# print(model.qembed1.qw)
# print(model.qembed1.qo)
# print(model.qlstm3.qix)
# print(model.qlstm3.qih)
# print(model.qlstm3.qic)
# print(model.qlstm3.qox)
# print(model.qlstm3.qoh)
# print(model.qlstm3.qoc)
with torch.no_grad():
for i in range(0, test_data.size(0) - 1, bptt):
data, targets = get_batch(test_data, i, bptt)
output, hidden = model.quantize_inference(data,hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * lossLayer(output_flat, targets).item()
hidden = repackage_hidden(hidden)
test_loss = total_loss / (len(test_data) - 1)
ppl = math.exp(test_loss)
print('Test set: Quant Model Perplexity: {:.4f} Loss {:f}\n'.format(ppl,test_loss))
return ppl
if __name__ == "__main__":
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed_gpu)
corpus = Corpus(data_path)
ntokens = len(corpus.dictionary)
val_data = batchify(corpus.valid, eval_batch_size, device)
test_data = batchify(corpus.test, eval_batch_size, device)
load_ptq = False
store_ptq = False
gol._init()
excel_path = 'ptq_result.xlsx'
workbook = openpyxl.Workbook()
if 'Sheet' in workbook.sheetnames:
workbook.remove(workbook['Sheet'])
txt_path = 'ptq_result.txt'
ft = open(txt_path,'w')
model = Model(ntokens, embed_size, hidden_size, dropout, tied).to(device)
full_file = 'ckpt/ptb_PTB_LSTM.pt'
model.load_state_dict(torch.load(full_file))
ptq_file_prefix = 'ckpt/ptb_PTB_LSTM_ptq_'
model.eval()
lstm_flatten(model)
full_ppl = full_inference(model,test_data,ntokens,eval_batch_size,bptt)
quant_type_list = ['INT','POT','FLOAT']
# quant_type_list = ['FLOAT']
title_list = []
js_flops_list = []
js_param_list = []
ptq_ppl_list = []
ppl_ratio_list = []
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
model_ptq = Model(ntokens, embed_size, hidden_size, dropout, tied).to(device)
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nPTB_LSTM: PTQ: '+title)
title_list.append(title)
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
model_ptq.load_state_dict(torch.load(full_file))
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.eval()
direct_quantize(model_ptq, val_data, eval_batch_size, bptt)
#这里的quantize_inference都用伪量化,相比forward只少了update,不需freeze
ptq_ppl = quantize_inference(model_ptq, test_data, ntokens, eval_batch_size, bptt)
ppl_ratio = ptq_ppl / full_ppl
print(title+': ppl_ratio: %f'%ppl_ratio)
ptq_ppl_list.append(ptq_ppl)
ppl_ratio_list.append(ppl_ratio)
worksheet = workbook.create_sheet('PTB_LSTM')
worksheet.cell(row=1,column=1,value='FP32-ppl')
worksheet.cell(row=1,column=2,value=full_ppl)
worksheet.cell(row=3,column=1,value='title')
# worksheet.cell(row=3,column=2,value='js_flops')
# worksheet.cell(row=3,column=3,value='js_param')
worksheet.cell(row=3,column=2,value='ptq_ppl')
worksheet.cell(row=3,column=3,value='ppl_ratio')
for i in range(len(title_list)):
worksheet.cell(row=i+4,column=1,value=title_list[i])
worksheet.cell(row=i+4,column=2,value=ptq_ppl_list[i])
worksheet.cell(row=i+4,column=3,value=ppl_ratio_list[i])
if 'Sheet' in workbook.sheetnames:
workbook.remove(workbook['Sheet'])
workbook.save(excel_path)
print('PTB_LSTM',file=ft)
print('Full_ppl: %f'%full_ppl,file=ft)
print('title_list:',file=ft)
print(title_list,file=ft)
# print('js_flops_list:',file=ft)
# print(js_flops_list, file=ft)
# print('js_param_list:',file=ft)
# print(js_param_list, file=ft)
print('ptq_ppl_list:',file=ft)
print(ptq_ppl_list, file=ft)
print('ppl_ratio_list:',file=ft)
print(ppl_ratio_list, file=ft)
print("\n",file=ft)
ft.close()
#!/bin/bash
#- Job parameters
# (TODO)
# Please modify job name
#SBATCH -J PTB_LSTM # The job name
#SBATCH -o ret/ret-%j.out # Write the standard output to file named 'ret-<job_number>.out'
#SBATCH -e ret/ret-%j.err # Write the standard error to file named 'ret-<job_number>.err'
#- Resources
# (TODO)
# Please modify your requirements
#SBATCH -p nv-gpu # Submit to 'nv-gpu' Partitiion
#SBATCH -t 3-00:00:00 # Run for a maximum time of 0 days, 12 hours, 00 mins, 00 secs
#SBATCH --nodes=1 # Request N nodes
#SBATCH --gres=gpu:1 # Request M GPU per node
#SBATCH --gres-flags=enforce-binding # CPU-GPU Affinity
#SBATCH --qos=gpu-long # Request QOS Type
###
### The system will alloc 8 or 16 cores per gpu by default.
### If you need more or less, use following:
### #SBATCH --cpus-per-task=K # Request K cores
###
###
### Without specifying the constraint, any available nodes that meet the requirement will be allocated
### You can specify the characteristics of the compute nodes, and even the names of the compute nodes
###
### #SBATCH --nodelist=gpu-v00 # Request a specific list of hosts
### #SBATCH --constraint="Volta|RTX8000" # Request GPU Type: Volta(V100 or V100S) or RTX8000
###
# set constraint for RTX8000 to meet my cuda
#SBATCH --constraint="Ampere|RTX8000|T4"
#- Log information
echo "Job start at $(date "+%Y-%m-%d %H:%M:%S")"
echo "Job run at:"
echo "$(hostnamectl)"
#- Load environments
source /tools/module_env.sh
module list # list modules loaded
##- Tools
module load cluster-tools/v1.0
module load slurm-tools/v1.0
module load cmake/3.15.7
module load git/2.17.1
module load vim/8.1.2424
##- language
module load python3/3.6.8
##- CUDA
# module load cuda-cudnn/10.2-7.6.5
# module load cuda-cudnn/11.2-8.2.1
module load cuda-cudnn/11.1-8.2.1
##- virtualenv
# source xxxxx/activate
echo $(module list) # list modules loaded
echo $(which gcc)
echo $(which python)
echo $(which python3)
cluster-quota # nas quota
nvidia-smi --format=csv --query-gpu=name,driver_version,power.limit # gpu info
#- Warning! Please not change your CUDA_VISIBLE_DEVICES
#- in `.bashrc`, `env.sh`, or your job script
echo "Use GPU ${CUDA_VISIBLE_DEVICES}" # which gpus
#- The CUDA_VISIBLE_DEVICES variable is assigned and specified by SLURM
#- Job step
# [EDIT HERE(TODO)]
python ptq.py
#- End
echo "Job end at $(date "+%Y-%m-%d %H:%M:%S")"
PTB_LSTM
Full_ppl: 80.347527
title_list:
['INT_2', 'INT_3', 'INT_4', 'INT_5', 'INT_6', 'INT_7', 'INT_8', 'INT_9', 'INT_10', 'INT_11', 'INT_12', 'INT_13', 'INT_14', 'INT_15', 'INT_16', 'POT_2', 'POT_3', 'POT_4', 'POT_5', 'POT_6', 'POT_7', 'POT_8', 'FLOAT_3_E1', 'FLOAT_4_E1', 'FLOAT_4_E2', 'FLOAT_5_E1', 'FLOAT_5_E2', 'FLOAT_5_E3', 'FLOAT_6_E1', 'FLOAT_6_E2', 'FLOAT_6_E3', 'FLOAT_6_E4', 'FLOAT_7_E1', 'FLOAT_7_E2', 'FLOAT_7_E3', 'FLOAT_7_E4', 'FLOAT_7_E5', 'FLOAT_8_E1', 'FLOAT_8_E2', 'FLOAT_8_E3', 'FLOAT_8_E4', 'FLOAT_8_E5', 'FLOAT_8_E6']
ptq_ppl_list:
[3680.054005643725, 1113.1595756444422, 237.175147011631, 118.14850804671768, 85.68659660098655, 82.04448564575729, 80.74354328756583, 80.4501107581829, 80.37609633787858, 80.35521256832457, 80.34711404715937, 80.34729820019561, 80.34810939943782, 80.34762953587571, 80.34751136939029, 3680.054005643725, 1436.4914411630305, 168.02534152231982, 167.8943965460058, 166.00384274940092, 166.00408006630718, 166.00413183236293, 829.8560620309429, 407.67140698415596, 142.51986167702356, 165.40337795565802, 96.49152981578862, 101.44738407817896, 131.7285436670383, 87.54655615370719, 86.09250067555212, 101.35647647974928, 119.5743194917011, 84.69916168997997, 81.94981964738962, 86.0939502797634, 101.40409626471063, 117.6051487311836, 83.61499193954714, 80.77412538207982, 81.96937299585758, 85.96081475056866, 101.39158416596052]
ppl_ratio_list:
[45.80170816062014, 13.854310274163526, 2.951866154605171, 1.470468497709415, 1.0664497001338629, 1.0211202286630146, 1.0049287863275775, 1.0012767445219362, 1.0003555659541061, 1.000095647942865, 0.9999948542853881, 0.9999971462418676, 1.0000072423738524, 1.0000012700237741, 0.9999997993315312, 45.80170816062014, 17.87847723497442, 2.091232260230823, 2.0896025277374854, 2.0660728205317613, 2.0660757741622424, 2.06607641843914, 10.32833352721793, 5.073851302041298, 1.7737927491314487, 2.058599475520639, 1.2009271825404801, 1.2626074161089171, 1.6394847206612713, 1.0895986334086523, 1.0715015553337994, 1.2614759861661844, 1.4882140524200391, 1.0541601507000304, 1.0199420219238415, 1.071519597011702, 1.2620686588525436, 1.4637058836938923, 1.0406666458683778, 1.0053094090482557, 1.0201853816024549, 1.0698626010424483, 1.2619129341010724]
# coding: utf-8
import time
import math
import os
import torch
import torch.nn as nn
from lstm_utils import *
from model import *
import sys
data_path = '../data/ptb'
embed_size = 700
hidden_size = 700
lr = 22
clip = 0.25
epochs = 10
train_batch_size = 20
eval_batch_size = 10
bptt = 35 #所取的串长度
dropout = 0.65
tied = True
seed = 1111
seed_gpu = 1111
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
log_interval = 200
save_path = 'ckpt/ptb_PTB_LSTM.pt'
def evaluate(model, eval_data, ntokens, eval_batch_size, bptt):
model.eval()
total_loss = 0.
lossLayer = nn.CrossEntropyLoss()
hidden = model.init_hidden(eval_batch_size)
with torch.no_grad():
for i in range(0, eval_data.size(0) - 1, bptt):
data, targets = get_batch(eval_data, i, bptt)
output, hidden = model(data,hidden)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * lossLayer(output_flat, targets).item()
hidden = repackage_hidden(hidden)
return total_loss / (len(eval_data) - 1)
def train(model, train_data, ntokens, train_batch_size, bptt, lr, clip):
model.train()
total_loss = 0.
lossLayer = nn.CrossEntropyLoss()
start_time = time.time()
hidden = model.init_hidden(train_batch_size)
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i, bptt)
# hidden = repackage_hidden(hidden)
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
optimizer.zero_grad()
output, hidden = model(data,hidden)
hidden = repackage_hidden(hidden)
loss = lossLayer(output.view(-1, ntokens), targets)
loss.backward()
#梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
for p in model.parameters():
p.data.add_(p.grad.data, alpha=-lr)
total_loss += loss.item()
if batch % 200 == 0 and batch > 0:
cur_loss = total_loss / 200
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
epoch, batch, len(train_data) // bptt, lr,
elapsed * 1000 / 200, cur_loss, math.exp(cur_loss)))
total_loss = 0
start_time = time.time()
if __name__ == "__main__":
sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed_gpu)
corpus = Corpus(data_path)
ntokens = len(corpus.dictionary)
train_data = batchify(corpus.train, train_batch_size, device)
val_data = batchify(corpus.valid, eval_batch_size, device)
test_data = batchify(corpus.test, eval_batch_size, device)
model = Model(ntokens, embed_size, hidden_size, dropout, tied).to(device)
best_val_loss = None
for epoch in range(1, epochs+1):
epoch_start_time = time.time()
train(model, train_data, ntokens, train_batch_size, bptt, lr, clip)
val_loss = evaluate(model, val_data, ntokens, eval_batch_size, bptt)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(val_loss)))
print('-' * 89)
if not best_val_loss or val_loss < best_val_loss:
with open(save_path, 'wb') as f:
torch.save(model.state_dict(), f)
best_val_loss = val_loss
else:
lr /= 2.5
if lr < 0.0001:
break
with open(save_path, 'rb') as f:
# model = torch.load(f)
model = Model(ntokens, embed_size, hidden_size, dropout, tied).to(device)
model.load_state_dict(torch.load(f))
#将LSTM层的参数拉平为一维,方便并行计算
lstm_flatten(model)
# model.lstm3.flatten_parameters()
# model.lstm4.flatten_parameters()
test_loss = evaluate(model, test_data, ntokens, eval_batch_size, bptt)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)
#!/bin/bash
#- Job parameters
# (TODO)
# Please modify job name
#SBATCH -J PTB_LSTM # The job name
#SBATCH -o ret/ret-%j.out # Write the standard output to file named 'ret-<job_number>.out'
#SBATCH -e ret/ret-%j.err # Write the standard error to file named 'ret-<job_number>.err'
#- Resources
# (TODO)
# Please modify your requirements
#SBATCH -p nv-gpu # Submit to 'nv-gpu' Partitiion
#SBATCH -t 3-00:00:00 # Run for a maximum time of 0 days, 12 hours, 00 mins, 00 secs
#SBATCH --nodes=1 # Request N nodes
#SBATCH --gres=gpu:1 # Request M GPU per node
#SBATCH --gres-flags=enforce-binding # CPU-GPU Affinity
#SBATCH --qos=gpu-long # Request QOS Type
###
### The system will alloc 8 or 16 cores per gpu by default.
### If you need more or less, use following:
### #SBATCH --cpus-per-task=K # Request K cores
###
###
### Without specifying the constraint, any available nodes that meet the requirement will be allocated
### You can specify the characteristics of the compute nodes, and even the names of the compute nodes
###
### #SBATCH --nodelist=gpu-v00 # Request a specific list of hosts
### #SBATCH --constraint="Volta|RTX8000" # Request GPU Type: Volta(V100 or V100S) or RTX8000
###
# set constraint for RTX8000 to meet my cuda
#SBATCH --constraint="Ampere|RTX8000|T4"
#- Log information
echo "Job start at $(date "+%Y-%m-%d %H:%M:%S")"
echo "Job run at:"
echo "$(hostnamectl)"
#- Load environments
source /tools/module_env.sh
module list # list modules loaded
##- Tools
module load cluster-tools/v1.0
module load slurm-tools/v1.0
module load cmake/3.15.7
module load git/2.17.1
module load vim/8.1.2424
##- language
module load python3/3.6.8
##- CUDA
# module load cuda-cudnn/10.2-7.6.5
# module load cuda-cudnn/11.2-8.2.1
module load cuda-cudnn/11.1-8.2.1
##- virtualenv
# source xxxxx/activate
echo $(module list) # list modules loaded
echo $(which gcc)
echo $(which python)
echo $(which python3)
cluster-quota # nas quota
nvidia-smi --format=csv --query-gpu=name,driver_version,power.limit # gpu info
#- Warning! Please not change your CUDA_VISIBLE_DEVICES
#- in `.bashrc`, `env.sh`, or your job script
echo "Use GPU ${CUDA_VISIBLE_DEVICES}" # which gpus
#- The CUDA_VISIBLE_DEVICES variable is assigned and specified by SLURM
#- Job step
# [EDIT HERE(TODO)]
python train.py
#- End
echo "Job end at $(date "+%Y-%m-%d %H:%M:%S")"
# import torch
# from transformers import BertForSequenceClassification, BertTokenizer
# from ptflops import get_model_complexity_info
# def bert_input_constructor(input_shape, tokenizer):
# inp_seq = ""
# for _ in range(input_shape[1] - 2): # there are two special tokens [CLS] and [SEP]
# inp_seq += tokenizer.pad_token # let's use pad token to form a fake
# # sequence for subsequent flops calculation
# inputs = tokenizer([inp_seq] * input_shape[0], padding=True, truncation=True,
# return_tensors="pt")
# labels = torch.tensor([1] * input_shape[0])
# # Batch size input_shape[0], sequence length input_shape[128]
# inputs = dict(inputs)
# inputs.update({"labels": labels})
# return inputs
# if __name__ == '__main__':
# shape = (35,10)
# bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# tmp = bert_input_constructor(shape,bert_tokenizer)
# print(tmp)
# print(**tmp)
from functools import partial
import torch
from transformers import BertForSequenceClassification, BertTokenizer
from ptflops import get_model_complexity_info
def bert_input_constructor(input_shape, tokenizer):
inp_seq = ""
for _ in range(input_shape[1] - 2): # there are two special tokens [CLS] and [SEP]
inp_seq += tokenizer.pad_token # let's use pad token to form a fake
# sequence for subsequent flops calculation
inputs = tokenizer([inp_seq] * input_shape[0], padding=True, truncation=True,
return_tensors="pt")
labels = torch.tensor([1] * input_shape[0])
# Batch size input_shape[0], sequence length input_shape[128]
inputs = dict(inputs)
inputs.update({"labels": labels})
return inputs
if __name__ == '__main__':
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
flops_count, params_count = get_model_complexity_info(
model, (2, 128), as_strings=True,
input_constructor=partial(bert_input_constructor, tokenizer=bert_tokenizer),
print_per_layer_stat=True)
print('{:<30} {:<8}'.format('Computational complexity: ', flops_count))
print('{:<30} {:<8}'.format('Number of parameters: ', params_count))
\ No newline at end of file
import torch
import torch.nn as nn
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
e_bit_list = list(range(1,num_bits-1))
else:
e_bit_list = [0]
return e_bit_list
def numbit_list(quant_type):
if quant_type == 'INT':
num_bit_list = list(range(2,17))
# num_bit_list = [4,5]
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
# num_bit_list = [5]
else:
num_bit_list = list(range(2,9))
# num_bit_list = [8]
return num_bit_list
def build_bias_list(quant_type):
if quant_type == 'POT':
return build_pot_list(8)
else:
return build_float_list(16,7)
def build_list(quant_type, num_bits, e_bits):
if quant_type == 'POT':
return build_pot_list(num_bits)
else:
return build_float_list(num_bits,e_bits)
def build_pot_list(num_bits):
plist = [0.]
for i in range(-2 ** (num_bits-1) + 2, 1):
# i最高到0,即pot量化最大值为1
plist.append(2. ** i)
plist.append(-2. ** i)
plist = torch.Tensor(list(set(plist)))
# plist = plist.mul(1.0 / torch.max(plist))
return plist
def build_float_list(num_bits,e_bits):
m_bits = num_bits - 1 - e_bits
plist = [0.]
# 相邻尾数的差值
dist_m = 2 ** (-m_bits)
e = -2 ** (e_bits - 1) + 1
for m in range(1, 2 ** m_bits):
frac = m * dist_m # 尾数部分
expo = 2 ** e # 指数部分
flt = frac * expo
plist.append(flt)
plist.append(-flt)
for e in range(-2 ** (e_bits - 1) + 2, 2 ** (e_bits - 1) + 1):
expo = 2 ** e
for m in range(0, 2 ** m_bits):
frac = 1. + m * dist_m
flt = frac * expo
plist.append(flt)
plist.append(-flt)
plist = torch.Tensor(list(set(plist)))
return plist
#此处不必cfg,直接取同前缀同后缀即可。将relu一起考虑进去
def fold_ratio(layer, par_ratio, flop_ratio):
idx = -1
for name in layer:
if 'conv' in name:
conv_idx = layer.index(name)
[prefix,suffix] = name.split('conv')
bn_name = prefix+'bn'+suffix
relu_name = prefix+'relu'+suffix
if bn_name in layer:
bn_idx = layer.index(bn_name)
par_ratio[conv_idx]+=par_ratio[bn_idx]
flop_ratio[conv_idx]+=flop_ratio[bn_idx]
if relu_name in layer:
relu_idx = layer.index(relu_name)
par_ratio[conv_idx]+=par_ratio[relu_idx]
flop_ratio[conv_idx]+=flop_ratio[bn_idx]
return par_ratio,flop_ratio
def fold_model(model):
for name, module in model.named_modules():
if 'conv' in name:
[prefix,suffix] = name.split('conv')
bn_name = prefix+'bn'+suffix
if hasattr(model,bn_name):
bn_layer = getattr(model,bn_name)
fold_bn(module,bn_layer)
def fold_bn(conv, bn):
# 获取 BN 层的参数
mean = bn.running_mean
var = bn.running_var
eps = bn.eps
std = torch.sqrt(var + eps)
if bn.affine:
gamma_ = bn.weight / std
weight = conv.weight * gamma_.view(conv.out_channels, 1, 1, 1)
if conv.bias is not None:
bias = gamma_ * conv.bias - gamma_ * mean + bn.bias
else:
bias = bn.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = conv.weight * gamma_
if conv.bias is not None:
bias = gamma_ * conv.bias - gamma_ * mean
else:
bias = -gamma_ * mean
# 设置新的 weight 和 bias
conv.weight.data = weight.data
if conv.bias is not None:
conv.bias.data = bias.data
else:
conv.bias = torch.nn.Parameter(bias)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment