Commit a2ea6085 by Klin

feat: add Gen of ALL-cifar in GDFQ

parent 394c19ce
# GDFQ 说明
+ 思路源于论文 **Generative Low-bitwidth Data Free Quantization**,源代码可见https://github.com/xushoukai/GDFQ
+ 论文核心思路:
+ 从预训练模型(全精度模型)捕获信息训练生成器
+ 分类信息匹配:从预训练模型最后一层提取分类特征,给定随机标签y和高斯噪声,经过生成器网络反向得到伪数据x,将伪数据x经过全精度模型的输出z和标签y一起计算loss_one_hot.
+ 数据分布信息匹配:从预训练模型BN层提取训练数据分布信息,计算生成数据分布和真实数据分布一起计算BNS_loss
+ 生成器网络根据loss_one_hot和BNS_loss进行反向更新。得到的生成器生成的数据可以很好贴合全精度模型分类边界,且生成数据分布匹配真实数据(训练集)分布。如论文中图所示:
![paper_img](image/paper_img.png)
+ 伪数据驱动的低位宽量化:
用训练得到的生成器作为输入,全精度模型提供label,对量化网络进行训练,以提高性能,即给定相同输入下,量化模型和全进度模型输出更加接近。
+ 论文代码功能取舍:
+ 由于我们的量化希望直接部署,而不经过fine-tune,也即不进行进一步的训练和调整,因此论文中的伪数据驱动可以被我们省略。
+ 论文中的生成器可以很好的模拟全精度模型的分类边界,符合我们最初对迁移安全性的定义,对经过边界附近的样本扰动下的耐受能力。因此生成器可以较好的应用于我们的框架当中
+ 实验改动:
+ 原论文直接从torchcv获取官方预训练的全精度模型,我们使用自己训练的全精度模型,这是因为安全性并不关注训练集的客观分类边界,而是关注对全精度模型的分类边界改变。
+ 生成器的效果:对之前ALL-cifar10中的所有9个模型训练了生成器。将含噪声的随机标签y经过生成器得到伪数据,将伪数据结果输入全精度模型,与y的差别作为acc标准。ResNet系列3个模型得到的生成器acc在60~70,其余模型均在99以上。说明生成器很好的拟合了分类边界,ResNet由于大量残差结构,可能分类边界较复杂,拟合效果稍差一些。
+ 运行方式:
```shell
python main.py --conf_path=./cifar100_resnet20.hocon --id=01 --model_name=ResNet_18
```
+ 后续的安全性评估方式:
随机生成标签,并经过生成器生成伪数据,将伪数据分别输入全精度精度模型和量化模型。以全精度模型的输出为基准,由于生成器能够很好的拟合全精度分类边界,量化模型的输出和全精度模型输出不一致的比例可以衡量量化对分类边界的改变。
+ 进一步的实验改进:
参考论文**Qimera: Data-free Quantization with Synthetic Boundary Supporting Samples [NeurIPS 2021]**,进一步增强决策边界样本含量。
\ No newline at end of file
# conv: 'C',''/'B'/'BRL'/'BRS',qi,in_ch,out_ch,kernel_size,stirde,padding,bias
# relu: 'RL'
# relu6: 'RS'
# inception: 'Inc'
# maxpool: 'MP',kernel_size,stride,padding
# adaptiveavgpool: 'AAP',output_size
# view: 'VW':
# dafault: x = x.view(x.size(0),-1)
# dropout: 'D'
# MakeLayer: 'ML','BBLK'/'BTNK'/'IRES', ml_idx, blocks
# softmax: 'SM'
# class 100
ResNet_18_cfg_table = [
['C','BRL',True,3,16,3,1,1,True],
['ML','BBLK',0,2],
['ML','BBLK',1,2],
['ML','BBLK',2,2],
['ML','BBLK',3,2],
['AAP',1],
['VW'],
['FC',128,100,True],
['SM']
]
ResNet_50_cfg_table = [
['C','BRL',True,3,16,3,1,1,True],
['ML','BTNK',0,3],
['ML','BTNK',1,4],
['ML','BTNK',2,6],
['ML','BTNK',3,3],
['AAP',1],
['VW'],
['FC',512,100,True],
['SM']
]
ResNet_152_cfg_table = [
['C','BRL',True,3,16,3,1,1,True],
['ML','BTNK',0,3],
['ML','BTNK',1,8],
['ML','BTNK',2,36],
['ML','BTNK',3,3],
['AAP',1],
['VW'],
['FC',512,100,True],
['SM']
]
MobileNetV2_cfg_table = [
['C','BRS',True,3,32,3,1,1,True],
['ML','IRES',0,1],
['ML','IRES',1,2],
['ML','IRES',2,3],
['ML','IRES',3,3],
['ML','IRES',4,3],
['ML','IRES',5,1],
['C','',False,320,1280,1,1,0,True],
['AAP',1],
['VW'],
['FC',1280,100,True]
]
AlexNet_cfg_table = [
['C','',True,3,32,3,1,1,True],
['RL'],
['MP',2,2,0],
['C','',False,32,64,3,1,1,True],
['RL'],
['MP',2,2,0],
['C','',False,64,128,3,1,1,True],
['RL'],
['C','',False,128,256,3,1,1,True],
['RL'],
['C','',False,256,256,3,1,1,True],
['RL'],
['MP',3,2,0],
['VW'],
['D',0.5],
['FC',2304,1024,True],
['RL'],
['D',0.5],
['FC',1024,512,True],
['RL'],
['FC',512,100,True]
]
AlexNet_BN_cfg_table = [
['C','BRL',True,3,32,3,1,1,True],
['MP',2,2,0],
['C','BRL',False,32,64,3,1,1,True],
['MP',2,2,0],
['C','BRL',False,64,128,3,1,1,True],
['C','BRL',False,128,256,3,1,1,True],
['C','BRL',False,256,256,3,1,1,True],
['MP',3,2,0],
['VW'],
['D',0.5],
['FC',2304,1024,True],
['RL'],
['D',0.5],
['FC',1024,512,True],
['RL'],
['FC',512,100,True]
]
VGG_16_cfg_table = [
['C','BRL',True,3,64,3,1,1,True],
['C','BRL',False,64,64,3,1,1,True],
['MP',2,2,0],
['C','BRL',False,64,128,3,1,1,True],
['C','BRL',False,128,128,3,1,1,True],
['MP',2,2,0],
['C','BRL',False,128,256,3,1,1,True],
['C','BRL',False,256,256,3,1,1,True],
['C','BRL',False,256,256,3,1,1,True],
['MP',2,2,0],
['C','BRL',False,256,512,3,1,1,True],
['C','BRL',False,512,512,3,1,1,True],
['C','BRL',False,512,512,3,1,1,True],
['MP',2,2,0],
['C','BRL',False,512,512,3,1,1,True],
['C','BRL',False,512,512,3,1,1,True],
['C','BRL',False,512,512,3,1,1,True],
['MP',2,2,0],
['VW'],
['FC',512,4096,True],
['RL'],
['D',0.5],
['FC',4096,4096,True],
['RL'],
['D',0.5],
['FC',4096,100,True]
]
VGG_19_cfg_table = [
['C','BRL',True,3,64,3,1,1,True],
['C','BRL',False,64,64,3,1,1,True],
['MP',2,2,0],
['C','BRL',False,64,128,3,1,1,True],
['C','BRL',False,128,128,3,1,1,True],
['MP',2,2,0],
['C','BRL',False,128,256,3,1,1,True],
['C','BRL',False,256,256,3,1,1,True],
['C','BRL',False,256,256,3,1,1,True],
['C','BRL',False,256,256,3,1,1,True],
['MP',2,2,0],
['C','BRL',False,256,512,3,1,1,True],
['C','BRL',False,512,512,3,1,1,True],
['C','BRL',False,512,512,3,1,1,True],
['C','BRL',False,512,512,3,1,1,True],
['MP',2,2,0],
['C','BRL',False,512,512,3,1,1,True],
['C','BRL',False,512,512,3,1,1,True],
['C','BRL',False,512,512,3,1,1,True],
['C','BRL',False,512,512,3,1,1,True],
['MP',2,2,0],
['VW'],
['FC',512,4096,True],
['RL'],
['D',0.5],
['FC',4096,4096,True],
['RL'],
['D',0.5],
['FC',4096,100,True]
]
Inception_BN_cfg_table = [
['C','',True,3,64,3,1,1,True],
['RL'],
['C','',False,64,64,3,1,1,True],
['RL'],
['Inc',0],
['Inc',1],
['MP',3,2,1],
['Inc',2],
['Inc',3],
['Inc',4],
['Inc',5],
['Inc',6],
['MP',3,2,1],
['Inc',7],
['Inc',8],
['AAP',1],
['C','',False,1024,100,1,1,0,True],
['VW']
]
model_cfg_table = {
'AlexNet' : AlexNet_cfg_table,
'AlexNet_BN' : AlexNet_BN_cfg_table,
'VGG_16' : VGG_16_cfg_table,
'VGG_19' : VGG_19_cfg_table,
'Inception_BN' : Inception_BN_cfg_table,
'ResNet_18' : ResNet_18_cfg_table,
'ResNet_50' : ResNet_50_cfg_table,
'ResNet_152' : ResNet_152_cfg_table,
'MobileNetV2' : MobileNetV2_cfg_table
}
#每行对应一个Inc结构(channel)的参数表
inc_ch_table=[
[ 64, 64, 96,128, 16, 32, 32],#3a
[256,128,128,192, 32, 96, 64],#3b
[480,192, 96,208, 16, 48, 64],#4a
[512,160,112,224, 24, 64, 64],#4b
[512,128,128,256, 24, 64, 64],#4c
[512,112,144,288, 32, 64, 64],#4d
[528,256,160,320, 32,128,128],#4e
[832,256,160,320, 32,128,128],#5a
[832,384,192,384, 48,128,128] #5b
]
# br0,br1,br2,br3 <- br1x1,br3x3,br5x5,brM
# 每个子数组对应Inc结构中一个分支的结构,均默认含'BRL'参数,bias为False
# Conv层第2、3个参数是对应Inc结构(即ch_table中的一行)中的索引
# 由于每个Inc结构操作一致,只有权重不同,使用索引而非具体值,方便复用
# 各分支后还有Concat操作,由于只有唯一结构,未特殊说明
# conv: 'C', ('BRL' default), in_ch_idex, out_ch_idx, kernel_size, stride, padding, (bias: True default)
# maxpool: 'MP', kernel_size, stride, padding
# relu: 'RL'
inc_cfg_table = [
[
['C',0,1,1,1,0]
],
[
['C',0,2,1,1,0],
['C',2,3,3,1,1]
],
[
['C',0,4,1,1,0],
['C',4,5,5,1,2]
],
[
['MP',3,1,1],
['RL'],
['C',0,6,1,1,0]
]
]
# ml_cfg_table = []
#BasicBlock
#value: downsample,inplanes,planes,planes*expansion,stride,1(dafault stride and group)
bblk_ch_table = [
[False, 16, 16, 16,1,1], #layer1,first
[False, 16, 16, 16,1,1], # other
[True, 16, 32, 32,2,1], #layer2
[False, 32, 32, 32,1,1],
[True, 32, 64, 64,2,1], #layer3
[False, 64, 64, 64,1,1],
[True, 64,128,128,2,1], #layer4
[False,128,128,128,1,1]
]
#conv: 'C','B'/'BRL'/'BRS', in_ch_idx, out_ch_idx, kernel_sz, stride_idx, padding, groups_idx (bias: True default)
#add: 'AD', unconditonal. unconditonal为true或flag为true时将outs中两元素相加
bblk_cfg_table = [
[
['C','BRL',1,2,3,4,1,5],
['C','B' ,2,2,3,5,1,5],
],
# downsample, 仅当downsample传入为True时使用
[
['C','B' ,1,3,1,4,0,5]
],
# 分支交汇后动作
[
['AD',True],
['RL']
]
]
#BottleNeck
#value: downsample,inplanes,planes,planes*expansion,stride,1(dafault stride and group)
btnk_ch_table = [
[True, 16, 16, 64,1,1], #layer1,first
[False, 64, 16, 64,1,1], # other
[True, 64, 32,128,2,1], #layer2
[False,128, 32,128,1,1],
[True, 128, 64,256,2,1], #layer3
[False,256, 64,256,1,1],
[True, 256,128,512,2,1], #layer4
[False,512,128,512,1,1]
]
#conv: 'C','B'/'BRL'/'BRS', in_ch_idx, out_ch_idx, kernel_sz, stride_idx, padding, groups_idx (bias: True default)
#add: 'AD', unconditonal. unconditonal为true或flag为true时将outs中两元素相加
btnk_cfg_table = [
[
['C','BRL',1,2,1,5,0,5],
['C','BRL',2,2,3,4,1,5],
['C','B' ,2,3,1,5,0,5]
],
# downsample, 仅当downsample传入为True时使用
[
['C','B' ,1,3,1,4,0,5]
],
# 分支交汇后动作
[
['AD',True],
['RL']
]
]
#InvertedResidual
#value: identity_flag, in_ch, out_ch, in_ch*expand_ratio, stride, 1(dafault stride and group)
ires_ch_table = [
[False, 32, 16, 32,1,1], #layer1,first
[ True, 16, 16, 16,1,1], # other
[False, 16, 24, 96,2,1], #layer2
[ True, 24, 24, 144,1,1],
[False, 24, 32, 144,2,1], #layer3
[ True, 32, 32, 192,1,1],
[False, 32, 96, 192,1,1], #layer4
[ True, 96, 96, 576,1,1],
[False, 96,160, 576,2,1], #layer5
[ True,160,160, 960,1,1],
[False,160,320, 960,1,1], #layer6
[ True,320,320,1920,1,1]
]
#conv: 'C','B'/'BRL'/'BRS', in_ch_idx, out_ch_idx, kernel_sz, stride_idx, padding, groups_idx (bias: True default)
#add: 'AD', unconditonal. unconditonal为true或flag为true时将outs中两元素相加
ires_cfg_table = [
[
['C','BRS',1,3,1,5,0,5],
['C','BRS',3,3,3,4,1,3],
['C','B' ,3,2,1,5,0,5]
],
# identity_br empty
[
],
# 分支汇合后操作
[
['AD',False] #有条件的相加
]
]
\ No newline at end of file
# ------------ General options ----------------------------------------
save_path = "./log_cifar100_ResNet_epoch1600/"
dataPath = "/lustre/datasets/CIFAR100"
dataset = "cifar100" # options: imagenet | cifar100
nGPU = 1 # number of GPUs to use by default
GPU = 0 # default gpu to use, options: range(nGPU)
visible_devices = "0"
# ------------- Data options -------------------------------------------
nThreads = 8 # number of data loader threads
# ---------- Optimization options for S --------------------------------------
# nEpochs = 400 # number of total epochs to train 400
nEpochs = 1600
batchSize = 200 # batchsize
momentum = 0.9 # momentum 0.9
weightDecay = 1e-4 # weight decay 1e-4
opt_type = "SGD"
warmup_epochs = 4 # number of epochs for warmup
lr_S = 0.0001 # initial learning rate = 0.00001
lrPolicy_S = "multi_step" # options: multi_step | linear | exp | const | step
step_S = [100,200,300] # step for linear or exp learning rate policy default [100, 200, 300]
decayRate_S = 0.1 # lr decay rate
# ---------- Model options ---------------------------------------------
experimentID = "_cifar100_4bit_"
nClasses = 100 # number of classes in the dataset
# ---------- Quantization options ---------------------------------------------
qw = 4
qa = 4
# ----------KD options ---------------------------------------------
temperature = 20
alpha = 1
# ----------Generator options ---------------------------------------------
latent_dim = 100
img_size = 32
channels = 3
lr_G = 0.001 # default 0.001
lrPolicy_G = "multi_step" # options: multi_step | linear | exp | const | step
#step_G = [100,200,300] # step for linear or exp learning rate policy
step_G = [1000,1200,1400]
decayRate_G = 0.1 # lr decay rate
b1 = 0.5
b2 = 0.999
\ No newline at end of file
# -*- coding: utf-8 -*-
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
# 在BN层基础上改动
class ConditionalBatchNorm2d(nn.BatchNorm2d):
"""Conditional Batch Normalization"""
def __init__(self, num_features, eps=1e-05, momentum=0.1,
affine=False, track_running_stats=True):
super(ConditionalBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats
)
def forward(self, input, weight, bias, **kwargs):
self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
self.num_batches_tracked += 1
# 累计移动平均值
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / self.num_batches_tracked.item()
else: # use exponential moving average
exponential_average_factor = self.momentum
output = F.batch_norm(input, self.running_mean, self.running_var,
self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
if weight.dim() == 1:
weight = weight.unsqueeze(0)
if bias.dim() == 1:
bias = bias.unsqueeze(0)
size = output.size()
weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
return weight * output + bias
class CategoricalConditionalBatchNorm2d(ConditionalBatchNorm2d):
def __init__(self, num_classes, num_features, eps=1e-5, momentum=0.1,
affine=False, track_running_stats=True):
super(CategoricalConditionalBatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats
)
self.weights = nn.Embedding(num_classes, num_features)
self.biases = nn.Embedding(num_classes, num_features)
self._initialize()
def _initialize(self):
init.ones_(self.weights.weight.data)
init.zeros_(self.biases.weight.data)
def forward(self, input, c, **kwargs):
weight = self.weights(c)
bias = self.biases(c)
return super(CategoricalConditionalBatchNorm2d, self).forward(input, weight, bias)
if __name__ == '__main__':
"""Forward computation check."""
import torch
size = (3, 3, 12, 12)
#前两个维度
batch_size, num_features = size[:2]
print('# Affirm embedding output')
naive_bn = nn.BatchNorm2d(3)
idx_input = torch.tensor([1, 2, 0], dtype=torch.long)
embedding = nn.Embedding(3, 3)
weights = embedding(idx_input)
print('# weights size', weights.size())
empty = torch.tensor((), dtype=torch.float)
running_mean = empty.new_zeros((3,))
running_var = empty.new_ones((3,))
naive_bn_W = naive_bn.weight
# print('# weights from embedding | type {}\n'.format(type(weights)), weights)
# print('# naive_bn_W | type {}\n'.format(type(naive_bn_W)), naive_bn_W)
input = torch.rand(*size, dtype=torch.float32)
print('input size', input.size())
print('input ndim ', input.dim())
_ = naive_bn(input)
print('# batch_norm with given weights')
try:
with torch.no_grad():
output = F.batch_norm(input, running_mean, running_var,
weights, naive_bn.bias, False, 0.0, 1e-05)
except Exception as e:
print("\tFailed to use given weights")
print('# Error msg:', e)
print()
else:
print("Succeeded to use given weights")
print('\n# Batch norm before use given weights')
with torch.no_grad():
tmp_out = F.batch_norm(input, running_mean, running_var,
naive_bn_W, naive_bn.bias, False, .0, 1e-05)
weights_cast = weights.unsqueeze(-1).unsqueeze(-1)
weights_cast = weights_cast.expand(tmp_out.size())
try:
out = weights_cast * tmp_out
except Exception:
print("Failed")
else:
print("Succeeded!")
print('\t {}'.format(out.size()))
print(type(tuple(out.size())))
print('--- condBN and catCondBN ---')
catCondBN = CategoricalConditionalBatchNorm2d(3, 3)
output = catCondBN(input, idx_input)
assert tuple(output.size()) == size
condBN = ConditionalBatchNorm2d(3)
idx = torch.tensor([1], dtype=torch.long)
out = catCondBN(input, idx)
print('cat cond BN weights\n', catCondBN.weights.weight.data)
print('cat cond BN biases\n', catCondBN.biases.weight.data)
"""
data loder for loading data
"""
import os
import math
import torch
import torch.utils.data as data
import numpy as np
from PIL import Image
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import struct
__all__ = ["DataLoader", "PartDataLoader"]
class ImageLoader(data.Dataset):
def __init__(self, dataset_dir, transform=None, target_transform=None):
class_list = os.listdir(dataset_dir)
datasets = []
for cla in class_list:
cla_path = os.path.join(dataset_dir, cla)
files = os.listdir(cla_path)
for file_name in files:
file_path = os.path.join(cla_path, file_name)
if os.path.isfile(file_path):
# datasets.append((file_path, tuple([float(v) for v in int(cla)])))
datasets.append((file_path, [float(cla)]))
# print(datasets)
# assert False
self.dataset_dir = dataset_dir
self.datasets = datasets
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
frames = []
file_path, label = self.datasets[index]
noise = torch.load(file_path, map_location=torch.device('cpu'))
return noise, torch.Tensor(label)
def __len__(self):
return len(self.datasets)
class DataLoader(object):
"""
data loader for CV data sets
"""
def __init__(self, dataset, batch_size, n_threads=4,
ten_crop=False, data_path='/home/dataset/', logger=None):
"""
create data loader for specific data set
:params n_treads: number of threads to load data, default: 4
:params ten_crop: use ten crop for testing, default: False
:params data_path: path to data set, default: /home/dataset/
"""
self.dataset = dataset
self.batch_size = batch_size
self.n_threads = n_threads
self.ten_crop = ten_crop
self.data_path = data_path
self.logger = logger
self.dataset_root = data_path
self.logger.info("|===>Creating data loader for " + self.dataset)
if self.dataset in ["cifar100"]:
self.train_loader, self.test_loader = self.cifar(
dataset=self.dataset)
elif self.dataset in ["imagenet"]:
self.train_loader, self.test_loader = self.imagenet(
dataset=self.dataset)
else:
assert False, "invalid data set"
def getloader(self):
"""
get train_loader and test_loader
"""
return self.train_loader, self.test_loader
def imagenet(self, dataset="imagenet"):
traindir = os.path.join(self.data_path, "train")
testdir = os.path.join(self.data_path, "val")
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_loader = torch.utils.data.DataLoader(
dsets.ImageFolder(traindir, transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
batch_size=self.batch_size,
shuffle=True,
num_workers=self.n_threads,
pin_memory=True)
test_transform = transforms.Compose([
transforms.Resize(256),
# transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize
])
test_loader = torch.utils.data.DataLoader(
dsets.ImageFolder(testdir, test_transform),
batch_size=self.batch_size,
shuffle=False,
num_workers=self.n_threads,
pin_memory=False)
return train_loader, test_loader
def cifar(self, dataset="cifar100"):
"""
dataset: cifar
"""
if dataset == "cifar10":
norm_mean = [0.49139968, 0.48215827, 0.44653124]
norm_std = [0.24703233, 0.24348505, 0.26158768]
elif dataset == "cifar100":
norm_mean = [0.50705882, 0.48666667, 0.44078431]
norm_std = [0.26745098, 0.25568627, 0.27607843]
# norm_mean = [0.4914, 0.4822, 0.4465]
# norm_std = [0.2023, 0.1994, 0.2010]
else:
assert False, "Invalid cifar dataset"
test_data_root = self.dataset_root
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(norm_mean, norm_std)])
if self.dataset == "cifar10":
test_dataset = dsets.CIFAR10(root=test_data_root,
train=False,
transform=test_transform)
elif self.dataset == "cifar100":
test_dataset = dsets.CIFAR100(root=test_data_root,
train=False,
transform=test_transform,
download=True)
else:
assert False, "invalid data set"
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=200,
# batch_size=128,
shuffle=False,
pin_memory=True,
num_workers=self.n_threads)
return None, test_loader
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
\ No newline at end of file
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
# ------------ General options ----------------------------------------
save_path = "./save_ImageNet/"
dataPath = "/home/datasets/Datasets/imagenet"
dataset = "imagenet" # options: imagenet | cifar100
nGPU = 1 # number of GPUs to use by default
GPU = 0 # default gpu to use, options: range(nGPU)
visible_devices = "2"
# ------------- Data options -------------------------------------------
nThreads = 8 # number of data loader threads
# ---------- Optimization options --------------------------------------
nEpochs = 400 # number of total epochs to train 400
batchSize = 16 # batchsize
momentum = 0.9 # momentum 0.9
weightDecay = 1e-4 # weight decay 1e-4
opt_type = "SGD"
warmup_epochs = 50 # number of epochs for warmup
lr_S = 0.000001 # initial learning rate = 0.000001
lrPolicy_S = "multi_step" # options: multi_step | linear | exp | const | step
step_S = [100,200,300] # step for linear or exp learning rate policy default [200, 300, 400]
decayRate_S = 0.1 # lr decay rate
# ---------- Model options ---------------------------------------------
experimentID = "imganet_4bit_"
nClasses = 1000 # number of classes in the dataset
# ---------- Quantization options ---------------------------------------------
qw = 4
qa = 4
# ----------KD options ---------------------------------------------
temperature = 20
alpha = 1
# ----------Generator options ---------------------------------------------
latent_dim = 100
img_size = 224
channels = 3
lr_G = 0.001 # default 0.001
lrPolicy_G = "multi_step" # options: multi_step | linear | exp | const | step
step_G = [100,200,300] # step for linear or exp learning rate policy
decayRate_G = 0.1 # lr decay rate
b1 = 0.5
b2 = 0.999
\ No newline at end of file
from model import *
import argparse
import datetime
import logging
import os
import time
import traceback
import sys
import copy
import torch
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torch.nn as nn
# option file should be modified according to your expriment
from options import Option
from dataloader import DataLoader
from trainer import Trainer
import utils as utils
from quantization_utils.quant_modules import *
# from pytorchcv.model_provider import get_model as ptcv_get_model
from conditional_batchnorm import CategoricalConditionalBatchNorm2d
# 生成器,也是个网络
class Generator(nn.Module):
def __init__(self, options=None, conf_path=None):
super(Generator, self).__init__()
# 注意这里的设置
self.settings = options or Option(conf_path)
# 注意这里有embedding层,两个分别是词典大小和向量长度
# 用于将标签映射为向量
self.label_emb = nn.Embedding(self.settings.nClasses, self.settings.latent_dim)
self.init_size = self.settings.img_size // 4
self.l1 = nn.Sequential(nn.Linear(self.settings.latent_dim, 128 * self.init_size ** 2))
self.conv_blocks0 = nn.Sequential(
nn.BatchNorm2d(128),
)
self.conv_blocks1 = nn.Sequential(
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
)
self.conv_blocks2 = nn.Sequential(
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, self.settings.channels, 3, stride=1, padding=1),
nn.Tanh(),
nn.BatchNorm2d(self.settings.channels, affine=False)
)
def forward(self, z, labels):
#label对应的向量和噪声向量相乘,得到输入
gen_input = torch.mul(self.label_emb(labels), z)
out = self.l1(gen_input)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks0(out)
img = nn.functional.interpolate(img, scale_factor=2)
img = self.conv_blocks1(img)
img = nn.functional.interpolate(img, scale_factor=2)
img = self.conv_blocks2(img)
return img
class Generator_imagenet(nn.Module):
def __init__(self, options=None, conf_path=None):
self.settings = options or Option(conf_path)
super(Generator_imagenet, self).__init__()
self.init_size = self.settings.img_size // 4
self.l1 = nn.Sequential(nn.Linear(self.settings.latent_dim, 128 * self.init_size ** 2))
self.conv_blocks0_0 = CategoricalConditionalBatchNorm2d(1000, 128)
self.conv_blocks1_0 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
self.conv_blocks1_1 = CategoricalConditionalBatchNorm2d(1000, 128, 0.8)
self.conv_blocks1_2 = nn.LeakyReLU(0.2, inplace=True)
self.conv_blocks2_0 = nn.Conv2d(128, 64, 3, stride=1, padding=1)
self.conv_blocks2_1 = CategoricalConditionalBatchNorm2d(1000, 64, 0.8)
self.conv_blocks2_2 = nn.LeakyReLU(0.2, inplace=True)
self.conv_blocks2_3 = nn.Conv2d(64, self.settings.channels, 3, stride=1, padding=1)
self.conv_blocks2_4 = nn.Tanh()
self.conv_blocks2_5 = nn.BatchNorm2d(self.settings.channels, affine=False)
def forward(self, z, labels):
out = self.l1(z)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img = self.conv_blocks0_0(out, labels)
img = nn.functional.interpolate(img, scale_factor=2)
img = self.conv_blocks1_0(img)
img = self.conv_blocks1_1(img, labels)
img = self.conv_blocks1_2(img)
img = nn.functional.interpolate(img, scale_factor=2)
img = self.conv_blocks2_0(img)
img = self.conv_blocks2_1(img, labels)
img = self.conv_blocks2_2(img)
img = self.conv_blocks2_3(img)
img = self.conv_blocks2_4(img)
img = self.conv_blocks2_5(img)
return img
class ExperimentDesign:
def __init__(self, model_name, generator=None, options=None, conf_path=None):
self.settings = options or Option(conf_path)
self.generator = generator
self.train_loader = None
self.test_loader = None
self.model = None
self.model_teacher = None
self.optimizer_state = None
self.trainer = None
self.start_epoch = 0
self.test_input = None
self.unfreeze_Flag = True
os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = self.settings.visible_devices
self.settings.set_save_path()
self.logger = self.set_logger()
self.settings.paramscheck(self.logger)
self.prepare(model_name)
def set_logger(self):
logger = logging.getLogger('baseline')
file_formatter = logging.Formatter('%(asctime)s %(levelname)s: %(message)s')
console_formatter = logging.Formatter('%(message)s')
# file log
file_handler = logging.FileHandler(os.path.join(self.settings.save_path, "train_test.log"))
file_handler.setFormatter(file_formatter)
# console log
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(console_formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
logger.setLevel(logging.INFO)
return logger
def prepare(self,model_name):
self._set_gpu()
self._set_dataloader()
self._set_model(model_name)
self._replace()
self.logger.info(self.model)
self._set_trainer()
def _set_gpu(self):
torch.manual_seed(self.settings.manualSeed)
torch.cuda.manual_seed(self.settings.manualSeed)
assert self.settings.GPU <= torch.cuda.device_count() - 1, "Invalid GPU ID"
cudnn.benchmark = True
def _set_dataloader(self):
# create data loader
data_loader = DataLoader(dataset=self.settings.dataset,
batch_size=self.settings.batchSize,
data_path=self.settings.dataPath,
n_threads=self.settings.nThreads,
ten_crop=self.settings.tenCrop,
logger=self.logger)
self.train_loader, self.test_loader = data_loader.getloader()
def _set_model(self,model_name):
# ResNet是否cifar会更改最前的conv层,7->3
# cifar不支持resnet18/50/152
if self.settings.dataset in ["cifar100"]:
self.test_input = torch.randn(1, 3, 32, 32).cuda()
# 这里student和teacher暂时相同
# 使用模型部署,并加载预训练好的全精度模型
ckpt_path = 'ckpt_full/cifar100_'+model_name+'.pt'
self.model = Model(model_name).cuda()
self.model.load_state_dict(torch.load(ckpt_path))
self.model_teacher = Model(model_name).cuda()
self.model_teacher.load_state_dict(torch.load(ckpt_path))
self.model_teacher.eval()
# 如需使用torchcv的预训练全精度模型
# 1.重新指定模型存储路径,避免存在/home中
# 2.服务器需联网
# if self.settings.dataset in ["cifar100"]:
# self.test_input = Variable(torch.randn(1, 3, 32, 32).cuda())
# # 这里student和teacher暂时相同
# self.model = ptcv_get_model('resnet20_cifar100', pretrained=True)
# self.model_teacher = ptcv_get_model('resnet20_cifar100', pretrained=True)
# self.model_teacher.eval()
# elif self.settings.dataset in ["imagenet"]:
# self.test_input = Variable(torch.randn(1, 3, 224, 224).cuda())
# self.model = ptcv_get_model('resnet18', pretrained=True)
# self.model_teacher = ptcv_get_model('resnet18', pretrained=True)
# self.model_teacher.eval()
else:
assert False, "unsupport data set: " + self.settings.dataset
def _set_trainer(self):
# set lr master
lr_master_S = utils.LRPolicy(self.settings.lr_S,
self.settings.nEpochs,
self.settings.lrPolicy_S)
lr_master_G = utils.LRPolicy(self.settings.lr_G,
self.settings.nEpochs,
self.settings.lrPolicy_G)
params_dict_S = {
'step': self.settings.step_S,
'decay_rate': self.settings.decayRate_S
}
params_dict_G = {
'step': self.settings.step_G,
'decay_rate': self.settings.decayRate_G
}
lr_master_S.set_params(params_dict=params_dict_S)
lr_master_G.set_params(params_dict=params_dict_G)
# set trainer
#trainer的train方法是训练生成器
self.trainer = Trainer(
model=self.model,
model_teacher=self.model_teacher,
generator = self.generator,
train_loader=self.train_loader,
test_loader=self.test_loader,
lr_master_S=lr_master_S,
lr_master_G=lr_master_G,
settings=self.settings,
logger=self.logger,
opt_type=self.settings.opt_type,
optimizer_state=self.optimizer_state,
run_count=self.start_epoch)
def quantize_model(self,model):
"""
Recursively quantize a pretrained single-precision model to int8 quantized model
model: pretrained single-precision model
"""
# 将预训练后的全精度模型量化为student
# hocon设置的是w4a4
weight_bit = self.settings.qw
act_bit = self.settings.qa
# quantize convolutional and linear layers
#conv和fc
if type(model) == nn.Conv2d:
quant_mod = Quant_Conv2d(weight_bit=weight_bit)
quant_mod.set_param(model)
return quant_mod
elif type(model) == nn.Linear:
quant_mod = Quant_Linear(weight_bit=weight_bit)
quant_mod.set_param(model)
return quant_mod
# quantize all the activation
# relu和relus
elif type(model) == nn.ReLU or type(model) == nn.ReLU6:
#附加了quantact层
return nn.Sequential(*[model, QuantAct(activation_bit=act_bit)])
# recursively use the quantized module to replace the single-precision module
# 这里直接在原有的层上进行了替换
elif type(model) == nn.Sequential:
#递归进行
mods = []
for n, m in model.named_children():
mods.append(self.quantize_model(m))
return nn.Sequential(*mods)
else:
q_model = copy.deepcopy(model)
for attr in dir(model): #获取所有属性名
mod = getattr(model, attr)
# BN层不替换
if isinstance(mod, nn.Module) and 'norm' not in attr:
setattr(q_model, attr, self.quantize_model(mod))
return q_model
def _replace(self):
#实现了student模型的量化
self.model = self.quantize_model(self.model)
def freeze_model(self,model):
"""
freeze the activation range
"""
if type(model) == QuantAct: #对应relu和relu6
model.fix()
#递归进行
elif type(model) == nn.Sequential:
for n, m in model.named_children():
self.freeze_model(m)
else:
for attr in dir(model):
mod = getattr(model, attr)
if isinstance(mod, nn.Module) and 'norm' not in attr:
self.freeze_model(mod)
return model
def unfreeze_model(self,model):
"""
unfreeze the activation range
"""
if type(model) == QuantAct:
model.unfix()
elif type(model) == nn.Sequential:
for n, m in model.named_children():
self.unfreeze_model(m)
else:
for attr in dir(model):
mod = getattr(model, attr)
if isinstance(mod, nn.Module) and 'norm' not in attr:
self.unfreeze_model(mod)
return model
def run(self,gen_path):
best_top1 = 100
best_top5 = 100
start_time = time.time()
# teacher固定,只跑一个epoch
test_error, test_loss, test5_error = self.trainer.test_teacher(0)
best_gen_acc = None
try:
for epoch in range(self.start_epoch, self.settings.nEpochs):
self.epoch = epoch
self.start_epoch = 0
# if epoch < 4:
# print ("\n self.unfreeze_model(self.model)\n")
# self.unfreeze_model(self.model)
# gen_acc, train_error, train_loss, train5_error = self.trainer.train(epoch=epoch)
gen_acc = self.trainer.train(epoch=epoch)
if not best_gen_acc or gen_acc > best_gen_acc:
best_gen_acc = gen_acc
torch.save(self.generator, gen_path)
# self.freeze_model(self.model)
# if self.settings.dataset in ["cifar100"]:
# test_error, test_loss, test5_error = self.trainer.test(epoch=epoch)
# elif self.settings.dataset in ["imagenet"]:
# if epoch > self.settings.warmup_epochs - 2:
# test_error, test_loss, test5_error = self.trainer.test(epoch=epoch)
# else:
# test_error = 100
# test5_error = 100
# else:
# assert False, "invalid data set"
# if best_top1 >= test_error:
# best_top1 = test_error
# best_top5 = test5_error
# 对应一组输出的3 4行,表示量化网络的效果
self.logger.info(">>> Cur Gen acc: {:f}, Best Gen acc: {:f}".format(gen_acc,best_gen_acc))
# self.logger.info("#==>Best Result is: Top1 Error: {:f}, Top5 Error: {:f}".format(best_top1, best_top5))
# self.logger.info("#==>Best Result is: Top1 Accuracy: {:f}, Top5 Accuracy: {:f}\n".format(100 - best_top1,
# 100 - best_top5))
except BaseException as e:
self.logger.error("Training is terminating due to exception: {}".format(str(e)))
traceback.print_exc()
end_time = time.time()
time_interval = end_time - start_time
t_string = "Running Time is: " + str(datetime.timedelta(seconds=time_interval)) + "\n"
self.logger.info(t_string)
return best_top1, best_top5
def main():
parser = argparse.ArgumentParser(description='Baseline')
parser.add_argument('--conf_path', type=str, metavar='conf_path',
help='input the path of config file')
parser.add_argument('--id', type=int, metavar='experiment_id',
help='Experiment ID')
parser.add_argument('--model_name',metavar='model_name',type=str,help='Model Name')
args = parser.parse_args()
option = Option(args.conf_path)
option.manualSeed = 1
option.experimentID = args.model_name + option.experimentID
if option.dataset in ["cifar100"]:
generator = Generator(option)
elif option.dataset in ["imagenet"]:
generator = Generator_imagenet(option)
else:
assert False, "invalid data set"
experiment = ExperimentDesign(args.model_name,generator, option)
print('>>> Gen: '+args.model_name)
gen_path = 'ckpt_gen_rn_1600/cifar100_'+args.model_name+'.pt'
experiment.run(gen_path)
if __name__ == '__main__':
main()
import torch.nn as nn
from cfg import *
from module import *
from model_deployment import *
class Model(nn.Module):
def __init__(self,model_name):
super(Model, self).__init__()
self.cfg_table = model_cfg_table[model_name]
make_layers(self,self.cfg_table)
# # 参数初始化
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# elif isinstance(m, nn.BatchNorm2d):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
# elif isinstance(m, nn.Linear):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
def forward(self,x):
x = model_forward(self,self.cfg_table,x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
model_quantize(self,self.cfg_table,quant_type,num_bits,e_bits)
def quantize_forward(self,x):
return model_utils(self,self.cfg_table,func='forward',x=x)
def freeze(self):
model_utils(self,self.cfg_table,func='freeze')
def quantize_inference(self,x):
return model_utils(self,self.cfg_table,func='inference',x=x)
def fakefreeze(self):
model_utils(self,self.cfg_table,func='fakefreeze')
# if __name__ == "__main__":
# model = Inception_BN()
# model.quantize('INT',8,3)
# print(model.named_modules)
# print('-------')
# print(model.named_parameters)
# print(len(model.conv0.named_parameters()))
\ No newline at end of file
import torch.nn as nn
import torch.nn.functional as F
from cfg import *
from module import *
def make_layers(model,cfg_table):
for i in range(len(cfg_table)):
cfg = cfg_table[i]
if cfg[0] == 'Inc':
make_inc_layers(model,cfg[1])
elif cfg[0] == 'ML':
make_ml_layers(model,cfg[1],cfg[2],cfg[3])
elif cfg[0] == 'C':
name = 'conv%d'%i
layer = nn.Conv2d(cfg[3],cfg[4],kernel_size=cfg[5],stride=cfg[6],padding=cfg[7],bias=cfg[8])
model.add_module(name,layer)
if 'B' in cfg[1]:
name = 'bn%d'%i
layer = nn.BatchNorm2d(cfg[4])
model.add_module(name,layer)
if 'RL' in cfg[1]:
name = 'relu%d'%i
layer = nn.ReLU(True)
model.add_module(name,layer)
elif 'RS' in cfg[1]:
name = 'relus%d'%i
layer = nn.ReLU6(True)
model.add_module(name,layer)
elif cfg[0] == 'RL':
name = 'relu%d'%i
layer = nn.ReLU(True)
model.add_module(name,layer)
elif cfg[0] == 'RS':
name = 'relus%d'%i
layer = nn.ReLU6(True)
model.add_module(name,layer)
elif cfg[0] == 'MP':
name = 'pool%d'%i
layer = nn.MaxPool2d(kernel_size=cfg[1],stride=cfg[2],padding=cfg[3])
model.add_module(name,layer)
elif cfg[0] == 'AAP':
name = 'aap%d'%i
layer = nn.AdaptiveAvgPool2d(cfg[1])
model.add_module(name,layer)
elif cfg[0] == 'FC':
name = 'fc%d'%i
layer = nn.Linear(cfg[1],cfg[2],bias=cfg[3])
model.add_module(name,layer)
elif cfg[0] == 'D':
name = 'drop%d'%i
layer = nn.Dropout(cfg[1])
model.add_module(name,layer)
def model_forward(model,cfg_table,x):
for i in range(len(cfg_table)):
cfg = cfg_table[i]
if cfg[0] == 'Inc':
x = inc_forward(model,cfg[1],x)
elif cfg[0] == 'ML':
x = ml_forward(model,cfg[1],cfg[2],cfg[3],x)
elif cfg[0] == 'C':
name = 'conv%d'%i
layer = getattr(model,name)
x = layer(x)
if 'B' in cfg[1]:
name = 'bn%d'%i
layer = getattr(model,name)
x = layer(x)
if 'RL' in cfg[1]:
name = 'relu%d'%i
layer = getattr(model,name)
x = layer(x)
elif 'RS' in cfg[1]:
name = 'relus%d'%i
layer = getattr(model,name)
x = layer(x)
elif cfg[0] == 'RL':
name = 'relu%d'%i
layer = getattr(model,name)
x = layer(x)
elif cfg[0] == 'RS':
name = 'relus%d'%i
layer = getattr(model,name)
x = layer(x)
elif cfg[0] == 'MP':
name = 'pool%d'%i
layer = getattr(model,name)
x = layer(x)
elif cfg[0] == 'AAP':
name = 'aap%d'%i
layer = getattr(model,name)
x = layer(x)
elif cfg[0] == 'FC':
name = 'fc%d'%i
layer = getattr(model,name)
x = layer(x)
elif cfg[0] == 'D':
name = 'drop%d'%i
layer = getattr(model,name)
x = layer(x)
elif cfg[0] == 'VW':
if len(cfg) == 1: #default
x = x.view(x.size(0),-1)
elif cfg[0] == 'SM':
x = F.softmax(x,dim=1)
return x
def model_quantize(model,cfg_table,quant_type,num_bits,e_bits):
for i in range(len(cfg_table)):
cfg = cfg_table[i]
if cfg[0] == 'Inc':
inc_quantize(model,cfg[1],quant_type,num_bits,e_bits)
elif cfg[0] == 'ML':
ml_quantize(model,cfg[1],cfg[2],cfg[3],quant_type,num_bits,e_bits)
elif cfg[0] == 'C':
conv_name = 'conv%d'%i
conv_layer = getattr(model,conv_name)
qname = 'q_'+conv_name
if 'B' in cfg[1]:
bn_name = 'bn%d'%i
bn_layer = getattr(model,bn_name)
if 'RL' in cfg[1]:
qlayer = QConvBNReLU(quant_type,conv_layer,bn_layer,qi=cfg[2],num_bits=num_bits,e_bits=e_bits)
elif 'RS' in cfg[1]:
qlayer = QConvBNReLU6(quant_type,conv_layer,bn_layer,qi=cfg[2],num_bits=num_bits,e_bits=e_bits)
else:
qlayer = QConvBN(quant_type,conv_layer,bn_layer,qi=cfg[2],num_bits=num_bits,e_bits=e_bits)
else:
qlayer = QConv2d(quant_type,conv_layer,qi=cfg[2],num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
elif cfg[0] == 'RL':
name = 'relu%d'%i
qname = 'q_'+name
qlayer = QReLU(quant_type,num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
elif cfg[0] == 'RS':
name = 'relus%d'%i
qname = 'q_'+name
qlayer = QReLU6(quant_type,num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
elif cfg[0] == 'MP':
name = 'pool%d'%i
qname = 'q_'+name
qlayer = QMaxPooling2d(quant_type,kernel_size=cfg[1],stride=cfg[2],padding=cfg[3],num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
elif cfg[0] == 'AAP':
name = 'aap%d'%i
qname = 'q_'+name
qlayer = QAdaptiveAvgPool2d(quant_type,output_size=cfg[1],num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
elif cfg[0] == 'FC':
name = 'fc%d'%i
layer = getattr(model,name)
qname = 'q_'+name
qlayer = QLinear(quant_type,layer,num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
# 增加了func='fakefreeze'
def model_utils(model,cfg_table,func,x=None):
last_qo = None
# 表示已经经过反量化,用于区别反量化不再最后,而是在softmax前的情形
done_flag = False
for i in range(len(cfg_table)):
cfg = cfg_table[i]
if cfg[0] == 'Inc':
x,last_qo = inc_utils(model,cfg[1],func,x,last_qo)
elif cfg[0] == 'ML':
x,last_qo = ml_utils(model,cfg[1],cfg[2],cfg[3],func,x,last_qo)
elif cfg[0] == 'C':
qname = 'q_conv%d'%i
qlayer = getattr(model,qname)
if func == 'forward':
x = qlayer(x)
elif func == 'inference':
# cfg[2]为True表示起始层,需要量化
if cfg[2]:
x = qlayer.qi.quantize_tensor(x)
x = qlayer.quantize_inference(x)
elif func == 'freeze':
qlayer.freeze(last_qo)
elif func == 'fakefreeze':
qlayer.fakefreeze()
last_qo = qlayer.qo
elif cfg[0] == 'RL':
qname = 'q_relu%d'%i
qlayer = getattr(model,qname)
if func == 'forward':
x = qlayer(x)
elif func == 'inference':
x = qlayer.quantize_inference(x)
elif func == 'freeze':
qlayer.freeze(last_qo)
elif cfg[0] == 'RS':
qname = 'q_relus%d'%i
qlayer = getattr(model,qname)
if func == 'forward':
x = qlayer(x)
elif func == 'inference':
x = qlayer.quantize_inference(x)
elif func == 'freeze':
qlayer.freeze(last_qo)
elif cfg[0] == 'MP':
qname = 'q_pool%d'%i
qlayer = getattr(model,qname)
if func == 'forward':
x = qlayer(x)
elif func == 'inference':
x = qlayer.quantize_inference(x)
elif func == 'freeze':
qlayer.freeze(last_qo)
elif cfg[0] == 'AAP':
qname = 'q_aap%d'%i
qlayer = getattr(model,qname)
if func == 'forward':
x = qlayer(x)
elif func == 'inference':
x = qlayer.quantize_inference(x)
elif func == 'freeze':
qlayer.freeze(last_qo)
last_qo = qlayer.qo
elif cfg[0] == 'FC':
qname = 'q_fc%d'%i
qlayer = getattr(model,qname)
if func == 'forward':
x = qlayer(x)
elif func == 'inference':
x = qlayer.quantize_inference(x)
elif func == 'freeze':
qlayer.freeze(last_qo)
elif func == 'fakefreeze':
qlayer.fakefreeze()
last_qo = qlayer.qo
elif cfg[0] == 'D':
if func == 'forward':
name = 'drop%d'%i
layer = getattr(model,name)
x = layer(x)
elif cfg[0] == 'VW':
if func == 'inference' or func == 'forward':
if len(cfg) == 1: #default
x = x.view(x.size(0),-1)
elif cfg[0] == 'SM':
if func == 'inference':
done_flag = True
x = last_qo.dequantize_tensor(x)
x = F.softmax(x,dim=1)
elif func == 'forward':
x = F.softmax(x,dim=1)
if func == 'inference' and not done_flag:
x = last_qo.dequantize_tensor(x)
return x
def make_inc_layers(model,inc_idx):
inc_name = 'inc%d'%inc_idx
ch = inc_ch_table[inc_idx]
for i in range(4): # branch
prefix = inc_name+'_br%d_'%i
for j in range(len(inc_cfg_table[i])):
cfg = inc_cfg_table[i][j]
if cfg[0] == 'MP':
name = prefix+'pool%d'%j
layer =nn.MaxPool2d(kernel_size=cfg[1],stride=cfg[2],padding=cfg[3])
model.add_module(name,layer)
elif cfg[0] == 'RL':
name=prefix+'relu%d'%j
layer=nn.ReLU(True)
model.add_module(name,layer)
elif cfg[0] == 'C': # 'BRL' default
name=prefix+'conv%d'%j
layer=nn.Conv2d(ch[cfg[1]],ch[cfg[2]],kernel_size=cfg[3],stride=cfg[4],padding=cfg[5],bias=False)
model.add_module(name,layer)
name=prefix+'bn%d'%j
layer=nn.BatchNorm2d(ch[cfg[2]])
model.add_module(name,layer)
name=prefix+'relu%d'%j
layer=nn.ReLU(True)
model.add_module(name,layer)
def inc_forward(model,inc_idx,x):
inc_name = 'inc%d'%inc_idx
outs = []
for i in range(4):
prefix = inc_name+'_br%d_'%i
tmp = x
for j in range(len(inc_cfg_table[i])):
cfg = inc_cfg_table[i][j]
if cfg[0] == 'MP':
name = prefix+'pool%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
elif cfg[0] == 'RL':
name=prefix+'relu%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
elif cfg[0] == 'C': # 'BRL' default
name=prefix+'conv%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
name=prefix+'bn%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
name=prefix+'relu%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
outs.append(tmp)
out = torch.cat(outs,1)
return out
def inc_quantize(model,inc_idx,quant_type,num_bits,e_bits):
inc_name = 'inc%d'%inc_idx
for i in range(4):
prefix = inc_name+'_br%d_'%i
for j in range(len(inc_cfg_table[i])):
cfg = inc_cfg_table[i][j]
if cfg[0] == 'MP':
name = prefix+'pool%d'%j
qname = 'q_'+name
qlayer = QMaxPooling2d(quant_type,kernel_size=cfg[1],stride=cfg[2],padding=cfg[3],num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
elif cfg[0] == 'RL':
name = prefix+'relu%d'%j
qname = 'q_'+name
qlayer = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
model.add_module(qname,qlayer)
elif cfg[0] == 'C': # 'BRL' default
conv_name=prefix+'conv%d'%j
conv_layer=getattr(model,conv_name)
bn_name=prefix+'bn%d'%j
bn_layer=getattr(model,bn_name)
qname='q_'+conv_name
qlayer=QConvBNReLU(quant_type, conv_layer, bn_layer, num_bits=num_bits, e_bits=e_bits)
model.add_module(qname,qlayer)
qname = 'q_'+inc_name+'_concat'
qlayer = QConcat(quant_type,4,qi_array=False,qo=True,num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
def inc_utils(model,inc_idx,func,x=None,qo=None):
inc_name = 'inc%d'%inc_idx
outs=[]
qos=[]
for i in range(4):
qprefix = 'q_'+inc_name+'_br%d_'%i
tmp = x
last_qo = qo
for j in range(len(inc_cfg_table[i])):
cfg = inc_cfg_table[i][j]
if cfg[0] == 'MP':
qname = qprefix+'pool%d'%j
qlayer = getattr(model,qname)
if func == 'forward':
tmp = qlayer(tmp)
elif func == 'inference':
tmp = qlayer.quantize_inference(tmp)
elif func == 'freeze':
qlayer.freeze(last_qo)
elif cfg[0] == 'RL':
qname = qprefix+'relu%d'%j
qlayer = getattr(model,qname)
if func == 'forward':
tmp = qlayer(tmp)
elif func == 'inference':
tmp = qlayer.quantize_inference(tmp)
elif func == 'freeze':
qlayer.freeze(last_qo)
elif cfg[0] == 'C': # 'BRL' default
qname = qprefix+'conv%d'%j
qlayer = getattr(model,qname)
if func == 'forward':
tmp = qlayer(tmp)
elif func == 'inference':
tmp = qlayer.quantize_inference(tmp)
elif func == 'freeze':
qlayer.freeze(last_qo)
elif func == 'fakefreeze':
qlayer.fakefreeze()
last_qo = qlayer.qo
outs.append(tmp)
qos.append(last_qo)
qname = 'q_'+inc_name+'_concat'
qlayer = getattr(model,qname)
out = None
if func == 'forward':
out = qlayer(outs)
elif func == 'inference':
out = qlayer.quantize_inference(outs)
elif func == 'freeze':
qlayer.freeze(qos)
last_qo = qlayer.qo
return out,last_qo
def make_ml_layers(model,blk_type,ml_idx,blocks):
ml_name = 'ml%d'%ml_idx
if blk_type == 'BBLK':
blk_ch_table = bblk_ch_table
blk_cfg_table = bblk_cfg_table
elif blk_type == 'BTNK':
blk_ch_table = btnk_ch_table
blk_cfg_table = btnk_cfg_table
elif blk_type == 'IRES':
blk_ch_table = ires_ch_table
blk_cfg_table = ires_cfg_table
else:
raise ValueError("Make_ml_layers: Illegal blk_type")
#一个makelayer对应两行,分别表示第一个blk和其余的特征
make_blk_layers(model,blk_ch_table,blk_cfg_table,ml_name,2*ml_idx,0)
for i in range(1,blocks):
make_blk_layers(model,blk_ch_table,blk_cfg_table,ml_name,2*ml_idx+1,i)
# ma表示主分支,ds表示downsample
# 当cfgtable含有三个元素时,第3个表示分支合并后需经过的层。
# BasicBlock和BottleNeck合并后需经过relu层,InvertedResidual合并后无需经过relu层
def make_blk_layers(model,blk_ch_table,blk_cfg_table,ml_name,ch_idx,blk_idx):
blk_name = ml_name+'_blk%d'%blk_idx
ch = blk_ch_table[ch_idx]
for i in range(2):
if i == 0:
prefix = blk_name+'_ma_'
elif i == 1:
if ch[0]: #downsample/identity_flag
prefix = blk_name+'_ds_'
else:
continue
for j in range(len(blk_cfg_table[i])):
cfg = blk_cfg_table[i][j]
if cfg[0] == 'C':
name = prefix+'conv%d'%j
layer = nn.Conv2d(ch[cfg[2]],ch[cfg[3]],kernel_size=cfg[4],stride=ch[cfg[5]],padding=cfg[6],groups=ch[cfg[7]])
model.add_module(name,layer)
if 'B' in cfg[1]:
name = prefix+'bn%d'%j
layer=nn.BatchNorm2d(ch[cfg[3]])
model.add_module(name,layer)
if 'RL' in cfg[1]:
name = prefix+'relu%d'%j
layer = nn.ReLU(True)
model.add_module(name,layer)
elif 'RS' in cfg[1]:
name = prefix+'relus%d'%j
layer = nn.ReLU6(True)
model.add_module(name,layer)
#分支汇总
prefix = blk_name+'_'
for j in range(len(blk_cfg_table[-1])):
cfg = blk_cfg_table[-1][j]
if cfg[0] == 'RL': #当前没有blk出现汇总处有RS
name = prefix+'relu%d'%j
layer = nn.ReLU(True)
model.add_module(name,layer)
def ml_forward(model,blk_type,ml_idx,blocks,x):
ml_name = 'ml%d'%ml_idx
if blk_type == 'BBLK':
blk_ch_table = bblk_ch_table
blk_cfg_table = bblk_cfg_table
elif blk_type == 'BTNK':
blk_ch_table = btnk_ch_table
blk_cfg_table = btnk_cfg_table
elif blk_type == 'IRES':
blk_ch_table = ires_ch_table
blk_cfg_table = ires_cfg_table
else:
raise ValueError("ml_forward: Illegal blk_type")
x = blk_forward(model,blk_ch_table,blk_cfg_table,ml_name,2*ml_idx,0,x)
for i in range(1,blocks):
x = blk_forward(model,blk_ch_table,blk_cfg_table,ml_name,2*ml_idx+1,i,x)
return x
def blk_forward(model,blk_ch_table,blk_cfg_table,ml_name,ch_idx,blk_idx,x):
blk_name = ml_name+'_blk%d'%blk_idx
ch = blk_ch_table[ch_idx]
outs = []
for i in range(2):
tmp=x
if i == 0:
prefix = blk_name+'_ma_'
elif i == 1:
if ch[0]: #downsample/identity_flag
prefix = blk_name+'_ds_'
else:
outs.append(tmp)
continue
for j in range(len(blk_cfg_table[i])):
cfg = blk_cfg_table[i][j]
if cfg[0] == 'C':
name = prefix+'conv%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
if 'B' in cfg[1]:
name = prefix+'bn%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
if 'RL' in cfg[1]:
name = prefix+'relu%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
elif 'RS' in cfg[1]:
name = prefix+'relus%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
outs.append(tmp)
#分支汇总
prefix = blk_name+'_'
for j in range(len(blk_cfg_table[-1])):
cfg = blk_cfg_table[-1][j]
if cfg[0] == 'AD':
if cfg[1] or ch[0]: #无条件加或flag为true
out = outs[0] + outs[1]
else:
out = outs[0]
elif cfg[0] == 'RL':
name = prefix+'relu%d'%j
layer = getattr(model,name)
out = layer(out)
return out
def ml_quantize(model,blk_type,ml_idx,blocks,quant_type,num_bits,e_bits):
ml_name = 'ml%d'%ml_idx
if blk_type == 'BBLK':
blk_ch_table = bblk_ch_table
blk_cfg_table = bblk_cfg_table
elif blk_type == 'BTNK':
blk_ch_table = btnk_ch_table
blk_cfg_table = btnk_cfg_table
elif blk_type == 'IRES':
blk_ch_table = ires_ch_table
blk_cfg_table = ires_cfg_table
else:
raise ValueError("ml_quantize: Illegal blk_type")
blk_quantize(model,blk_ch_table,blk_cfg_table,ml_name,2*ml_idx,0,quant_type,num_bits,e_bits)
for i in range(1,blocks):
blk_quantize(model,blk_ch_table,blk_cfg_table,ml_name,2*ml_idx+1,i,quant_type,num_bits,e_bits)
def blk_quantize(model,blk_ch_table,blk_cfg_table,ml_name,ch_idx,blk_idx,quant_type,num_bits,e_bits):
blk_name = ml_name+'_blk%d'%blk_idx
ch = blk_ch_table[ch_idx]
for i in range(2):
if i == 0:
prefix = blk_name+'_ma_'
elif i == 1:
if ch[0]: #downsample/identity_flag
prefix = blk_name+'_ds_'
else:
continue
for j in range(len(blk_cfg_table[i])):
cfg = blk_cfg_table[i][j]
if cfg[0] == 'C':
conv_name = prefix+'conv%d'%j
conv_layer = getattr(model,conv_name)
qname = 'q_'+conv_name
if 'B' in cfg[1]:
bn_name = prefix+'bn%d'%j
bn_layer = getattr(model,bn_name)
if 'RL' in cfg[1]:
qlayer = QConvBNReLU(quant_type,conv_layer,bn_layer,num_bits=num_bits,e_bits=e_bits)
elif 'RS' in cfg[1]:
qlayer = QConvBNReLU6(quant_type,conv_layer,bn_layer,num_bits=num_bits,e_bits=e_bits)
else:
qlayer = QConvBN(quant_type,conv_layer,bn_layer,num_bits=num_bits,e_bits=e_bits)
else:
qlayer = QConv2d(quant_type,conv_layer,num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
#分支汇总
prefix = blk_name+'_'
for j in range(len(blk_cfg_table[-1])):
cfg = blk_cfg_table[-1][j]
if cfg[0] == 'AD':
if cfg[1] or ch[0]: #无条件加或flag为true
qname = 'q_'+prefix+'add%d'%j
qlayer = QElementwiseAdd(quant_type,2,qi_array=False,qo=True,num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
elif cfg[0] == 'RL':
qname = 'q_'+prefix+'relu%d'%j
qlayer = QReLU(quant_type,num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
def ml_utils(model,blk_type,ml_idx,blocks,func,x=None,qo=None):
ml_name = 'ml%d'%ml_idx
if blk_type == 'BBLK':
blk_ch_table = bblk_ch_table
blk_cfg_table = bblk_cfg_table
elif blk_type == 'BTNK':
blk_ch_table = btnk_ch_table
blk_cfg_table = btnk_cfg_table
elif blk_type == 'IRES':
blk_ch_table = ires_ch_table
blk_cfg_table = ires_cfg_table
else:
raise ValueError("ml_quantize: Illegal blk_type")
last_qo = qo
x,last_qo = blk_utils(model,blk_ch_table,blk_cfg_table,ml_name,2*ml_idx,0,func,x,last_qo)
for i in range(1,blocks):
x,last_qo = blk_utils(model,blk_ch_table,blk_cfg_table,ml_name,2*ml_idx+1,i,func,x,last_qo)
return x, last_qo
def blk_utils(model,blk_ch_table,blk_cfg_table,ml_name,ch_idx,blk_idx,func,x=None,qo=None):
blk_name = ml_name+'_blk%d'%blk_idx
ch = blk_ch_table[ch_idx]
outs = []
qos = []
for i in range(2):
tmp=x
last_qo = qo
if i == 0:
qprefix = 'q_'+blk_name+'_ma_'
elif i == 1:
if ch[0]: #downsample/identity_flag
qprefix = 'q_'+blk_name+'_ds_'
else:
outs.append(tmp)
qos.append(last_qo)
continue
for j in range(len(blk_cfg_table[i])):
cfg = blk_cfg_table[i][j]
if cfg[0] == 'C':
qname = qprefix+'conv%d'%j
qlayer = getattr(model,qname)
if func == 'forward':
tmp = qlayer(tmp)
elif func == 'inference':
tmp = qlayer.quantize_inference(tmp)
elif func == 'freeze':
qlayer.freeze(last_qo)
elif func == 'fakefreeze':
qlayer.fakefreeze()
last_qo = qlayer.qo
outs.append(tmp)
qos.append(last_qo)
#分支汇总
qprefix = 'q_'+blk_name+'_'
for j in range(len(blk_cfg_table[-1])):
cfg = blk_cfg_table[-1][j]
if cfg[0] == 'AD':
if cfg[1] or ch[0]: #无条件加或flag为true
qname = qprefix+'add%d'%j
qlayer = getattr(model,qname)
out = None
if func == 'forward':
out = qlayer(outs)
elif func == 'inference':
out = qlayer.quantize_inference(outs)
elif func == 'freeze':
qlayer.freeze(qos)
last_qo = qlayer.qo
else:
out = outs[0]
last_qo = qos[0]
elif cfg[0] == 'RL':
qname = qprefix+'relu%d'%j
qlayer = getattr(model,qname)
if func == 'forward':
out = qlayer(out)
elif func == 'inference':
out = qlayer.quantize_inference(out)
elif func == 'freeze':
qlayer.freeze(last_qo)
return out,last_qo
\ No newline at end of file
import math
import numpy as np
import gol
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from function import FakeQuantize
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
# 获取最近的量化值
def get_nearest_val(quant_type, x, is_bias=False, block_size=1000000):
if quant_type == 'INT':
return x.round_()
plist = gol.get_value(is_bias)
shape = x.shape
xhard = x.view(-1)
xout = torch.zeros_like(xhard)
plist = plist.type_as(x)
n_blocks = (x.numel() + block_size - 1) // block_size
for i in range(n_blocks):
start_idx = i * block_size
end_idx = min(start_idx + block_size, xhard.numel())
block_size_i = end_idx - start_idx
xblock = xhard[start_idx:end_idx]
plist_block = plist.unsqueeze(1)
idx = (xblock.unsqueeze(0) - plist_block).abs().min(dim=0)[1]
xhard_block = plist[idx].view(xblock.shape)
xout[start_idx:end_idx] = (xhard_block - xblock).detach() + xblock
xout = xout.view(shape)
return xout
# 采用对称有符号量化时,获取量化范围最大值
def get_qmax(quant_type,num_bits=None, e_bits=None):
if quant_type == 'INT':
qmax = 2. ** (num_bits - 1) - 1
elif quant_type == 'POT':
qmax = 1
else: #FLOAT
m_bits = num_bits - 1 - e_bits
dist_m = 2 ** (-m_bits)
e = 2 ** (e_bits - 1)
expo = 2 ** e
m = 2 ** m_bits -1
frac = 1. + m * dist_m
qmax = frac * expo
return qmax
# 都采用有符号量化,zeropoint都置为0
def calcScaleZeroPoint(min_val, max_val, qmax):
scale = torch.max(max_val.abs(),min_val.abs()) / qmax
zero_point = torch.tensor(0.)
return scale, zero_point
# 将输入进行量化,输入输出都为tensor
def quantize_tensor(quant_type, x, scale, zero_point, qmax, is_bias=False):
# 量化后范围,直接根据位宽确定
qmin = -qmax
q_x = zero_point + x / scale
q_x.clamp_(qmin, qmax)
q_x = get_nearest_val(quant_type, q_x, is_bias)
return q_x
# bias使用不同精度,需要根据量化类型指定num_bits/e_bits
def bias_qmax(quant_type):
if quant_type == 'INT':
return get_qmax(quant_type, 64)
elif quant_type == 'POT':
return get_qmax(quant_type)
else:
return get_qmax(quant_type, 16, 7)
# 转化为FP32,不需再做限制
def dequantize_tensor(q_x, scale, zero_point):
return scale * (q_x - zero_point)
class QParam(nn.Module):
def __init__(self,quant_type, num_bits=8, e_bits=3):
super(QParam, self).__init__()
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.qmax = get_qmax(quant_type, num_bits, e_bits)
scale = torch.tensor([], requires_grad=False)
zero_point = torch.tensor([], requires_grad=False)
min = torch.tensor([], requires_grad=False)
max = torch.tensor([], requires_grad=False)
# 通过注册为register,使得buffer可以被记录到state_dict
self.register_buffer('scale', scale)
self.register_buffer('zero_point', zero_point)
self.register_buffer('min', min)
self.register_buffer('max', max)
# 更新统计范围及量化参数
def update(self, tensor):
if self.max.nelement() == 0 or self.max.data < tensor.max().data:
self.max.data = tensor.max().data
self.max.clamp_(min=0)
if self.min.nelement() == 0 or self.min.data > tensor.min().data:
self.min.data = tensor.min().data
self.min.clamp_(max=0)
self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, self.qmax)
def quantize_tensor(self, tensor):
return quantize_tensor(self.quant_type, tensor, self.scale, self.zero_point, self.qmax)
def dequantize_tensor(self, q_x):
return dequantize_tensor(q_x, self.scale, self.zero_point)
# 该方法保证了可以从state_dict里恢复
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
key_names = ['scale', 'zero_point', 'min', 'max']
for key in key_names:
value = getattr(self, key)
value.data = state_dict[prefix + key].data
state_dict.pop(prefix + key)
# 该方法返回值将是打印该对象的结果
def __str__(self):
info = 'scale: %.10f ' % self.scale
info += 'zp: %.6f ' % self.zero_point
info += 'min: %.6f ' % self.min
info += 'max: %.6f' % self.max
return info
# 作为具体量化层的父类,qi和qo分别为量化输入/输出
class QModule(nn.Module):
def __init__(self,quant_type, qi=False, qo=True, num_bits=8, e_bits=3):
super(QModule, self).__init__()
if qi:
self.qi = QParam(quant_type,num_bits, e_bits)
if qo:
self.qo = QParam(quant_type,num_bits, e_bits)
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
def freeze(self):
pass # 空语句
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
def fakefreeze(self):
pass
"""
QModule 量化卷积
:quant_type: 量化类型
:conv_module: 卷积模块
:qi: 是否量化输入特征图
:qo: 是否量化输出特征图
:num_bits: 8位bit数
"""
class QConv2d(QModule):
def __init__(self, quant_type, conv_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QConv2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
# freeze方法可以固定真量化的权重参数,并将该值更新到原全精度层上,便于散度计算
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
# 这里因为在池化或者激活的输入,不需要对最大值和最小是进行额外的统计,会共享相同的输出
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
# 根据https://zhuanlan.zhihu.com/p/156835141, 这是式3 的系数
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
#考虑conv层无bias,此时forward和inference传入none亦可
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
self.conv_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0.,qmax=self.bias_qmax, is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
if self.conv_module.bias is not None:
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale*self.qw.scale,zero_point=0.)
def forward(self, x): # 前向传播,输入张量,x为浮点型数据
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi) # 对输入张量X完成量化
# foward前更新qw,保证量化weight时候scale正确
self.qw.update(self.conv_module.weight.data)
# 注意:此处主要为了统计各层x和weight范围,未对bias进行量化操作
tmp_wgt = FakeQuantize.apply(self.conv_module.weight, self.qw)
x = F.conv2d(x, tmp_wgt, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
# 利用公式 q_a = M(\sigma(q_w-Z_w)(q_x-Z_x) + q_b)
def quantize_inference(self, x): # 此处input为已经量化的qx
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QLinear(QModule):
def __init__(self, quant_type, fc_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QLinear, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.fc_module = fc_module
self.qw = QParam(quant_type, num_bits, e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data)
self.fc_module.weight.data = self.fc_module.weight.data - self.qw.zero_point
if self.fc_module.bias is not None:
self.fc_module.bias.data = quantize_tensor(self.quant_type,
self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax, is_bias=True)
def fakefreeze(self):
self.fc_module.weight.data = self.qw.dequantize_tensor(self.fc_module.weight.data)
if self.fc_module.bias is not None:
self.fc_module.bias.data = dequantize_tensor(self.fc_module.bias.data,scale=self.qi.scale*self.qw.scale,zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
self.qw.update(self.fc_module.weight.data)
tmp_wgt = FakeQuantize.apply(self.fc_module.weight, self.qw)
x = F.linear(x, tmp_wgt, self.fc_module.bias)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.fc_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QReLU(QModule):
def __init__(self,quant_type, qi=False, qo=False, num_bits=8, e_bits=3):
super(QReLU, self).__init__(quant_type, qi, qo, num_bits, e_bits)
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.relu(x)
return x
def quantize_inference(self, x):
x = x.clone()
x[x < self.qi.zero_point] = self.qi.zero_point
return x
class QReLU6(QModule):
def __init__(self,quant_type, qi=False, qo=False, num_bits=8, e_bits=3):
super(QReLU6, self).__init__(quant_type, qi, qo, num_bits, e_bits)
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.relu6(x)
return x
def quantize_inference(self, x):
x = x.clone()
upper = torch.tensor(6)
qupper = self.qi.quantize_tensor(upper)
x.clamp_(min=0,max=qupper.item())
return x
class QMaxPooling2d(QModule):
def __init__(self, quant_type, kernel_size=3, stride=1, padding=0, qi=False, qo=False, num_bits=8,e_bits=3):
super(QMaxPooling2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
return x
def quantize_inference(self, x):
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
class QAdaptiveAvgPool2d(QModule):
def __init__(self, quant_type, output_size, qi=False, qo=True, num_bits=8,e_bits=3):
super(QAdaptiveAvgPool2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.output_size = output_size
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qi.scale / self.qo.scale).data
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.adaptive_avg_pool2d(x, self.output_size)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = F.adaptive_avg_pool2d(x, self.output_size)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x+self.qo.zero_point
return x
class QConvBN(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QConvBN, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
else:
bias = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
self.conv_module.bias = torch.nn.Parameter(bias)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale*self.qw.scale,zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QConvBNReLU(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QConvBNReLU, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
else:
bias = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
self.conv_module.bias = torch.nn.Parameter(bias)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale*self.qw.scale,zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
x.clamp_(min=0)
return x
class QConvBNReLU6(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QConvBNReLU6, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
else:
bias = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
self.conv_module.bias = torch.nn.Parameter(bias)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale*self.qw.scale,zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu6(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
upper = torch.tensor(6)
qupper = self.qo.quantize_tensor(upper)
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
x.clamp_(min=0,max=qupper.item())
return x
# 作为具体量化层的父类,qi和qo分别为量化输入/输出
# 用于处理多个层结果或qo以array形式传入
class QModule_array(nn.Module):
def __init__(self,quant_type,len,qi_array=False, qo=True, num_bits=8, e_bits=3):
super(QModule_array, self).__init__()
if qi_array:
for i in range(len):
self.add_module('qi%d'%i,QParam(quant_type,num_bits, e_bits))
if qo:
self.qo = QParam(quant_type,num_bits, e_bits)
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
self.len = len
def freeze(self):
pass # 空语句
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
class QElementwiseAdd(QModule_array):
def __init__(self, quant_type, len, qi_array=False, qo=True, num_bits=8, e_bits=3):
super(QElementwiseAdd,self).__init__(quant_type,len,qi_array,qo,num_bits,e_bits)
for i in range(len):
self.register_buffer('M%d'%i,torch.tensor([], requires_grad=False))
def freeze(self, qi_array=None, qo=None):
if qi_array is None:
raise ValueError('qi_array should be provided')
elif len(qi_array) != self.len:
raise ValueError('qi_array len no match')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
for i in range(self.len):
self.add_module('qi%d'%i,qi_array[i])
if qo is not None:
self.qo = qo
for i in range(self.len):
getattr(self,'M%d'%i).data = (getattr(self,'qi%d'%i).scale / self.qo.scale).data
def forward(self,x_array):
outs=[]
for i in range(self.len):
x = x_array[i]
if hasattr(self,'qi%d'%i):
qi = getattr(self,'qi%d'%i)
qi.update(x)
x = FakeQuantize.apply(x,qi)
outs.append(x)
out = outs[0]+outs[1]
if hasattr(self,'qo'):
self.qo.update(x)
out = FakeQuantize.apply(out,self.qo)
return out
def quantize_inference(self, x_array):
outs=[]
for i in range(self.len):
qi = getattr(self,'qi%d'%i)
x = x_array[i] - qi.zero_point
x = getattr(self,'M%d'%i) * x
outs.append(x)
out = outs[0]+outs[1]
out = get_nearest_val(self.quant_type,out)
out = out + self.qo.zero_point
return out
class QConcat(QModule_array):
def __init__(self, quant_type, len, qi_array=False, qo=True, num_bits=8, e_bits=3):
super(QConcat,self).__init__(quant_type, len, qi_array, qo, num_bits, e_bits)
for i in range(len):
self.register_buffer('M%d'%i,torch.tensor([], requires_grad=False))
def freeze(self, qi_array=None, qo=None):
if qi_array is None:
raise ValueError('qi_array should be provided')
elif len(qi_array) != self.len:
raise ValueError('qi_array len no match')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
for i in range(self.len):
self.add_module('qi%d'%i,qi_array[i])
if qo is not None:
self.qo = qo
for i in range(self.len):
getattr(self,'M%d'%i).data = (getattr(self,'qi%d'%i).scale / self.qo.scale).data
def forward(self,x_array):
outs=[]
for i in range(self.len):
x = x_array[i]
if hasattr(self,'qi%d'%i):
qi = getattr(self,'qi%d'%i)
qi.update(x)
x = FakeQuantize.apply(x,qi)
outs.append(x)
out = torch.cat(outs,1)
if hasattr(self,'qo'):
self.qo.update(x)
out = FakeQuantize.apply(out,self.qo)
return out
def quantize_inference(self, x_array):
outs=[]
for i in range(self.len):
qi = getattr(self,'qi%d'%i)
x = x_array[i] - qi.zero_point
x = getattr(self,'M%d'%i) * x
outs.append(x)
out = torch.concat(outs,1)
out = get_nearest_val(self.quant_type,out)
out = out + self.qo.zero_point
return out
\ No newline at end of file
import os
import shutil
from pyhocon import ConfigFactory
from utils.opt_static import NetOption
class Option(NetOption):
def __init__(self, conf_path):
super(Option, self).__init__()
self.conf = ConfigFactory.parse_file(conf_path)
# ------------ General options ----------------------------------------
self.save_path = self.conf['save_path']
self.dataPath = self.conf['dataPath'] # path for loading data set
# 这里数据集只支持cifar100和imagenet?
self.dataset = self.conf['dataset'] # options: imagenet | cifar100
self.nGPU = self.conf['nGPU'] # number of GPUs to use by default
self.GPU = self.conf['GPU'] # default gpu to use, options: range(nGPU)
self.visible_devices = self.conf['visible_devices']
# ------------- Data options -------------------------------------------
self.nThreads = self.conf['nThreads'] # number of data loader threads
# ---------- Optimization options --------------------------------------
self.nEpochs = self.conf['nEpochs'] # number of total epochs to train
self.batchSize = self.conf['batchSize'] # mini-batch size
self.momentum = self.conf['momentum'] # momentum
self.weightDecay = float(self.conf['weightDecay']) # weight decay
# sgd adam之类的
self.opt_type = self.conf['opt_type']
self.warmup_epochs = self.conf['warmup_epochs'] # number of epochs for warmup
self.lr_S = self.conf['lr_S'] # initial learning rate
#hocon里用了multistep
self.lrPolicy_S = self.conf['lrPolicy_S'] # options: multi_step | linear | exp | const | step
self.step_S = self.conf['step_S'] # step for linear or exp learning rate policy
self.decayRate_S = self.conf['decayRate_S'] # lr decay rate
# ---------- Model options ---------------------------------------------
self.experimentID = self.conf['experimentID']
self.nClasses = self.conf['nClasses'] # number of classes in the dataset
# ---------- Quantization options ---------------------------------------------
#量化中的W4A4就是这里,W是指权重,A是指relu等层的量化。hocon里值都是4
self.qw = self.conf['qw']
self.qa = self.conf['qa']
# ----------KD options ---------------------------------------------
self.temperature = self.conf['temperature']
self.alpha = self.conf['alpha']
# ----------Generator options ---------------------------------------------
#生成器的参数
self.latent_dim = self.conf['latent_dim']
self.img_size = self.conf['img_size']
self.channels = self.conf['channels']
self.lr_G = self.conf['lr_G']
#用的还是multistep
self.lrPolicy_G = self.conf['lrPolicy_G'] # options: multi_step | linear | exp | const | step
self.step_G = self.conf['step_G'] # step for linear or exp learning rate policy
self.decayRate_G = self.conf['decayRate_G'] # lr decay rate
self.b1 = self.conf['b1']
self.b2 = self.conf['b2']
def set_save_path(self):
self.save_path = self.save_path + "{}_bs{:d}_lr{:.4f}_{}_epoch{}/".format(
self.experimentID,
self.batchSize, self.lr, self.opt_type,
self.nEpochs)
if os.path.exists(self.save_path):
shutil.rmtree(self.save_path)
# print("{} file exist!".format(self.save_path))
# action = input("Select Action: d (delete) / q (quit):").lower().strip()
# act = action
# if act == 'd':
# shutil.rmtree(self.save_path)
# else:
# raise OSError("Directory {} exits!".format(self.save_path))
if not os.path.exists(self.save_path):
os.makedirs(self.save_path)
def paramscheck(self, logger):
logger.info("|===>The used PyTorch version is {}".format(
self.torch_version))
if self.dataset in ["cifar10", "mnist"]:
self.nClasses = 10
elif self.dataset == "cifar100":
self.nClasses = 100
elif self.dataset == "imagenet" or "thi_imgnet":
self.nClasses = 1000
elif self.dataset == "imagenet100":
self.nClasses = 100
\ No newline at end of file
# *
# @file Different utility functions
# Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
# All rights reserved.
# This file is part of ZeroQ repository.
#
# ZeroQ is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ZeroQ is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ZeroQ repository. If not, see <http://www.gnu.org/licenses/>.
# *
import torch
import time
import math
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module, Parameter
from .quant_utils import *
import sys
class QuantAct(Module):
"""
Class to quantize given activations
"""
def __init__(self,
activation_bit,
full_precision_flag=False,
running_stat=True,
beta=0.9):
"""
activation_bit: bit-setting for activation
full_precision_flag: full precision or not
running_stat: determines whether the activation range is updated or froze
"""
super(QuantAct, self).__init__()
self.activation_bit = activation_bit
self.full_precision_flag = full_precision_flag
self.running_stat = running_stat
self.register_buffer('x_min', torch.zeros(1))
self.register_buffer('x_max', torch.zeros(1))
self.register_buffer('beta', torch.Tensor([beta]))
self.register_buffer('beta_t', torch.ones(1))
self.act_function = AsymmetricQuantFunction.apply
def __repr__(self):
return "{0}(activation_bit={1}, full_precision_flag={2}, running_stat={3}, Act_min: {4:.2f}, Act_max: {5:.2f})".format(
self.__class__.__name__, self.activation_bit,
self.full_precision_flag, self.running_stat, self.x_min.item(),
self.x_max.item())
#fix和unfix决定了能否更改统计值
def fix(self):
"""
fix the activation range by setting running stat
"""
self.running_stat = False
def unfix(self):
"""
fix the activation range by setting running stat
"""
self.running_stat = True
def forward(self, x):
"""
quantize given activation x
"""
if self.running_stat:
x_min = x.data.min()
x_max = x.data.max()
# in-place operation used on multi-gpus
# self.x_min += -self.x_min + min(self.x_min, x_min)
# self.x_max += -self.x_max + max(self.x_max, x_max)
self.beta_t = self.beta_t * self.beta
self.x_min = (self.x_min * self.beta + x_min * (1 - self.beta))/(1 - self.beta_t)
self.x_max = (self.x_max * self.beta + x_max * (1 - self.beta)) / (1 - self.beta_t)
if not self.full_precision_flag:
# 进行量化
quant_act = self.act_function(x, self.activation_bit, self.x_min,
self.x_max)
return quant_act
else:
return x
class Quant_Linear(Module):
"""
Class to quantize given linear layer weights
"""
def __init__(self, weight_bit, full_precision_flag=False):
"""
weight: bit-setting for weight
full_precision_flag: full precision or not
running_stat: determines whether the activation range is updated or froze
"""
super(Quant_Linear, self).__init__()
self.full_precision_flag = full_precision_flag
self.weight_bit = weight_bit
self.weight_function = AsymmetricQuantFunction.apply
def __repr__(self):
s = super(Quant_Linear, self).__repr__()
s = "(" + s + " weight_bit={}, full_precision_flag={})".format(
self.weight_bit, self.full_precision_flag)
return s
def set_param(self, linear):
self.in_features = linear.in_features
self.out_features = linear.out_features
self.weight = Parameter(linear.weight.data.clone())
try:
self.bias = Parameter(linear.bias.data.clone())
except AttributeError:
self.bias = None
def forward(self, x):
"""
using quantized weights to forward activation x
"""
w = self.weight
x_transform = w.data.detach()
w_min = x_transform.min(dim=1).values
w_max = x_transform.max(dim=1).values
if not self.full_precision_flag:
w = self.weight_function(self.weight, self.weight_bit, w_min,
w_max)
else:
w = self.weight
return F.linear(x, weight=w, bias=self.bias)
class Quant_Conv2d(Module):
"""
Class to quantize given convolutional layer weights
"""
def __init__(self, weight_bit, full_precision_flag=False):
super(Quant_Conv2d, self).__init__()
self.full_precision_flag = full_precision_flag
self.weight_bit = weight_bit
self.weight_function = AsymmetricQuantFunction.apply
def __repr__(self):
s = super(Quant_Conv2d, self).__repr__()
s = "(" + s + " weight_bit={}, full_precision_flag={})".format(
self.weight_bit, self.full_precision_flag)
return s
def set_param(self, conv):
self.in_channels = conv.in_channels
self.out_channels = conv.out_channels
self.kernel_size = conv.kernel_size
self.stride = conv.stride
self.padding = conv.padding
self.dilation = conv.dilation
self.groups = conv.groups
self.weight = Parameter(conv.weight.data.clone())
try:
self.bias = Parameter(conv.bias.data.clone())
except AttributeError:
self.bias = None
def forward(self, x):
"""
using quantized weights to forward activation x
"""
w = self.weight
x_transform = w.data.contiguous().view(self.out_channels, -1)
w_min = x_transform.min(dim=1).values
w_max = x_transform.max(dim=1).values
if not self.full_precision_flag:
#这里对权重进行量化。bias还是保持不变
w = self.weight_function(self.weight, self.weight_bit, w_min,
w_max)
else:
w = self.weight
return F.conv2d(x, w, self.bias, self.stride, self.padding,
self.dilation, self.groups)
#*
# @file Different utility functions
# Copyright (c) Yaohui Cai, Zhewei Yao, Zhen Dong, Amir Gholami
# All rights reserved.
# This file is part of ZeroQ repository.
#
# ZeroQ is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# ZeroQ is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with ZeroQ repository. If not, see <http://www.gnu.org/licenses/>.
#*
import math
import numpy as np
from torch.autograd import Function, Variable
import torch
def clamp(input, min, max, inplace=False):
"""
Clamp tensor input to (min, max).
input: input tensor to be clamped
"""
if inplace:
input.clamp_(min, max)
return input
return torch.clamp(input, min, max)
def linear_quantize(input, scale, zero_point, inplace=False):
"""
Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.
input: single-precision input tensor to be quantized
scale: scaling factor for quantization
zero_pint: shift for quantization
"""
# 根据input的shape确定,将所有信息集中于第一维
# reshape scale and zeropoint for convolutional weights and activation
if len(input.shape) == 4:
scale = scale.view(-1, 1, 1, 1)
zero_point = zero_point.view(-1, 1, 1, 1)
# reshape scale and zeropoint for linear weights
elif len(input.shape) == 2:
scale = scale.view(-1, 1)
zero_point = zero_point.view(-1, 1)
# mapping single-precision input to integer values with the given scale and zeropoint
if inplace:
# 这里就是量化行为
input.mul_(scale).sub_(zero_point).round_()
return input
return torch.round(scale * input - zero_point)
def linear_dequantize(input, scale, zero_point, inplace=False):
"""
Map integer input tensor to fixed point float point with given scaling factor and zeropoint.
input: integer input tensor to be mapped
scale: scaling factor for quantization
zero_pint: shift for quantization
"""
# reshape scale and zeropoint for convolutional weights and activation
if len(input.shape) == 4:
scale = scale.view(-1, 1, 1, 1)
zero_point = zero_point.view(-1, 1, 1, 1)
# reshape scale and zeropoint for linear weights
elif len(input.shape) == 2:
scale = scale.view(-1, 1)
zero_point = zero_point.view(-1, 1)
# mapping integer input to fixed point float point value with given scaling factor and zeropoint
if inplace:
input.add_(zero_point).div_(scale)
return input
return (input + zero_point) / scale
# 非对称线性量化
def asymmetric_linear_quantization_params(num_bits,
saturation_min,
saturation_max,
integral_zero_point=True,
signed=True):
"""
Compute the scaling factor and zeropoint with the given quantization range.
saturation_min: lower bound for quantization range
saturation_max: upper bound for quantization range
"""
n = 2**num_bits - 1
# 统计操作与我们的框架相反
scale = n / torch.clamp((saturation_max - saturation_min), min=1e-8)
zero_point = scale * saturation_min
if integral_zero_point:
if isinstance(zero_point, torch.Tensor):
zero_point = zero_point.round()
else:
zero_point = float(round(zero_point))
if signed:
zero_point += 2**(num_bits - 1)
return scale, zero_point
class AsymmetricQuantFunction(Function):
"""
Class to quantize the given floating-point values with given range and bit-setting.
Currently only support inference, but not support back-propagation.
"""
@staticmethod
def forward(ctx, x, k, x_min=None, x_max=None):
"""
x: single-precision value to be quantized
k: bit-setting for x
x_min: lower bound for quantization range
x_max=None
"""
# if x_min is None or x_max is None or (sum(x_min == x_max) == 1
# and x_min.numel() == 1):
# x_min, x_max = x.min(), x.max()
scale, zero_point = asymmetric_linear_quantization_params(
k, x_min, x_max)
#对输入量化
new_quant_x = linear_quantize(x, scale, zero_point, inplace=False)
n = 2**(k - 1)
new_quant_x = torch.clamp(new_quant_x, -n, n - 1)
quant_x = linear_dequantize(new_quant_x,
scale,
zero_point,
inplace=False)
#这里开启了求导功能
return torch.autograd.Variable(quant_x)
@staticmethod
def backward(ctx, grad_output):
return grad_output, None, None, None
# numpy==1.16.4
# requests==2.21.0
pyhocon==0.3.51
# torchvision==0.4.0
# torch==1.2.0+cu92
# Pillow==7.2.0
termcolor==1.1.0
"""
basic trainer
"""
import time
import torch.autograd
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import utils as utils
import numpy as np
import torch
__all__ = ["Trainer"]
class Trainer(object):
"""
trainer for training network, use SGD
"""
def __init__(self, model, model_teacher, generator, lr_master_S, lr_master_G,
train_loader, test_loader, settings, logger, tensorboard_logger=None,
opt_type="SGD", optimizer_state=None, run_count=0):
"""
init trainer
"""
self.settings = settings
self.model = utils.data_parallel(
model, self.settings.nGPU, self.settings.GPU)
self.model_teacher = utils.data_parallel(
model_teacher, self.settings.nGPU, self.settings.GPU)
self.generator = utils.data_parallel(
generator, self.settings.nGPU, self.settings.GPU)
self.train_loader = train_loader
self.test_loader = test_loader
self.tensorboard_logger = tensorboard_logger
self.criterion = nn.CrossEntropyLoss().cuda()
#logit表示log形式,BC表示二分类问题
self.bce_logits = nn.BCEWithLogitsLoss().cuda()
# MSE主要用于回归问题训练
self.MSE_loss = nn.MSELoss().cuda()
#student模型的学习率。这里就是量化模型
self.lr_master_S = lr_master_S
# 生成器的学习率
self.lr_master_G = lr_master_G
self.opt_type = opt_type
if opt_type == "SGD":
self.optimizer_S = torch.optim.SGD(
params=self.model.parameters(),
lr=self.lr_master_S.lr,
momentum=self.settings.momentum,
weight_decay=self.settings.weightDecay,
nesterov=True,
)
elif opt_type == "RMSProp":
self.optimizer_S = torch.optim.RMSprop(
params=self.model.parameters(),
lr=self.lr_master_S.lr,
eps=1.0,
weight_decay=self.settings.weightDecay,
momentum=self.settings.momentum,
alpha=self.settings.momentum
)
elif opt_type == "Adam":
self.optimizer_S = torch.optim.Adam(
params=self.model.parameters(),
lr=self.lr_master_S.lr,
eps=1e-5,
weight_decay=self.settings.weightDecay
)
else:
assert False, "invalid type: %d" % opt_type
#这里使用了beta参数,分别是一阶二阶动量的指数衰减率
if optimizer_state is not None:
self.optimizer_S.load_state_dict(optimizer_state)\
self.optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=self.settings.lr_G,
betas=(self.settings.b1, self.settings.b2))
self.logger = logger
self.run_count = run_count
self.scalar_info = {}
self.mean_list = []
self.var_list = []
self.teacher_running_mean = []
self.teacher_running_var = []
self.save_BN_mean = []
self.save_BN_var = []
self.fix_G = False
def update_lr(self, epoch):
"""
update learning rate of optimizers
:param epoch: current training epoch
"""
lr_S = self.lr_master_S.get_lr(epoch)
lr_G = self.lr_master_G.get_lr(epoch)
# update learning rate of model optimizer
for param_group in self.optimizer_S.param_groups:
param_group['lr'] = lr_S
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = lr_G
def loss_fn_kd(self, output, labels, teacher_outputs):
"""
Compute the knowledge-distillation (KD) loss given outputs, labels.
"Hyperparameters": temperature and alpha
NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
and student expects the input tensor to be log probabilities! See Issue #2
"""
# 输入是log形式概率,比较teacher和student模型的分类结果
criterion_d = nn.CrossEntropyLoss().cuda()
kdloss = nn.KLDivLoss().cuda()
alpha = self.settings.alpha
T = self.settings.temperature
a = F.log_softmax(output / T, dim=1)
b = F.softmax(teacher_outputs / T, dim=1)
c = (alpha * T * T)
d = criterion_d(output, labels)
KD_loss = kdloss(a,b)*c + d
return KD_loss
def forward(self, images, teacher_outputs, labels=None):
"""
forward propagation
"""
# forward and backward and optimize
# 原先第二个是开启feature层的中间输出,框架不含
output = self.model(images)
if labels is not None:
loss = self.loss_fn_kd(output, labels, teacher_outputs)
return output, loss
else:
return output, None
def backward_G(self, loss_G):
"""
backward propagation
"""
# 反向传播生成器
self.optimizer_G.zero_grad()
loss_G.backward()
self.optimizer_G.step()
def backward_S(self, loss_S):
"""
backward propagation
"""
self.optimizer_S.zero_grad()
loss_S.backward()
self.optimizer_S.step()
def backward(self, loss):
"""
backward propagation
"""
# 只需要更新生成器和student模型
self.optimizer_G.zero_grad()
self.optimizer_S.zero_grad()
loss.backward()
self.optimizer_G.step()
self.optimizer_S.step()
def hook_fn_forward(self,module, input, output):
#TODO:查看input来源
input = input[0]
mean = input.mean([0, 2, 3])
# use biased var in train
var = input.var([0, 2, 3], unbiased=False)
# 获取均值和方差?
self.mean_list.append(mean)
self.var_list.append(var)
self.teacher_running_mean.append(module.running_mean)
self.teacher_running_var.append(module.running_var)
def hook_fn_forward_saveBN(self,module, input, output):
self.save_BN_mean.append(module.running_mean.cpu())
self.save_BN_var.append(module.running_var.cpu())
def train(self, epoch):
"""
training
"""
# 用于每隔一定周期打印平均值
top1_error = utils.AverageMeter()
top1_loss = utils.AverageMeter()
top5_error = utils.AverageMeter()
fp_acc = utils.AverageMeter()
# iters = 200
iters = 200
self.update_lr(epoch) # 根据策略更新学习率
# 这里只训练了生成器
self.model.eval()
self.model_teacher.eval()
self.generator.train()
start_time = time.time()
end_time = start_time
if epoch==0:
for m in self.model_teacher.modules():
if isinstance(m, nn.BatchNorm2d):
# 跟踪BN层
# TODO:hook_fn_forward
m.register_forward_hook(self.hook_fn_forward)
for i in range(iters):
start_time = time.time()
data_time = start_time - end_time
z = Variable(torch.randn(self.settings.batchSize, self.settings.latent_dim)).cuda()
# Get labels ranging from 0 to n_classes for n rows
# 就是产生了1xn的向量,每个元素都是随机类别
labels = Variable(torch.randint(0, self.settings.nClasses, (self.settings.batchSize,))).cuda()
z = z.contiguous()
labels = labels.contiguous()
# TODO: 查看生成器输入
# teacher模型的输入是生成器根据随机类别和噪声来得到的
images = self.generator(z, labels)
self.mean_list.clear()
self.var_list.clear()
# 获取teacher模型输出,同时使用hook_fn_forward获取了mean和var列表
# 官方模型的out_feature用于输出经过feature系列层后的结果,框架不含此功能
output_teacher_batch = self.model_teacher(images)
# One hot loss
# teacher模型输出和label的损失,一维tensor
loss_one_hot = self.criterion(output_teacher_batch, labels)
# BN statistic loss
# 这里统计了伪数据分布和teacher模型BN层分布的loss
BNS_loss = torch.zeros(1).cuda()
for num in range(len(self.mean_list)):
BNS_loss += self.MSE_loss(self.mean_list[num], self.teacher_running_mean[num]) + self.MSE_loss(
self.var_list[num], self.teacher_running_var[num])
BNS_loss = BNS_loss / len(self.mean_list)
# loss of Generator
# 0.1就是论文中的β
loss_G = loss_one_hot + 0.1 * BNS_loss
self.backward_G(loss_G)
# #分别是teacher模型的输出。
# #TODO:loss_S待确定
# output, loss_S = self.forward(images.detach(), output_teacher_batch.detach(), labels)
# # warmup一定epoch后,才开始反向传播
# if epoch>= self.settings.warmup_epochs:
# self.backward_S(loss_S)
# #TODO:查看作用
# single_error, single_loss, single5_error = utils.compute_singlecrop(
# outputs=output, labels=labels,
# loss=loss_S, top5_flag=True, mean_flag=True)
# top1_error.update(single_error, images.size(0))
# top1_loss.update(single_loss, images.size(0))
# top5_error.update(single5_error, images.size(0))
end_time = time.time()
gt = labels.data.cpu().numpy()
d_acc = np.mean(np.argmax(output_teacher_batch.data.cpu().numpy(), axis=1) == gt)
#teacher的精度?
fp_acc.update(d_acc)
# 用于择优存储生成器
gen_acc = 100 * fp_acc.avg
# 对应输出的第一行,表示teacher在含噪声的label和生成器根据噪声label生成input上取得的精度
# 这里acc越高表示生成器效果越好,生成的输入接近真实数据分布
# print(
# "[Epoch %d/%d] [Batch %d/%d] [acc: %.4f%%] [G loss: %f] [One-hot loss: %f] [BNS_loss:%f] [S loss: %f] "
# % (epoch + 1, self.settings.nEpochs, i+1, iters, gen_acc, loss_G.item(), loss_one_hot.item(), BNS_loss.item(),
# loss_S.item())
# )
print(
"[Epoch %d/%d] [Batch %d/%d] [acc: %.4f%%] [G loss: %f] [One-hot loss: %f] [BNS_loss:%f]"
% (epoch + 1, self.settings.nEpochs, i+1, iters, gen_acc, loss_G.item(), loss_one_hot.item(), BNS_loss.item())
)
self.scalar_info['accuracy every epoch'] = 100 * d_acc
self.scalar_info['G loss every epoch'] = loss_G
self.scalar_info['One-hot loss every epoch'] = loss_one_hot
# self.scalar_info['S loss every epoch'] = loss_S
# self.scalar_info['training_top1error'] = top1_error.avg
# self.scalar_info['training_top5error'] = top5_error.avg
# self.scalar_info['training_loss'] = top1_loss.avg
if self.tensorboard_logger is not None:
for tag, value in list(self.scalar_info.items()):
self.tensorboard_logger.scalar_summary(tag, value, self.run_count)
self.scalar_info = {}
return gen_acc#, top1_error.avg, top1_loss.avg, top5_error.avg
def test(self, epoch):
"""
testing
"""
# 测试student的acc
top1_error = utils.AverageMeter()
top1_loss = utils.AverageMeter()
top5_error = utils.AverageMeter()
self.model.eval()
self.model_teacher.eval()
iters = len(self.test_loader)
start_time = time.time()
end_time = start_time
with torch.no_grad():
#从testloader中获取
for i, (images, labels) in enumerate(self.test_loader):
start_time = time.time()
labels = labels.cuda()
images = images.cuda()
output = self.model(images)
loss = torch.ones(1)
#model中会统计mean和var,这里清空
self.mean_list.clear()
self.var_list.clear()
single_error, single_loss, single5_error = utils.compute_singlecrop(
outputs=output, loss=loss,
labels=labels, top5_flag=True, mean_flag=True)
top1_error.update(single_error, images.size(0))
top1_loss.update(single_loss, images.size(0))
top5_error.update(single5_error, images.size(0))
end_time = time.time()
# 不含噪声的label和量化后的student模型
print(
"[Epoch %d/%d] [Batch %d/%d] [acc: %.4f%%]"
% (epoch + 1, self.settings.nEpochs, i + 1, iters, (100.00-top1_error.avg))
)
self.scalar_info['testing_top1error'] = top1_error.avg
self.scalar_info['testing_top5error'] = top5_error.avg
self.scalar_info['testing_loss'] = top1_loss.avg
if self.tensorboard_logger is not None:
for tag, value in self.scalar_info.items():
self.tensorboard_logger.scalar_summary(tag, value, self.run_count)
self.scalar_info = {}
self.run_count += 1
return top1_error.avg, top1_loss.avg, top5_error.avg
def test_teacher(self, epoch):
"""
testing
"""
top1_error = utils.AverageMeter()
top1_loss = utils.AverageMeter()
top5_error = utils.AverageMeter()
self.model_teacher.eval()
iters = len(self.test_loader)
start_time = time.time()
end_time = start_time
with torch.no_grad():
for i, (images, labels) in enumerate(self.test_loader):
start_time = time.time()
data_time = start_time - end_time
labels = labels.cuda()
#tenCrop,对图像进行十种不同裁剪方式,得到十个样本
if self.settings.tenCrop:
image_size = images.size()
images = images.view(
image_size[0] * 10, image_size[1] / 10, image_size[2], image_size[3])
#split括号内为切分的数量,则每个形如(10,size[1]/10,size[2],size[3])
images_tuple = images.split(image_size[0])
output = None
for img in images_tuple:
if self.settings.nGPU == 1:
img = img.cuda()
img_var = Variable(img, volatile=True)
temp_output, _ = self.forward(img_var)
if output is None:
output = temp_output.data
else:
output = torch.cat((output, temp_output.data))
single_error, single_loss, single5_error = utils.compute_tencrop(
outputs=output, labels=labels)
else:
if self.settings.nGPU == 1:
images = images.cuda()
output = self.model_teacher(images)
loss = torch.ones(1)
self.mean_list.clear()
self.var_list.clear()
single_error, single_loss, single5_error = utils.compute_singlecrop(
outputs=output, loss=loss,
labels=labels, top5_flag=True, mean_flag=True)
top1_error.update(single_error, images.size(0))
top1_loss.update(single_loss, images.size(0))
top5_error.update(single5_error, images.size(0))
end_time = time.time()
iter_time = end_time - start_time
print(
"Teacher network: [Epoch %d/%d] [Batch %d/%d] [acc: %.4f%%]"
% (epoch + 1, self.settings.nEpochs, i + 1, iters, (100.00 - top1_error.avg))
)
self.run_count += 1
return top1_error.avg, top1_loss.avg, top5_error.avg
from utils.lr_policy import *
from utils.compute import *
from utils.log_print import *
from utils.model_transform import *
# from utils.ifeige import *
\ No newline at end of file
import numpy as np
import math
import torch
__all__ = ["compute_tencrop", "compute_singlecrop", "AverageMeter"]
def compute_tencrop(outputs, labels):
output_size = outputs.size()
outputs = outputs.view(output_size[0] / 10, 10, output_size[1])
outputs = outputs.sum(1).squeeze(1)
# compute top1
_, pred = outputs.topk(1, 1, True, True)
pred = pred.t()
top1_count = pred.eq(labels.data.view(
1, -1).expand_as(pred)).view(-1).float().sum(0)
top1_error = 100.0 - 100.0 * top1_count / labels.size(0)
top1_error = float(top1_error.cpu().numpy())
# compute top5
_, pred = outputs.topk(5, 1, True, True)
pred = pred.t()
top5_count = pred.eq(labels.data.view(
1, -1).expand_as(pred)).view(-1).float().sum(0)
top5_error = 100.0 - 100.0 * top5_count / labels.size(0)
top5_error = float(top5_error.cpu().numpy())
return top1_error, 0, top5_error
def compute_singlecrop(outputs, labels, loss, top5_flag=False, mean_flag=False):
with torch.no_grad():
if isinstance(outputs, list):
top1_loss = []
top1_error = []
top5_error = []
for i in range(len(outputs)):
top1_accuracy, top5_accuracy = accuracy(outputs[i], labels, topk=(1, 5))
top1_error.append(100 - top1_accuracy)
top5_error.append(100 - top5_accuracy)
top1_loss.append(loss[i].item())
else:
top1_accuracy, top5_accuracy = accuracy(outputs, labels, topk=(1,5))
top1_error = 100 - top1_accuracy
top5_error = 100 - top5_accuracy
top1_loss = loss.item()
if top5_flag:
return top1_error, top1_loss, top5_error
else:
return top1_error, top1_loss
# 统计精确度acc
def accuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size).item())
return res
class AverageMeter(object):
"""Computes and stores the average and current value"""
# 统计某个间隔内的平均值
def __init__(self):
self.reset()
def reset(self):
"""
reset all parameters
"""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
"""
update parameters
"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
\ No newline at end of file
from termcolor import colored
import numpy as np
import datetime
__all__ = ["compute_remain_time", "print_result", "print_weight", "print_grad"]
single_train_time = 0
single_test_time = 0
single_train_iters = 0
single_test_iters = 0
def compute_remain_time(epoch, nEpochs, count, iters, data_time, iter_time, mode="Train"):
global single_train_time, single_test_time
global single_train_iters, single_test_iters
# compute cost time
if mode == "Train":
single_train_time = single_train_time * \
0.95 + 0.05 * (data_time + iter_time)
# single_train_time = data_time + iter_time
single_train_iters = iters
train_left_iter = single_train_iters - count + \
(nEpochs - epoch - 1) * single_train_iters
# print "train_left_iters", train_left_iter
test_left_iter = (nEpochs - epoch) * single_test_iters
else:
single_test_time = single_test_time * \
0.95 + 0.05 * (data_time + iter_time)
# single_test_time = data_time+iter_time
single_test_iters = iters
train_left_iter = (nEpochs - epoch - 1) * single_train_iters
test_left_iter = single_test_iters - count + \
(nEpochs - epoch - 1) * single_test_iters
left_time = single_train_time * train_left_iter + \
single_test_time * test_left_iter
total_time = (single_train_time * single_train_iters +
single_test_time * single_test_iters) * nEpochs
time_str = "TTime: {}, RTime: {}".format(datetime.timedelta(seconds=total_time),
datetime.timedelta(seconds=left_time))
return time_str, total_time, left_time
def print_result(epoch, nEpochs, count, iters, lr, data_time, iter_time, error, loss, top5error=None,
mode="Train", logger=None):
log_str = ">>> {}: [{:0>3d}|{:0>3d}], Iter: [{:0>3d}|{:0>3d}], LR: {:.6f}, DataTime: {:.4f}, IterTime: {:.4f}, ".format(
mode, epoch + 1, nEpochs, count, iters, lr, data_time, iter_time)
if isinstance(error, list) or isinstance(error, np.ndarray):
for i in range(len(error)):
log_str += "Error_{:d}: {:.4f}, Loss_{:d}: {:.4f}, ".format(i, error[i], i, loss[i])
else:
log_str += "Error: {:.4f}, Loss: {:.4f}, ".format(error, loss)
if top5error is not None:
if isinstance(top5error, list) or isinstance(top5error, np.ndarray):
for i in range(len(top5error)):
log_str += " Top5_Error_{:d}: {:.4f}, ".format(i, top5error[i])
else:
log_str += " Top5_Error: {:.4f}, ".format(top5error)
time_str, total_time, left_time = compute_remain_time(epoch, nEpochs, count, iters, data_time, iter_time, mode)
logger.info(log_str + time_str)
return total_time, left_time
def print_weight(layers, logger):
if isinstance(layers, MD.qConv2d):
logger.info(layers.weight)
elif isinstance(layers, MD.qLinear):
logger.info(layers.weight)
logger.info(layers.weight_mask)
logger.info("------------------------------------")
def print_grad(m, logger):
if isinstance(m, MD.qLinear):
logger.info(m.weight.data)
"""
class LRPolicy
"""
import math
__all__ = ["LRPolicy"]
class LRPolicy:
"""
learning rate policy
"""
def __init__(self, lr, n_epochs, lr_policy="multi_step"):
self.lr_policy = lr_policy
self.params_dict = {}
self.n_epochs = n_epochs
self.base_lr = lr
self.lr = lr
def set_params(self, params_dict=None):
"""
set parameters of lr policy
"""
if self.lr_policy == "multi_step":
"""
params: decay_rate, step
"""
self.params_dict['decay_rate'] = params_dict['decay_rate']
self.params_dict['step'] = sorted(params_dict['step'])
if max(self.params_dict['step']) <= 1:
new_step_list = []
for ratio in self.params_dict['step']:
new_step_list.append(int(self.n_epochs * ratio))
self.params_dict['step'] = new_step_list
elif self.lr_policy == "step":
"""
params: end_lr, step
step: lr = base_lr*gamma^(floor(iter/step))
"""
self.params_dict['end_lr'] = params_dict['end_lr']
self.params_dict['step'] = params_dict['step']
max_iter = math.floor((self.n_epochs - 1.0) /
self.params_dict['step'])
if self.params_dict['end_lr'] == -1:
self.params_dict['gamma'] = params_dict['decay_rate']
else:
self.params_dict['gamma'] = math.pow(
self.params_dict['end_lr'] / self.base_lr, 1. / max_iter)
elif self.lr_policy == "linear":
"""
params: end_lr, step
"""
self.params_dict['end_lr'] = params_dict['end_lr']
self.params_dict['step'] = params_dict['step']
elif self.lr_policy == "exp":
"""
params: end_lr
exp: lr = base_lr*gamma^iter
"""
self.params_dict['end_lr'] = params_dict['end_lr']
self.params_dict['gamma'] = math.pow(
self.params_dict['end_lr'] / self.base_lr, 1. / (self.n_epochs - 1))
elif self.lr_policy == "inv":
"""
params: end_lr
inv: lr = base_lr*(1+gamma*iter)^(-power)
"""
self.params_dict['end_lr'] = params_dict['end_lr']
self.params_dict['power'] = params_dict['power']
self.params_dict['gamma'] = (math.pow(
self.base_lr / self.params_dict['end_lr'],
1. / self.params_dict['power']) - 1.) / (self.n_epochs - 1.)
elif self.lr_policy == "const":
"""
no params
const: lr = base_lr
"""
self.params_dict = None
else:
assert False, "invalid lr_policy" + self.lr_policy
def get_lr(self, epoch):
"""
get current learning rate
"""
if self.lr_policy == "multi_step":
gamma = 0
for step in self.params_dict['step']:
if epoch + 1.0 > step:
gamma += 1
lr = self.base_lr * math.pow(self.params_dict['decay_rate'], gamma)
elif self.lr_policy == "step":
lr = self.base_lr * \
math.pow(self.params_dict['gamma'], math.floor(
epoch * 1.0 / self.params_dict['step']))
elif self.lr_policy == "linear":
k = (self.params_dict['end_lr'] - self.base_lr) / \
math.ceil(self.n_epochs / self.params_dict['step'])
lr = k * math.ceil((epoch + 1) /
self.params_dict['step']) + self.base_lr
elif self.lr_policy == "inv":
lr = self.base_lr * \
math.pow(
1 + self.params_dict['gamma'] * epoch, -self.params_dict['power'])
elif self.lr_policy == "exp":
# power = math.floor((epoch + 1) / self.params_dict['step'])
# lr = self.base_lr * math.pow(self.params_dict['gamma'], power)
lr = self.base_lr * math.pow(self.params_dict['gamma'], epoch)
elif self.lr_policy == "const":
lr = self.base_lr
else:
assert False, "invalid lr_policy: " + self.lr_policy
self.lr = lr
return lr
import torch.nn as nn
import torch
import numpy as np
__all__ = ["data_parallel", "model2list",
"list2sequential", "model2state_dict"]
def data_parallel(model, ngpus, gpu0=0):
"""
assign model to multi-gpu mode
:params model: target model
:params ngpus: number of gpus to use
:params gpu0: id of the master gpu
:return: model, type is Module or Sequantial or DataParallel
"""
if ngpus == 0:
assert False, "only support gpu mode"
gpu_list = list(range(gpu0, gpu0 + ngpus))
assert torch.cuda.device_count() >= gpu0 + ngpus, "Invalid Number of GPUs"
if isinstance(model, list):
for i in range(len(model)):
if ngpus >= 2:
if not isinstance(model[i], nn.DataParallel):
model[i] = torch.nn.DataParallel(model[i], gpu_list).cuda()
else:
model[i] = model[i].cuda()
else:
if ngpus >= 2:
if not isinstance(model, nn.DataParallel):
model = torch.nn.DataParallel(model, gpu_list).cuda()
else:
model = model.cuda()
return model
def model2list(model):
"""
convert model to list type
:param model: should be type of list or nn.DataParallel or nn.Sequential
:return: no return params
"""
if isinstance(model, nn.DataParallel):
model = list(model.module)
elif isinstance(model, nn.Sequential):
model = list(model)
return model
def list2sequential(model):
if isinstance(model, list):
model = nn.Sequential(*model)
return model
def model2state_dict(file_path):
model = torch.load(file_path)
if model['model'] is not None:
model_state_dict = model['model'].state_dict()
torch.save(model_state_dict, file_path.replace(
'.pth', 'state_dict.pth'))
else:
print((type(model)))
print(model)
print("skip")
"""
TODO: add doc for module
"""
import torch
__all__ = ["NetOption"]
"""
You can run your script with CUDA_VISIBLE_DEVICES=5,6 python your_script.py
or set the environment variable in the script by os.environ['CUDA_VISIBLE_DEVICES'] = '5,6'
to map GPU 5, 6 to device_ids 0, 1, respectively.
"""
#main中调用了这个对象,并传入hocon进行了改变。内容主要都取决于hocon
class NetOption(object):
def __init__(self):
# ------------ General options ----------------------------------------
self.save_path = "" # log path
#数据集的地方
self.dataPath = "/home/dataset/" # path for loading data set
self.dataset = "cifar10" # options: imagenet | cifar10 | cifar100 | imagenet100 | mnist
self.manualSeed = 1 # manually set RNG seed
self.nGPU = 1 # number of GPUs to use by default
self.GPU = 0 # default gpu to use, options: range(nGPU)
# ------------- Data options -------------------------------------------
self.nThreads = 4 # number of data loader threads
# ------------- Training options ---------------------------------------
self.testOnly = False # run on validation set only
self.tenCrop = False # Ten-crop testing
# ---------- Optimization options --------------------------------------
self.nEpochs = 200 # number of total epochs to train
self.batchSize = 128 # mini-batch size
self.momentum = 0.9 # momentum
self.weightDecay = 1e-4 # weight decay 1e-4
self.opt_type = "SGD"
self.lr = 0.1 # initial learning rate
self.lrPolicy = "multi_step" # options: multi_step | linear | exp | fixed
self.power = 1 # power for learning rate policy (inv)
self.step = [0.6, 0.8] # step for linear or exp learning rate policy
self.endlr = 0.001 # final learning rate, oly for "linear lrpolicy"
self.decayRate = 0.1 # lr decay rate
# ---------- Model options ---------------------------------------------
self.netType = "PreResNet" # options: ResNet | PreResNet | GreedyNet | NIN | LeNet5
self.experimentID = "refator-test-01"
self.depth = 20 # resnet depth: (n-2)%6==0
self.nClasses = 10 # number of classes in the dataset
self.wideFactor = 1 # wide factor for wide-resnet
# ---------- Resume or Retrain options ---------------------------------------------
self.retrain = None # path to model to retrain with, load model state_dict only
self.resume = None # path to directory containing checkpoint, load state_dicts of model and optimizer, as well as training epoch
# ---------- Visualization options -------------------------------------
self.drawNetwork = True
self.drawInterval = 30
self.torch_version = torch.__version__
torch_version_split = self.torch_version.split("_")
self.torch_version = torch_version_split[0]
# check parameters
# self.paramscheck()
def paramscheck(self):
if self.torch_version != "0.2.0":
self.drawNetwork = False
print("|===>DrawNetwork is supported by PyTorch with version: 0.2.0. The used version is ", self.torch_version)
if self.netType in ["PreResNet", "ResNet"]:
self.save_path = "log_%s%d_%s_bs%d_lr%0.3f_%s/" % (
self.netType, self.depth, self.dataset,
self.batchSize, self.lr, self.experimentID)
else:
self.save_path = "log_%s_%s_bs%d_lr%0.3f_%s/" % (
self.netType, self.dataset,
self.batchSize, self.lr, self.experimentID)
if self.dataset in ["cifar10", "mnist"]:
self.nClasses = 10
elif self.dataset == "cifar100":
self.nClasses = 100
elif self.dataset == "imagenet" or "thi_imgnet":
self.nClasses = 1000
elif self.dataset == "imagenet100":
self.nClasses = 100
if self.depth >= 100:
self.drawNetwork = False
print("|===>draw network with depth over 100 layers, skip this step")
from torchlearning.mio import MIO
train_dataset = MIO("/home/datasets/imagenet_mio/train/")
test_dataset = MIO("/home/datasets/imagenet_mio/val/")
for i in range(train_dataset.size):
print(i)
train_dataset.fetchone(i)
for i in range(test_dataset.size):
print(i)
test_dataset.fetchone(i)
\ No newline at end of file
# 改动说明
## update: 2023/05/29
+ GDFQ:结合之前框架,训练了所有模型的生成器。后续将进一步引入评估和决策边界样本增强。
## update2: 2023/05/26
+ 添加了cifar100数据集支持,详见ALL-cifar100。原先ALL文件夹重命名为ALL-cifar10
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment