Commit bc233acc by Zhihong Ma

fix: QAT - Predicting the Convergence Speed of Quantized Model

parent f4b96743
import torch
import torch.nn as nn
import torch.nn.functional as F
from module import *
import module
from global_var import GlobalVariables
# 定义 ResNet 模型
# 适用于Cifar10
class ResNet_fold(nn.Module):
def __init__(self, block, layers, num_classes=10): # 这里将类别数设置为10
super(ResNet_fold, self).__init__()
self.inplanes = 16 # 因为 CIFAR-10 图片较小,所以开始时需要更少的通道数
GlobalVariables.SELF_INPLANES = self.inplanes
# print('resnet init:'+ str(GlobalVariables.SELF_INPLANES))
# 输入层
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
self.convbnrelu1 = ConvBNReLU(self.conv1,self.bn1)
# 残差层(4 个阶段,每个阶段包含 6n+2 个卷积层)
self.layer1 = MakeLayer(block, 16, layers[0])
self.layer2 = MakeLayer(block, 32, layers[1], stride=2)
self.layer3 = MakeLayer(block, 64, layers[2], stride=2)
self.layer4 = MakeLayer(block, 128, layers[3], stride=2)
# 分类层
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(128 * block.expansion, num_classes)
# 参数初始化
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
# 输入层
# x = self.conv1(x)
# x = self.bn1(x)
# x = self.relu(x)
x = self.convbnrelu1(x)
# 这里相比于imagenet的,少了一个maxpool,因为cifar10本身图片就小,如果再pool就太小了
# 残差层
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# 分类层
x = self.avgpool(x) # 输出的尺寸为 B,C,1,1
x = x.view(x.size(0), -1)
x = self.fc(x)
out = F.softmax(x,dim = 1) # 这里不softmax也行 影响不大
return out
def quantize(self, quant_type, num_bits=8, e_bits=3):
self.qconvbnrelu1 = QConvBNReLU(quant_type,self.conv1,self.bn1,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)
# 没有输入num_bits 需修改
self.layer1.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer2.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer3.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer4.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.qavgpool1 = QAdaptiveAvgPool2d(quant_type,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qfc1 = QLinear(quant_type, self.fc,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
# self.qfc1 = QLinear(quant_type, self.fc,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)
def quantize_forward(self, x):
# for _, layer in self.quantize_layers.items():
# x = layer(x)
# out = F.softmax(x, dim=1)
# return out
x = self.qconvbnrelu1(x)
x = self.layer1.quantize_forward(x)
x = self.layer2.quantize_forward(x)
x = self.layer3.quantize_forward(x)
x = self.layer4.quantize_forward(x)
x = self.qavgpool1(x)
x = x.view(x.size(0), -1)
x = self.qfc1(x)
out = F.softmax(x,dim = 1) # 这里不softmax也行 影响不大
return out
def freeze(self):
self.qconvbnrelu1.freeze() # 因为作为第一层是有qi的,所以freeze的时候无需再重新提供qi
qo = self.layer1.freeze(qinput = self.qconvbnrelu1.qo)
qo = self.layer2.freeze(qinput = qo)
qo = self.layer3.freeze(qinput = qo)
qo = self.layer4.freeze(qinput = qo)
self.qavgpool1.freeze(qi=qo)
self.qfc1.freeze(qi=self.qavgpool1.qo)
# self.qfc1.freeze()
def fakefreeze(self):
self.qconvbnrelu1.fakefreeze()
self.layer1.fakefreeze()
self.layer2.fakefreeze()
self.layer3.fakefreeze()
self.layer4.fakefreeze()
self.qfc1.fakefreeze()
def quantize_inference(self, x):
qx = self.qconvbnrelu1.qi.quantize_tensor(x)
qx = self.qconvbnrelu1.quantize_inference(qx)
qx = self.layer1.quantize_inference(qx)
qx = self.layer2.quantize_inference(qx)
qx = self.layer3.quantize_inference(qx)
qx = self.layer4.quantize_inference(qx)
qx = self.qavgpool1.quantize_inference(qx)
qx = qx.view(qx.size(0), -1)
qx = self.qfc1.quantize_inference(qx)
qx = self.qfc1.qo.dequantize_tensor(qx)
out = F.softmax(qx,dim = 1) # 这里不softmax也行 影响不大
return out
# BasicBlock 类
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
# 第一个卷积层
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.convbnrelu1 = ConvBNReLU(self.conv1,self.bn1)
# 第二个卷积层
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU()
self.convbn1 = ConvBN(self.conv2,self.bn2)
# shortcut
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
# out = self.conv1(x)
# out = self.bn1(out)
# out = self.relu(out)
out = self.convbnrelu1(x)
# out = self.conv2(out)
# out = self.bn2(out)
out = self.convbn1(out)
if self.downsample is not None:
identity = self.downsample(identity)
out += identity
out = self.relu(out)
return out
def quantize(self, quant_type ,num_bits=8, e_bits=3):
self.qconvbnrelu1 = QConvBNReLU(quant_type,self.conv1,self.bn1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qconvbn1 = QConvBN(quant_type,self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
if self.downsample is not None:
self.qconvbn2 = QConvBN(quant_type,self.downsample[0],self.downsample[1],qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qelementadd = QElementwiseAdd(quant_type,qi0=False, qi1=False, qo=True,num_bits=num_bits,e_bits=e_bits)
self.qrelu1 = QReLU(quant_type,qi= False,num_bits=num_bits,e_bits=e_bits) # 需要qi
def quantize_forward(self, x):
identity = x
out = self.qconvbnrelu1(x)
out = self.qconvbn1(out)
if self.downsample is not None:
identity = self.qconvbn2(identity)
# residual add
# out = identity + out # 这里是需要写一个elementwiseadd的变换的,待后续修改
out = self.qelementadd(out,identity)
out = self.qrelu1(out)
return out
def freeze(self, qinput):
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
self.qconvbnrelu1.freeze(qi= qinput) # 需要接前一个module的最后一个qo
self.qconvbn1.freeze(qi = self.qconvbnrelu1.qo)
if self.downsample is not None:
self.qconvbn2.freeze(qi = qinput) # 一条支路
self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = self.qconvbn2.qo)
else:
self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = qinput)
# 这里或许需要补充个层来处理elementwise add
self.qrelu1.freeze(qi = self.qelementadd.qo)
return self.qrelu1.qi # relu后的qo可用relu统计的qi
def fakefreeze(self):
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
self.qconvbnrelu1.fakefreeze() # 需要接前一个module的最后一个qo
self.qconvbn1.fakefreeze()
if self.downsample is not None:
self.qconvbn2.fakefreeze() # 一条支路
def quantize_inference(self, x):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
identity = x
out = self.qconvbnrelu1.quantize_inference(x)
out = self.qconvbn1.quantize_inference(out)
if self.downsample is not None:
identity = self.qconvbn2.quantize_inference(identity)
# out = identity + out # 这里可能需要写一个elementwiseadd的变换的,待后续修改
out = self.qelementadd.quantize_inference(out,identity)
out = self.qrelu1.quantize_inference(out)
return out
# Bottleneck 类
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
# 1x1 卷积层
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.convbnrelu1 = ConvBNReLU(self.conv1,self.bn1)
# 3x3 卷积层
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.convbnrelu2 = ConvBNReLU(self.conv2,self.bn2)
# 1x1 卷积层
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.convbn1 = ConvBN(self.conv3,self.bn3)
# shortcut
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
# out = self.conv1(x)
# out = self.bn1(out)
# out = self.relu(out)
out = self.convbnrelu1(x)
# out = self.conv2(out)
# out = self.bn2(out)
# out = self.relu(out)
out = self.convbnrelu2(out)
# out = self.conv3(out)
# out = self.bn3(out)
out = self.convbn1(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity # 相加是在这里处理的
out = self.relu(out)
return out
def quantize(self, quant_type ,num_bits=8, e_bits=3):
self.qconvbnrelu1 = QConvBNReLU(quant_type,self.conv1,self.bn1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qconvbnrelu2 = QConvBNReLU(quant_type,self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qconvbn1 = QConvBN(quant_type,self.conv3,self.bn3,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
if self.downsample is not None:
self.qconvbn2 = QConvBN(quant_type,self.downsample[0],self.downsample[1],qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qelementadd = QElementwiseAdd(quant_type,qi0=False, qi1=False, qo=True,num_bits=num_bits,e_bits=e_bits)
self.qrelu1 = QReLU(quant_type,qi= False,num_bits=num_bits,e_bits=e_bits) # 需要qi
def quantize_forward(self, x):
identity = x
out = self.qconvbnrelu1(x)
out = self.qconvbnrelu2(out)
out = self.qconvbn1(out)
if self.downsample is not None:
identity = self.qconvbn2(identity)
# residual add
# out = identity + out # 这里是需要写一个elementwiseadd的变换的,待后续修改
out = self.qelementadd(out,identity)
out = self.qrelu1(out)
return out
def freeze(self, qinput):
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
self.qconvbnrelu1.freeze(qi= qinput) # 需要接前一个module的最后一个qo
self.qconvbnrelu2.freeze(qi=self.qconvbnrelu1.qo)
self.qconvbn1.freeze(qi = self.qconvbnrelu2.qo)
if self.downsample is not None:
self.qconvbn2.freeze(qi = qinput) # 一条支路
self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = self.qconvbn2.qo)
else:
self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = qinput)
# 这里或许需要补充个层来处理elementwise add
self.qrelu1.freeze(qi = self.qelementadd.qo) # 需要自己统计qi
return self.qrelu1.qi # relu后的qo可用relu统计的qi
def fakefreeze(self):
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
self.qconvbnrelu1.fakefreeze()
self.qconvbnrelu2.fakefreeze()
self.qconvbn1.fakefreeze()
if self.downsample is not None:
self.qconvbn2.fakefreeze() # 一条支路
def quantize_inference(self, x):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
identity = x
out = self.qconvbnrelu1.quantize_inference(x)
out = self.qconvbnrelu2.quantize_inference(out)
out = self.qconvbn1.quantize_inference(out)
if self.downsample is not None:
identity = self.qconvbn2.quantize_inference(identity)
# out = identity + out # 这里可能需要写一个elementwiseadd的变换的,待后续修改
out = self.qelementadd.quantize_inference(out,identity)
out = self.qrelu1.quantize_inference(out)
return out
class MakeLayer(nn.Module):
def __init__(self, block, planes, blocks, stride=1):
super(MakeLayer, self).__init__()
# print('makelayer init:'+ str(GlobalVariables.SELF_INPLANES))
self.downsample = None
if stride != 1 or GlobalVariables.SELF_INPLANES != planes * block.expansion:
# self.downsample = nn.Sequential(
# nn.Conv2d(GlobalVariables.SELF_INPLANES, planes * block.expansion,kernel_size=1, stride=stride, bias=False),
# nn.BatchNorm2d(planes * block.expansion)
# )
self.conv1 = nn.Conv2d(GlobalVariables.SELF_INPLANES, planes * block.expansion,kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(planes * block.expansion)
self.convbn1 = ConvBN(self.conv1,self.bn1)
self.downsample = self.convbn1
self.blockdict = nn.ModuleDict()
self.blockdict['block1'] = block(inplanes=GlobalVariables.SELF_INPLANES, planes=planes, stride=stride, downsample=self.downsample)
GlobalVariables.SELF_INPLANES = planes * block.expansion
for i in range(1, blocks): # block的个数 这里只能用字典了
self.blockdict['block' + str(i+1)] = block(inplanes=GlobalVariables.SELF_INPLANES, planes=planes) # 此处进行实例化了
# def _make_layer(self, block, planes, blocks, stride=1):
# downsample = None
# # stride 是卷积层的步幅,而 self.inplanes 表示当前残差块输入的通道数,
# # planes * block.expansion 则表示当前残差块输出的通道数。因此,当 stride 不等于 1 或者 self.inplanes 不等于 planes * block.expansion 时,就需要进行下采样操作
# #该层中除了第一个残差块之外,其他所有残差块的输入通道数和输出通道数都相等,并且具有相同的步幅(都为 1 或者 2)。这些卷积层的输入张量大小不变, 输出张量高宽尺寸会随着残差块的堆叠而逐渐降低
# if stride != 1 or SELF_INPLANES != planes * block.expansion:
# downsample = nn.Sequential(
# nn.Conv2d(SELF_INPLANES, planes * block.expansion,
# kernel_size=1, stride=stride, bias=False),
# nn.BatchNorm2d(planes * block.expansion),
# )
# layers = []
# layers.append(block(SELF_INPLANES, planes, stride, downsample))
# SELF_INPLANES = planes * block.expansion
# for _ in range(1, blocks): # block的个数
# layers.append(block(SELF_INPLANES, planes))
# return nn.Sequential(*layers)
def forward(self,x):
for _, layer in self.blockdict.items():
x = layer(x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
# 需检查
for _, layer in self.blockdict.items():
layer.quantize(quant_type=quant_type,num_bits=num_bits,e_bits=e_bits) # 这里是因为每一块都是block,而block中有具体的quantize策略, n_exp和mode已经在__init__中赋值了
def quantize_forward(self, x):
for _, layer in self.blockdict.items():
x = layer.quantize_forward(x) # 各个block中有具体的quantize_forward
return x
def freeze(self, qinput): # 需要在 Module Resnet的freeze里传出来
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
cnt = 0
for _, layer in self.blockdict.items():
if cnt == 0:
qo = layer.freeze(qinput = qinput)
cnt = 1
else:
qo = layer.freeze(qinput = qo) # 各个block中有具体的freeze
return qo # 供后续的层用
def fakefreeze(self):
for _, layer in self.blockdict.items():
layer.fakefreeze()
def quantize_inference(self, x):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
for _, layer in self.blockdict.items():
x = layer.quantize_inference(x) # 每个block中有具体的quantize_inference
return x
# 使用 ResNet18 模型
def resnet18_fold(**kwargs):
model = ResNet_fold(BasicBlock, [2, 2, 2, 2], **kwargs)
return model
# 使用 ResNet50 模型
def resnet50_fold(**kwargs):
model = ResNet_fold(Bottleneck, [3, 4, 6, 3], **kwargs)
return model
# 使用 ResNet152 模型
def resnet152_fold(**kwargs):
model = ResNet_fold(Bottleneck, [3, 8, 36, 3], **kwargs)
return model
......@@ -247,8 +247,13 @@ class QConv2d(QModule):
# foward前更新qw,保证量化weight时候scale正确
self.qw.update(self.conv_module.weight.data)
# 注意:此处主要为了统计各层x和weight范围,未对bias进行量化操作
tmp_wgt = FakeQuantize.apply(self.conv_module.weight, self.qw)
x = F.conv2d(x, tmp_wgt, self.conv_module.bias,
# tmp_wgt = FakeQuantize.apply(self.conv_module.weight, self.qw)
# x = F.conv2d(x, tmp_wgt, self.conv_module.bias,
# stride=self.conv_module.stride,
# padding=self.conv_module.padding, dilation=self.conv_module.dilation,
# groups=self.conv_module.groups)
x = F.conv2d(x, FakeQuantize.apply(self.conv_module.weight, self.qw), self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
......@@ -317,8 +322,9 @@ class QLinear(QModule):
x = FakeQuantize.apply(x, self.qi)
self.qw.update(self.fc_module.weight.data)
tmp_wgt = FakeQuantize.apply(self.fc_module.weight, self.qw)
x = F.linear(x, tmp_wgt, self.fc_module.bias)
# tmp_wgt = FakeQuantize.apply(self.fc_module.weight, self.qw)
# x = F.linear(x, tmp_wgt, self.fc_module.bias)
x = F.linear(x, FakeQuantize.apply(self.fc_module.weight, self.qw), self.fc_module.bias)
if hasattr(self, 'qo'):
self.qo.update(x)
......@@ -924,3 +930,260 @@ class QElementwiseAdd(QModule_2):
return x
# new modules for full-precision model - fold bn
# inference应该也需要相应的适配
class ConvBNReLU(nn.Module):
def __init__(self,conv_module, bn_module):
super(ConvBNReLU, self).__init__()
self.conv_module = conv_module
self.bn_module = bn_module
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self):
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = weight.data
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(bias)
else:
self.conv_module.bias.data = bias
def fakefreeze(self):
pass
def forward(self, x):
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
x = F.conv2d(x, weight, bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu(x)
return x
def quantize_inference(self, x):
x = self.conv_module(x)
x.clamp_(min=0)
return x
class ConvBN(nn.Module):
def __init__(self,conv_module, bn_module):
super(ConvBN, self).__init__()
self.conv_module = conv_module
self.bn_module = bn_module
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self):
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = weight.data
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(bias)
else:
self.conv_module.bias.data = bias
def fakefreeze(self):
pass
def forward(self, x):
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
x = F.conv2d(x, weight, bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
return x
def quantize_inference(self, x):
x = self.conv_module(x)
return x
class ConvBNReLU6(nn.Module):
def __init__(self,conv_module, bn_module):
super(ConvBNReLU6, self).__init__()
self.conv_module = conv_module
self.bn_module = bn_module
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self):
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = weight.data
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(bias)
else:
self.conv_module.bias.data = bias
def fakefreeze(self):
pass
def forward(self, x):
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
x = F.conv2d(x, weight, bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu6(x)
return x
def quantize_inference(self, x):
x = self.conv_module(x)
x.clamp_(min=0,max=6)
return x
\ No newline at end of file
from model import *
from model_foldbn import *
from extract_ratio import *
from utils import *
import openpyxl
import gol
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm).cpu().item()
def js_div_0(a,b):
return js_div(a,b).cpu().item()
def direct_quantize(model, test_loader,device):
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model.quantize_forward(data).cpu()
if i % 500 == 0:
break
print('direct quantization finish')
def quantize_aware_training(model, device, train_loader, optimizer, epoch):
old_sub_str0 = "downsample.0"
new_sub_str0 = "conv1"
old_sub_str1 = "downsample.1"
new_sub_str1 = "bn1"
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
if old_sub_str0 in name:
name = name.replace(old_sub_str0, new_sub_str0)
elif old_sub_str1 in name:
name = name.replace(old_sub_str1, new_sub_str1)
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader, 1):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model.quantize_forward(data)
# 对一批数据求得的loss是平均值
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
if old_sub_str0 in name:
name = name.replace(old_sub_str0, new_sub_str0)
elif old_sub_str1 in name:
name = name.replace(old_sub_str1, new_sub_str1)
grad_dict[name] += param.grad.detach()
# print(grad_dict[name])
# print(grad_dict.items())
# input()
optimizer.step()
if batch_idx % 50 == 0:
print('Quantize Aware Training Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def full_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad.detach()
# print(grad_dict[name])
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def quantize_inference(model, test_loader):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_inference(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Quant Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='QAT Training')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
parser.add_argument('-e','--epochs', default=15, type=int, metavar='EPOCHS', help='number of total epochs to run')
parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
parser.add_argument('-j','--workers', default=1, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')
parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('-wd','--weight_decay',default=0.0001,type=float,metavar='WD',help='lr schduler weight decay',dest='wd')
parser.add_argument('-t', '--test', dest='test', action='store_true', help='test model on test set')
args = parser.parse_args()
batch_size = args.batch_size
seed = 1
epochs = args.epochs
lr = args.lr
# momentum = 0.5
weight_decay = args.wd
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
writer = SummaryWriter(log_dir='log/' + args.model + '/qat')
wb = openpyxl.Workbook()
ws = wb.active
old_sub_str0 = "downsample.0"
new_sub_str0 = "conv1"
old_sub_str1 = "downsample.1"
new_sub_str1 = "bn1"
if args.model == 'ResNet18':
model = resnet18_fold()
elif args.model == 'ResNet50':
model = resnet50_fold()
elif args.model == 'ResNet152':
model = resnet152_fold()
layer, par_ratio, flop_ratio = extract_ratio(args.model)
# TODO layer要重新读取
layer = []
# 此处得到的layer是为了标记par_ratio, flop_ratio 对应起来 一层一个名字 一个flop/flop ratio
for name, param in model.named_parameters():
if 'weight' in name:
n = name.split('.') # conv,bn,fc这些有param的层的名字都能提取出来
pre = '.'.join(n[:len(n)-1])
# 提取出weight前的名字(就是这个层的名字,if weight是避免bias重复提取一遍名字)
# 无downsample串
layer.append(pre)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=args.workers, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=args.workers, pin_memory=False
)
# model.load_state_dict(torch.load(full_file))
model.to(device)
momentum = 0.9
# optimizer1 = optim.Adam(model.parameters(), lr=lr)
optimizer1 = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
lr_scheduler1 = CosineAnnealingLR(optimizer1, T_max=epochs)
# 没save .pt 无load
load_qat = False
ckpt_prefix = 'ckpt/qat/'+ args.model + '/'
loss_sum = 0.
full_grad_sum = {}
full_grad_avg = {}
for name,param in model.named_parameters():
full_grad_sum[name] = torch.zeros_like(param)
full_grad_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
# 训练原模型,获取梯度分布
loss,full_grad = train(model, device, train_loader, optimizer1, epoch)
if epoch == 1:
loss_start = loss
# print('loss:%f' % loss_avg)
writer.add_scalar('Full.loss',loss,epoch)
# for name,grad in grad_dict.items():
# writer.add_histogram('Full.'+name+'_grad',grad,global_step=epoch)
loss_sum += loss
loss_avg = loss_sum / epoch
# loss的变化量 越大说明收敛的越快(不同model在对比时,相同的epoch数,loss_delta大说明很快就进入了小loss的收敛期)
loss_delta = loss - loss_start
for name,grad in full_grad.items():
full_grad_sum[name] += full_grad[name]
full_grad_avg[name] = full_grad_sum[name] / epoch
if epoch % 5 == 0:
ws = wb.create_sheet('epoch_%d'%epoch)
ws.cell(row=1,column=2,value='loss')
ws.cell(row=1,column=3,value='loss_sum')
ws.cell(row=1,column=4,value='loss_avg')
ws.cell(row=1,column=5,value='loss_delta')
ws.cell(row=2,column=1,value='FP32')
ws.cell(row=2,column=2,value=loss.item())
ws.cell(row=2,column=3,value=loss_sum.item())
ws.cell(row=2,column=4,value=loss_avg.item())
ws.cell(row=2,column=5,value=loss_delta.item())
ws.cell(row=4,column=1,value='title')
ws.cell(row=4,column=2,value='loss')
ws.cell(row=4,column=3,value='loss_sum')
ws.cell(row=4,column=4,value='loss_avg')
ws.cell(row=4,column=5,value='loss_delta')
ws.cell(row=4,column=6,value='js_grad')
ws.cell(row=4,column=7,value='js_grad_sum')
ws.cell(row=4,column=8,value='js_grad_avg')
# lr_scheduler1.step()
quant_type_list = ['INT']
gol._init()
currow=4 #数据从哪行开始写
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
if load_qat is True and osp.exists(ckpt_prefix+'epoch_20/'+title+'.pt'):
continue
currow += 1
print('\nQAT: '+title)
if args.model == 'ResNet18':
model_ptq = resnet18()
elif args.model == 'ResNet50':
model_ptq = resnet50()
elif args.model == 'ResNet152':
model_ptq = resnet152()
# optimizer2 = optim.Adam(model_ptq.parameters(), lr=lr)
# lr_scheduler2 = CosineAnnealingLR(optimizer2, T_max=epochs)
optimizer2 = optim.SGD(model_ptq.parameters(), lr=lr, momentum=momentum)
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
model_ptq.to(device)
full_file = 'ckpt/cifar10_' + args.model + '.pt'
# model_ptq.load_state_dict(torch.load(full_file))
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.eval()
direct_quantize(model_ptq, train_loader, device)
model_ptq.train()
loss_sum = 0.
qat_grad_sum = {}
qat_grad_avg = {}
# 因为没有freeze 所以model和model_ptq的parameters其实一样,只是name在downsample处略有不同
for name,param in model_ptq.named_parameters():
if old_sub_str0 in name:
name = name.replace(old_sub_str0, new_sub_str0)
elif old_sub_str1 in name:
name = name.replace(old_sub_str1, new_sub_str1)
qat_grad_sum[name] = torch.zeros_like(param)
qat_grad_avg[name] = torch.zeros_like(param)
for epoch in range(1, epochs+1):
loss,qat_grad = quantize_aware_training(model_ptq, device, train_loader, optimizer2, epoch)
# print('loss:%f' % loss_avg)
if epoch == 1:
loss_start = loss
writer.add_scalar(title+'.loss',loss,epoch)
# for name,grad in qat_grad.items():
# writer.add_histogram(title+'.'+name+'_grad',grad,global_step=epoch)
loss_sum += loss
loss_avg = loss_sum / epoch
loss_delta = loss-loss_start
# 这里对各个epoch的梯度求和不太合理吧 修改成下面的每5个epoch只对那一个epoch的梯度求和
for name,param in model_ptq.named_parameters():
# qat_grad_sum[name] += qat_grad[name]
# 只是对name中的部分串做简单替换
if old_sub_str0 in name:
name = name.replace(old_sub_str0, new_sub_str0)
elif old_sub_str1 in name:
name = name.replace(old_sub_str1, new_sub_str1)
qat_grad_sum[name] += qat_grad[name]
qat_grad_avg[name] += qat_grad_sum[name] / epoch
# 应对每一个epoch都这样计算,而不是只计算在某一个epoch的情况
if epoch % 5 == 0:
ws = wb['epoch_%d'%epoch]
js_grad = 0.
js_grad_sum = 0.
js_grad_avg = 0.
for name,_ in model_ptq.named_parameters():
# TODO
# 可以把downsample换成对应conv,bn的名字
# downsample.0 => conv1 downsample.1 => bn1
# 由于没有freeze,因此model和model_ptq中的conv都是没有bias的
# 是否需要考虑BN的相似度和梯度还有待观察
n = name.split('.')
prefix = '.'.join(n[:len(n) - 1])
if old_sub_str0 in prefix:
prefix = prefix.replace(old_sub_str0, new_sub_str0)
elif old_sub_str1 in prefix:
prefix = prefix.replace(old_sub_str1, new_sub_str1)
if old_sub_str0 in name:
name = name.replace(old_sub_str0, new_sub_str0)
elif old_sub_str1 in name:
name = name.replace(old_sub_str1, new_sub_str1)
# layer中是层名的顺序排序,flop_ratio中也是按层名顺序排序的ratio
layer_idx = layer.index(prefix)
# 加权求和
# 这里相当于只记录了full precision时的最后一个epoch的grad
js = js_div_0(qat_grad[name],full_grad[name])
js_sum = js_div_0(qat_grad_sum[name],full_grad_sum[name])
js_avg = js_div_0(qat_grad_avg[name],full_grad_avg[name])
if js < 0:
js = 0
if js_sum < 0:
js_sum = 0
if js_avg < 0:
js_avg = 0
js_grad += flop_ratio[layer_idx] * js
print(f"name{name}\nqat_grad_avg[{name}]={qat_grad_avg[name]}\nfull_grad_avg[{name}]={full_grad_avg[name]}\njs:{js}\nidx:{layer_idx}")
js_grad_sum += flop_ratio[layer_idx] * js_sum
js_grad_avg += flop_ratio[layer_idx] * js_avg
ws.cell(row=currow,column=1,value=title)
ws.cell(row=currow,column=2,value=loss.item())
ws.cell(row=currow,column=3,value=loss_sum.item())
ws.cell(row=currow,column=4,value=loss_avg.item())
ws.cell(row=currow,column=5,value=loss_delta.item())
ws.cell(row=currow,column=6,value=js_grad)
ws.cell(row=currow,column=7,value=js_grad_sum)
ws.cell(row=currow,column=8,value=js_grad_avg)
print(f"name:{name},js_grad:{js_grad},js_sum:{js_grad_sum},js_avg:{js_grad_avg}")
# print(f"quan_type:{quant_type},num_bits:{num_bits},epoch:{epoch}")
# print(f"loss:{loss.item()},loss_sum:{loss_sum.item()},loss_avg:{loss_avg.item()},loss_delta:{loss_delta.item()}")
# print(f"js_grad:{js_grad},js_grad_sum:{js_grad_sum},js_grad_avg:{js_grad_avg}")
# lr_scheduler2.step()
wb.remove(wb['Sheet']) # 根据名称删除工作表
wb.save(args.model + 'qat_result.xlsx')
writer.close()
from model import *
from extract_ratio import *
from utils import *
import openpyxl
import gol
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
def js_div_norm(a,b):
a_norm = F.normalize(a.data,p=2,dim=-1)
b_norm = F.normalize(b.data,p=2,dim=-1)
return js_div(a_norm,b_norm).cpu().item()
def js_div_0(a,b):
return js_div(a,b).cpu().item()
def quantize_aware_training(model, device, train_loader, optimizer, epoch):
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader, 1):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model.quantize_forward(data)
# 对一批数据求得的loss是平均值
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad.detach()
# print(grad_dict[name])
# print(grad_dict.items())
# input()
optimizer.step()
if batch_idx % 50 == 0:
print('Quantize Aware Training Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def full_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
#统计loss和每个参数的grad
#初始化
loss_sum = 0.
grad_dict = {}
for name,param in model.named_parameters():
grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
#loss和grads累加
loss_sum += loss
for name,param in model.named_parameters():
if param.grad is not None:
# print('-------'+name+'-------')
grad_dict[name] += param.grad.detach()
# print(grad_dict[name])
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
batch_size = len(train_loader.batch_sampler)
#对不同batch累加值求平均
for name,grad in grad_dict.items():
grad_dict[name] = grad / batch_size
loss_avg = loss_sum / batch_size
return loss_avg, grad_dict
def quantize_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_inference(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Quant Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='QAT Training')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
parser.add_argument('-e','--epochs', default=15, type=int, metavar='EPOCHS', help='number of total epochs to run')
parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
parser.add_argument('-j','--workers', default=1, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')
parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr')
parser.add_argument('-wd','--weight_decay',default=0.0001,type=float,metavar='WD',help='lr schduler weight decay',dest='wd')
parser.add_argument('-t', '--test', dest='test', action='store_true', help='test model on test set')
args = parser.parse_args()
batch_size = args.batch_size
seed = 1
epochs = args.epochs
lr = args.lr
# momentum = 0.5
weight_decay = args.wd
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
writer = SummaryWriter(log_dir='log/' + args.model + '/qat')
wb = openpyxl.Workbook()
ws = wb.active
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
layer, par_ratio, flop_ratio = extract_ratio(args.model)
# TODO layer要重新读取
layer = []
for name, param in model.named_parameters():
if 'weight' in name:
n = name.split('.') # conv,bn,fc这些有param的层的名字都能提取出来
pre = '.'.join(n[:len(n)-1])
# 提取出weight前的名字(就是这个层的名字,if weight是避免bias重复提取一遍名字)
layer.append(pre)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=args.workers, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../../project/p/data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=args.workers, pin_memory=False
)
# model.load_state_dict(torch.load(full_file))
model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=epochs)
# 没save .pt 无load
quant_type_list = ['INT']
gol._init()
currow=4 #数据从哪行开始写
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
currow += 1
print('\nQAT: '+title)
if args.model == 'ResNet18':
model_ptq = resnet18()
elif args.model == 'ResNet50':
model_ptq = resnet50()
elif args.model == 'ResNet152':
model_ptq = resnet152()
model_ptq.to(device)
full_file = 'ckpt/cifar10_' + args.model + '.pt'
model_ptq.load_state_dict(torch.load(full_file))
model_ptq.eval()
full_acc = full_inference(model_ptq, test_loader, device)
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
# model_ptq.load_state_dict(torch.load(full_file))
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.train()
for epoch in range(1, epochs+1):
loss,qat_grad = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
# print('loss:%f' % loss_avg)
if epoch == 1:
loss_start = loss
writer.add_scalar(title+'.loss',loss,epoch)
lr_scheduler.step()
print(f"loss:{loss}")
model_ptq.freeze()
quantize_inference(model_ptq, test_loader, device)
# print(f"Final QAT ACC:{qat_acc}")
## update: <br>2023.4.28<br>
### 目标工作:尝试去解决“预测模型收敛速度”方面的问题
- 问题:按照原有思路,通过QAT from scratch获得前5/10/15/20个epoch的loss下降量和训练梯度相似度进行拟合。但根据qat.py得到的数据结果并不太好。<br>主要有两个方面的问题:<br>(1)出现了距离(即 相似度的差异性)过大、且变化过大(出现了显著的数量级差异,且规律与预期不符)的问题。<br>(2) 对不同量化方式的数据,loss的下降量有正有负,换言之,没有一个明显的loss在减小的趋势,数值较为随机。<br>
- 实验:针对上述问题,我进行了一系列观察、思考、实验,修改了qat.py中可能存在的问题,得到new_qat.py,还新增了model_foldbn.py, 修改了module.py.<br>
### 分析与实验:
1. 问题与方案:
- 量化模型中将BN fold进了Conv,因此我尝试仿照量化中的fold过程,在全精度模型训练时也将BN fold进Conv,具体的代码在module.py和model_foldbn.py中。我对fold后的全精度模型进行了训练验证,其可以正常更新权值参数,提升推理精度,但训练的收敛速度明显变慢了(ResNet18_foldbn在80个epoch时acc才40%)。
- qat.py中model和model_ptq都使用了同一个optimizer,在new_qat.py将其改为两个optimizer,分别为两个model的参数进行优化。
- 在实验中发现如果使用Adam优化器,得到的梯度会比较不稳定,我该用了SGD后稳定性提高了,趋势更显著。
- 对full_grad...等字典存储的是epoch上限时的各组梯度数据,如果直接用于与各个epoch节点的量化模型梯度数据去计算相似度,在大部分情况下是没有对应上的。这里我在实验中暂时只训练5个epoch,还没处理该问题。
- 对lr和momentum进行了一系列调整,但效果不明显。
2. 修改后的效果:
- 在INT量化中,随着量化位宽增大,量化模型的训练梯度与全精度模型的训练梯度的数量级逐渐接近至相同(但具体数值上仍有明显差异)。
- 得到明显改善的js_grad, loss_delta...等数据 (不过与预期仍不相符)
3. 还存在的问题与我的猜想:
- 对于INT量化,随着位宽增加,只有最开始出现了训练梯度相似度上升的趋势,后续呈现了波动趋势。且根据我对梯度数据的观察,他们并没有随着量化位宽增加而呈现出一致性,仅仅是数量级接近了,但具体数值仍有很大的差异。我猜想这是因为QAT from scratch较难,他们没有能有效的训练。也有可能是代码中还有一些未发现的bug。
- 在INT量化中,loss_delta(loss的减小量)也没有随着位宽增加而呈显著一致的增大,只有在位宽较小时有一段增大趋势,后续则呈现了较为随机的波动。
- 尝试进行了QAT from scratch训练,80个epoch左右没能看到明显的训练效果,可能这也是为什么出现上述问题的原因之一。
<br>(注:具体数据可见于ResNet18qat_result.xlsx中)
4. 尝试进行的拟合
loss_delta - js_grad (loss的下降量 - 第5个epoch的训练梯度加权相似度)
<img src = "fig/grad_delta.png" class="h-90 auto">
可以看到拟合效果非常差。
loss_avg - js_grad_avg (loss平均值 - 前5epoch的平均训练梯度的加权相似度)
<img src = "fig/qat.png" class="h-90 auto">
可以看到他们存在一个线性趋势,但这仅仅说明训练梯度相似度越大,loss越大,而loss的大小与训练收敛速度间并不能建立合理的逻辑关系。loss很大说明当前模型的效果很差,导致了loss大。
## update: <br>2023.4.24<br>
补充了一些数据和拟合图<br>
尝试将ResNet18,ResNet50,ResNet152,MobileNetV2四个模型的数据点拟合在同一张图上,效果还不错。不过考虑到这四个模型的结构较为相似,暂不确定与其他的结构差异较大的模型的数据点在一起拟合效果如何。
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment