Commit 4a099b52 by Zhihong Ma

fix CONVBN & CONVBNRELU for resnet

parent 8358b5d7
# -*- coding: utf-8 -*-
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
from function import FakeQuantize
def quantize_adaptivfloat(float_arr, n_bits=8, n_exp=4):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
n_mant = n_bits - 1 - n_exp
# 1. store sign value and do the following part as unsigned value
sign = torch.sign(float_arr).cpu()
float_arr = torch.abs(float_arr)
# float_arr = float_arr.cpu().numpy()
float_arr = float_arr.detach().cpu().numpy()
Emin = -2.**(n_exp-1)+1
Emax = 2.**(n_exp-1)
min_e = 2.**(Emin)
max_e = 2.**(Emax)
min_value = min_e
max_value = max_e * (2-2.**(-n_mant))
# Non denormal part
float_arr[float_arr < min_value] = 0
## 2.2. reduce too large values to max value of output format
float_arr[float_arr > max_value] = max_value
# 3. get mant, exp (the format is different from IEEE float)
# mant, exp = torch.frexp(float_arr)
mant, exp = np.frexp(float_arr) # 若是0,则mant是0,后续的float_out也就是0了
mant = torch.tensor(mant)
exp = torch.tensor(exp)
# 3.1 change mant, and exp format to IEEE float format
# no effect for exponent of 0 outputs
mant = 2 * mant
exp = exp - 1
exp = exp.clamp(Emin, Emax) # 防止上下溢出
power_exp = torch.exp2(exp)
## 4. quantize mantissa
scale = 2 ** (-n_mant) ## e.g. 2 bit, scale = 0.25
mant = ((mant / scale).floor()) * scale # 舍掉了无法达到的精度的尾数
float_out = sign * power_exp * mant
float_out = float_out.to(device)
return float_out
def build_power_value(num_bits=8):
base_a = [0.]
for i in range(2 ** num_bits - 1): # 从+ -(1) 到 + -(2 ** B - 1)
base_a.append(2 ** (-i - 1))
values = []
for a in base_a:
values.append(a)
values = torch.Tensor(list(set(values)))
values = values.mul(1.0 / torch.max(values)) # max是1吧,相当于没除
return values
def apot_quantization(tensor, alpha, proj_set): # alpha 可以是 scale
def power_quant(x, value_s):
shape = x.shape
xhard = x.view(-1) # 展平
sign = x.sign() # 应该是一个向量吧
value_s = value_s.type_as(x) # value_s 就是 proj_set
xhard = xhard.abs()
idxs = (xhard.unsqueeze(0) - value_s.unsqueeze(1)).abs().min(dim=0)[1]
xhard = value_s[idxs].view(shape).mul(sign) # 还原形状和符号
xhard = xhard
# xout的值与xhard相等,这里可能就是为了将数据从可求梯度的状态取出来
# 简单来说,上面的代码就是去把x的值映射到了离他们最近的quantization point上了
xout = (xhard - x).detach() + x
return xout
data = tensor / alpha # 相当于归一一下,α是系数 (可以不是min,max,由scale去定呗)
data = data.clamp(-1, 1) # 先clip
data_q = power_quant(data, proj_set) # 再映射
data_q = data_q * alpha # 再乘系数
return data_q
def calcScaleZeroPoint(min_val, max_val, num_bits=8, mode=1):
# 这里是0~127 uint
scale = torch.tensor(0)
zero_point = torch.tensor(0)
if mode == 1 :
qmin = 0.
qmax = 2. ** num_bits - 1.
scale = (max_val - min_val) / (qmax - qmin)
zero_point = qmax - max_val / scale
if zero_point < qmin:
zero_point = torch.tensor([qmin], dtype=torch.float32).to(min_val.device)
elif zero_point > qmax:
zero_point = torch.tensor([qmax], dtype=torch.float32).to(max_val.device)
zero_point.round_() # 截断
# 主要是mode=2用, mode=3其实不用
elif mode == 2 or mode == 3:
# print('BEFORE')
scale = max_val.abs() if max_val.abs()>min_val.abs() else min_val.abs() # 直接找了个最大值
# print(scale)
# print(zero_point)
# print('AFTER')
return scale, zero_point
def quantize_tensor(x, scale, zero_point, num_bits=8, signed=False, n_exp=4 , mode=1):
if mode == 1:
if signed:
qmin = - 2. ** (num_bits - 1)
qmax = 2. ** (num_bits - 1) - 1
else:
qmin = 0.
qmax = 2. ** num_bits - 1.
q_x = zero_point + x / scale
q_x.clamp_(qmin, qmax).round_() # 将q_x限制在[qmin,qimax],并rounding
elif mode == 2:
# 待补充
proj_set = build_power_value(num_bits)
q_x = apot_quantization(x, scale, proj_set)
elif mode == 3:
q_x = quantize_adaptivfloat(float_arr=x, n_bits=num_bits, n_exp=n_exp) # E=4 或 5
return q_x
def dequantize_tensor(q_x, scale, zero_point, mode):
if mode == 1:
return scale * (q_x - zero_point)
elif mode == 2 or mode == 3:
# 待补充
return q_x # 对于mode2,3 quantize的时候实际上就dequantize过了,实际范围是没有大变化的
# def search(M):
# P = 7000
# n = 1
# while True:
# Mo = int(round(2 ** n * M))
# # Mo
# approx_result = Mo * P >> n
# result = int(round(M * P))
# error = approx_result - result
#
# print("n=%d, Mo=%f, approx=%d, result=%d, error=%f" % \
# (n, Mo, approx_result, result, error))
#
# if math.fabs(error) < 1e-9 or n >= 22:
# return Mo, n
# n += 1
# quantize parameter 伪量化层
class QParam(nn.Module):
# mode = 1: INT, mode = 2 : PoT, mode = 3: FP
def __init__(self, num_bits=8 ,n_exp=4, mode=1):
super(QParam, self).__init__()
self.num_bits = num_bits
self.mode = mode
self.n_exp = n_exp
# 在训练时不更新梯度 并且通过register_buffer保留下来
scale = torch.tensor([], requires_grad=False)
zero_point = torch.tensor([], requires_grad=False)
min = torch.tensor([], requires_grad=False)
max = torch.tensor([], requires_grad=False)
# 张量会保存在model.state_dict()中,也就可以随着模型一起通过.cuda()复制到gpu上
self.register_buffer('scale', scale)
self.register_buffer('zero_point', zero_point)
self.register_buffer('min', min)
self.register_buffer('max', max)
# 在记录rmax,rmin
def update(self, tensor):
if self.max.nelement() == 0 or self.max.data < tensor.max().data:
self.max.data = tensor.max().data
self.max.clamp_(min=0) # 限制max>=0 (此限制相当于一种限定了范围的quantized reflection)
if self.min.nelement() == 0 or self.min.data > tensor.min().data:
self.min.data = tensor.min().data
self.min.clamp_(max=0) # 限制min<=0 (此限制相当于一种限定了范围的quantized reflection)
# 更新量化参数
self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, self.num_bits, self.mode)
# tensor量化
def quantize_tensor(self, tensor, mode):
return quantize_tensor(x=tensor, scale=self.scale, zero_point=self.zero_point, num_bits=self.num_bits, n_exp=self.n_exp, mode=self.mode)
# tensor还原
def dequantize_tensor(self, q_x, mode):
return dequantize_tensor(q_x=q_x, scale=self.scale, zero_point=self.zero_point, mode=self.mode)
# 从state_dict中恢复模型参数 这里应该是重构了
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
key_names = ['scale', 'zero_point', 'min', 'max']
for key in key_names:
value = getattr(self, key)
value.data = state_dict[prefix + key].data
state_dict.pop(prefix + key)
# 当使用print输出对象的时候,只要自己定义了__str__(self)方法,那么就会打印从在这个方法中return的数据
def __str__(self):
info = 'scale: %.10f ' % self.scale
info += 'zp: %d ' % self.zero_point
info += 'min: %.6f ' % self.min
info += 'max: %.6f' % self.max
return info
# 伪量化层
class QModule(nn.Module):
def __init__(self, qi=True, qo=True, num_bits=8, n_exp=4, mode=1):
super(QModule, self).__init__()
if qi:
self.qi = QParam(num_bits=num_bits, n_exp=n_exp, mode = mode) # qi在此处就已经被num_bits和mode赋值了
if qo:
self.qo = QParam(num_bits=num_bits, n_exp=n_exp, mode = mode) # qo在此处就已经被num_bits和mode赋值了
def freeze(self):
pass
def fakefreeze(self):
pass
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
class QConv2d(QModule):
def __init__(self, conv_module, qi=True, qo=True, num_bits=8, n_exp=4, mode=1): # 此处是为了给内蕴的QModule(i.e. qi,qo)赋值mode
super(QConv2d, self).__init__(qi=qi, qo=qo, num_bits=num_bits, n_exp=n_exp,mode=mode)
self.num_bits = num_bits
self.conv_module = conv_module
self.qw = QParam(num_bits=num_bits, n_exp=n_exp, mode=mode) # 这里是引入一个伪量化层
self.qb = QParam(num_bits=num_bits, n_exp=n_exp, mode=mode)
self.mode = mode #方便层内使用
self.n_exp = n_exp
# 新建qb
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
# freeze是即将要保存并推断了,所有的量化参数、量化量都在做最后一波更新
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None: # 有输入qi,可以给self.qi赋值
self.qi = qi
if qo is not None: # 有输入qo,可以给self.qo赋值
self.qo = qo
if self.mode == 1:
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
# 量化 weight 且weight实际上是可以直接用于相乘的 (已 -zeropoint)用于finetune后准备量化推理了
self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data, self.mode)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
# 量化 bias
# bias的num_bits是否也应该受设备量化位宽限制
self.conv_module.bias.data = quantize_tensor(self.conv_module.bias.data,
scale=self.qi.scale * self.qw.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
elif self.mode == 2 or self.mode == 3:
# 量化 weight 且weight实际上是可以直接用于相乘的 (已 -zeropoint)用于finetune后准备量化推理了
self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data, self.mode)
# 量化 bias
# bias的num_bits是否也应该受设备量化位宽限制
self.conv_module.bias.data = quantize_tensor(self.conv_module.bias.data,
scale=self.qb.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
def fakefreeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None: # 有输入qi,可以给self.qi赋值
self.qi = qi
if qo is not None: # 有输入qo,可以给self.qo赋值
self.qo = qo
if self.mode == 1:
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
# fake quantization weight
self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data, self.mode)
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data, self.mode)
# fake quantization bias
self.conv_module.bias.data = quantize_tensor(self.conv_module.bias.data,
scale=self.qi.scale * self.qw.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,
scale=self.qi.scale * self.qw.scale,
zero_point=0, mode=self.mode)
elif self.mode == 2 or self.mode == 3:
# fake quantization weight
self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data, self.mode)
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data, self.mode)
# fake quantization bias
self.conv_module.bias.data = quantize_tensor(self.conv_module.bias.data,
scale=self.qb.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,
scale=self.qb.scale,
zero_point=0, mode=self.mode)
# FakeQuantize.apply是量化再恢复,forward中的各种参数都是量化再恢复后的数据,基本还在原fp范围内
def forward(self, x):
# 在forward中会更新qi (但并不是计算后的数据,而是计算前的x的情况,这也是符合qi input的含义的)
# 若input时令qi=FALSE且为训练时用forward的时候,下面判断attr qi是为了确认下是否要先算下本层对应的qi,pool和relu都不需要,因为与继承conv来的没有区别,x在conv最后已经要么不更新qo,要么更新qo后对x进行重新量化修改了。
if hasattr(self, 'qi'):
# qi 在init时就被定了mode
self.qi.update(x) # qi中包含了伪量化层的参数、方法
x = FakeQuantize.apply(x, self.qi) # forward: FP->INT->FP (qi: input的量化) 量化再恢复
# 每次forward前会update一下qw先,保证下面运算的时候用的正确的scale等去量化weight
self.qw.update(self.conv_module.weight.data)
self.qb.update(self.conv_module.bias.data)
# conv2d(input: Tensor, weight: Tensor, bias: Optional[Tensor]=None, stride: Union[_int, _size]=1, padding: Union[_int, _size]=0, dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor: ...
# x = F.conv2d(x, FakeQuantize.apply(self.conv_module.weight, self.qw), self.conv_module.bias,
# stride=self.conv_module.stride,
# padding=self.conv_module.padding, dilation=self.conv_module.dilation,
# groups=self.conv_module.groups)
x = F.conv2d(x, FakeQuantize.apply(self.conv_module.weight, self.qw),
FakeQuantize.apply(self.conv_module.bias, self.qb),
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo) # output量化再恢复
return x
def quantize_inference(self, x): # 量化后的input x
if self.mode == 1:
x = x - self.qi.zero_point
x = self.conv_module(x) # forward 此处的conv_module的权重参数是在上面freeze量化过的
x = self.M * x # 量化计算过程
# 处理一下刚刚由layer的forward计算完得到的fp32数据
x.round_()
x = x + self.qo.zero_point
x.clamp_(0., 2. ** self.num_bits - 1.).round_() # 截断范围
return x
elif self.mode == 2 or self.mode==3:
x = self.conv_module(x) # forward 此处的conv_module的权重参数是在上面freeze量化过的
# 将计算结果再用PoT重新表示
x = FakeQuantize.apply(x, self.qo) # 首先qo根据forward后的x update过,根据self.qo去quantize+dequantize的话,能得到PoT量化后的结果
return x
# x.round_()
# x.clamp_(0., 2. ** self.num_bits - 1.).round_() # 截断范围
class QLinear(QModule):
def __init__(self, fc_module, qi=True, qo=True, num_bits=8, n_exp=4, mode=1):
super(QLinear, self).__init__(qi=qi, qo=qo, num_bits=num_bits, n_exp=n_exp, mode=mode)
self.num_bits = num_bits
self.fc_module = fc_module
self.qw = QParam(num_bits=num_bits, n_exp=n_exp, mode=mode)
self.qb = QParam(num_bits=num_bits, n_exp=n_exp, mode=mode)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
self.mode = mode
self.n_exp = n_exp
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
if self.mode == 1:
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
# quantize后的weight 需要存储 self.qw.quantize_tensor(self.fc_module.weight.data) 和 quantize_tensor(self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale,
# zero_point=0, num_bits=32, signed=True)
self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data, self.mode)
self.fc_module.weight.data = self.fc_module.weight.data - self.qw.zero_point
self.fc_module.bias.data = quantize_tensor(self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
# 这里的num_bits需要随着输入的num_bits修改吗 这里是用了qi.scale*qw.scale代替qb.scale,有一定估算成分,误差可忽略
elif self.mode == 2 or self.mode==3:
self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data, self.mode)
self.fc_module.bias.data = quantize_tensor(self.fc_module.bias.data, scale=self.qb.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
def fakefreeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
if self.mode == 1:
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
# quantize后的weight 需要存储 self.qw.quantize_tensor(self.fc_module.weight.data) 和 quantize_tensor(self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale,
# zero_point=0, num_bits=32, signed=True)
self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data, self.mode)
self.fc_module.weight.data = self.qw.dequantize_tensor(self.fc_module.weight.data, self.mode)
self.fc_module.bias.data = quantize_tensor(self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
self.fc_module.bias.data = dequantize_tensor(self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0, mode=self.mode)
elif self.mode == 2 or self.mode==3:
self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data, self.mode)
self.fc_module.weight.data = self.qw.dequantize_tensor(self.fc_module.weight.data, self.mode)
self.fc_module.bias.data = quantize_tensor(self.fc_module.bias.data, scale=self.qb.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
self.fc_module.bias.data = dequantize_tensor(self.fc_module.bias.data, scale=self.qb.scale,
zero_point=0, mode=self.mode)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
self.qw.update(self.fc_module.weight.data)
self.qb.update(self.fc_module.bias.data)
# 权重需要经过伪量化层量化 这里bias没量化 这里可改 (训练过程用到,inference也用到了)
# x = F.linear(x, FakeQuantize.apply(self.fc_module.weight, self.qw), self.fc_module.bias)
x = F.linear(x, FakeQuantize.apply(self.fc_module.weight, self.qw),
FakeQuantize.apply(self.fc_module.bias, self.qb))
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
# quantize_inference就是用的量化后的数据和参数了 由整个网络第一层进行FP32->INT 最后一层进行INT->FP32
def quantize_inference(self, x):
if self.mode == 1:
x = x - self.qi.zero_point
x = self.fc_module(x)
x = self.M * x
x.round_()
x = x + self.qo.zero_point
x.clamp_(0., 2. ** self.num_bits - 1.).round_()
return x
elif self.mode == 2 or self.mode == 3:
x = self.fc_module(x)
x = FakeQuantize.apply(x, self.qo) # 将计算结果再用PoT重新表示
return x
class QReLU(QModule):
def __init__(self, qi=False, num_bits=None, n_exp=4, mode=1):
super(QReLU, self).__init__(qi=qi, num_bits=num_bits, n_exp=n_exp, mode=mode)
self.mode = mode
self.n_exp = n_exp
# 要保存最终的量化参数了
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def fakefreeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x) # 此处更新了scale等
x = FakeQuantize.apply(x, self.qi) # 此处按照qi的scale,对x进行PoT表示
x = F.relu(x)
return x
def quantize_inference(self, x):
x = x.clone()
# print('before!!!')
# print(x)
# print('==')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
a = self.qi.zero_point.float().to(device)
x[x < a] = a
# print(x)
return x
# class QDrop(QModule):
# # dropout的效果仅仅是丢掉tensor中的部分值而已,所以对改变整体上的min,max没有意义,可以不更新qi
#
# def __init__(self, drop_module, qi=True, qo=True, num_bits=8):
# super(QDrop, self).__init__(qi=qi, qo=qo, num_bits=num_bits)
# self.num_bits = num_bits
# self.drop_module = drop_module
# self.qw = QParam(num_bits=num_bits)
# #self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
#
# def freeze(self, qi=None, qo=None):
#
# if hasattr(self, 'qi') and qi is not None:
# raise ValueError('qi has been provided in init function.')
# if not hasattr(self, 'qi') and qi is None:
# raise ValueError('qi is not existed, should be provided.')
#
# if hasattr(self, 'qo') and qo is not None:
# raise ValueError('qo has been provided in init function.')
# if not hasattr(self, 'qo') and qo is None:
# raise ValueError('qo is not existed, should be provided.')
#
# if qi is not None:
# self.qi = qi
# if qo is not None:
# self.qo = qo
#
#
# def forward(self, x):
# if hasattr(self, 'qi'):
# self.qi.update(x)
# x = FakeQuantize.apply(x, self.qi)
#
# self.qw.update(self.drop_module.weight.data)
#
# # 权重需要经过伪量化层量化 这里bias是不是没量化 可以用freeze里的方法改下这里吧
# x = nn.dropout(0.5)
#
# if hasattr(self, 'qo'):
# self.qo.update(x)
# x = FakeQuantize.apply(x, self.qo)
#
# return x
#
# def quantize_inference(self, x):
# x = x - self.qi.zero_point
# x = self.fc_module(x)
# x = self.M * x
# x.round_()
# x = x + self.qo.zero_point
# x.clamp_(0., 2.**self.num_bits-1.).round_()
# return x
class QMaxPooling2d(QModule):
def __init__(self, kernel_size=3, stride=1, padding=0, qi=False, num_bits=None, n_exp=4, mode=1):
super(QMaxPooling2d, self).__init__(qi=qi, num_bits=num_bits, n_exp=n_exp, mode=mode)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.mode = mode
self.n_exp = n_exp
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def fakefreeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi) # 与ReLu一样,先更新qi的scale,再将x用PoT表示了 (不过一般前一层的qo都是True,则x已经被PoT表示了)
x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
return x
def quantize_inference(self, x):
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
class QConvBNReLU(QModule):
def __init__(self, conv_module, bn_module, qi=True, qo=True, num_bits=8, n_exp=4, mode=1):
super(QConvBNReLU, self).__init__(qi=qi, qo=qo, num_bits=num_bits, n_exp=n_exp, mode=mode)
self.num_bits = num_bits
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(num_bits=num_bits, n_exp=n_exp, mode=mode)
self.qb = QParam(num_bits=num_bits, n_exp=n_exp, mode=mode)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
self.mode = mode
self.n_exp = n_exp
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else: # 如果conv_module.bias是None
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
if self.conv_module.bias is not None:
self.qb.update(bias.data)
# x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
# stride=self.conv_module.stride,
# padding=self.conv_module.padding, dilation=self.conv_module.dilation,
# groups=self.conv_module.groups)
if self.conv_module.bias is not None:
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), FakeQuantize.apply(bias, self.qb),
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
else:
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw),
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
if self.mode == 1:
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data, self.mode)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(bias, scale=self.qi.scale * self.qw.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
elif self.mode == 2 or self.mode == 3:
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data, self.mode)
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(bias, scale=self.qb.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
def fakefreeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
if self.mode == 1:
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data, self.mode)
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data, self.mode)
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(bias, scale=self.qi.scale * self.qw.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,
scale=self.qi.scale * self.qw.scale,
zero_point=0, mode=self.mode)
elif self.mode == 2 or self.mode == 3:
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data,self.mode)
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data, self.mode)
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(bias, scale=self.qb.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp,
mode=self.mode)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,
scale=self.qb.scale,
zero_point=0, mode=self.mode)
def quantize_inference(self, x):
if self.mode == 1:
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x.round_()
x = x + self.qo.zero_point
x.clamp_(0., 2. ** self.num_bits - 1.).round_()
return x
elif self.mode == 2 or self.mode == 3:
x = self.conv_module(x)
x = FakeQuantize.apply(x, self.qo)
return x
class QConvBN(QModule):
def __init__(self, conv_module, bn_module, qi=True, qo=True, num_bits=8, n_exp=4, mode=1):
super(QConvBN, self).__init__(qi=qi, qo=qo, num_bits=num_bits, n_exp=n_exp, mode=mode)
self.num_bits = num_bits
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(num_bits=num_bits, n_exp=n_exp, mode=mode)
self.qb = QParam(num_bits=num_bits, n_exp=n_exp, mode=mode)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
self.mode = mode
self.n_exp = n_exp
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
if self.conv_module.bias is not None:
self.qb.update(bias.data)
# x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
# stride=self.conv_module.stride,
# padding=self.conv_module.padding, dilation=self.conv_module.dilation,
# groups=self.conv_module.groups)
if self.conv_module.bias is not None:
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), FakeQuantize.apply(bias, self.qb),
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
else:
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw),
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
# x = F.relu(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
if self.mode == 1:
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data, self.mode)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(bias, scale=self.qi.scale * self.qw.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
elif self.mode == 2 or self.mode == 3:
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data, self.mode)
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(bias, scale=self.qb.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
def fakefreeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
if self.mode == 1:
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data, self.mode)
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data, self.mode)
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(bias, scale=self.qi.scale * self.qw.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp, mode=self.mode)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,
scale=self.qi.scale * self.qw.scale,
zero_point=0, mode=self.mode)
elif self.mode == 2 or self.mode == 3:
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data,self.mode)
# self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data, self.mode)
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(bias, scale=self.qb.scale,
zero_point=0, num_bits=self.num_bits, signed=True, n_exp=self.n_exp,
mode=self.mode)
# self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qb.scale,zero_point=0, mode=self.mode)
def quantize_inference(self, x):
if self.mode == 1:
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x.round_()
x = x + self.qo.zero_point
x.clamp_(0., 2. ** self.num_bits - 1.).round_()
return x
elif self.mode == 2 or self.mode == 3:
x = self.conv_module(x)
x = FakeQuantize.apply(x, self.qo)
return x
# 待修改 需要有qo吧
class QAdaptiveAvgPool2d(QModule):
def __init__(self, qi=False ,qo=True, num_bits=None, n_exp=4, mode=1):
super(QAdaptiveAvgPool2d, self).__init__(qi=qi, qo=qo, num_bits=num_bits, n_exp=n_exp, mode=mode)
self.num_bits = num_bits
self.mode = mode
self.n_exp = n_exp
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def fakefreeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi) # 与ReLu一样,先更新qi的scale,再将x用PoT表示了 (不过一般前一层的qo都是True,则x已经被PoT表示了)
x = F.adaptive_avg_pool2d(x,(1, 1))
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = F.adaptive_avg_pool2d(x,(1,1))
x = FakeQuantize.apply(x, self.qo)
return x
\ No newline at end of file
...@@ -207,6 +207,9 @@ class ResNet(nn.Module): ...@@ -207,6 +207,9 @@ class ResNet(nn.Module):
super(ResNet, self).__init__() super(ResNet, self).__init__()
self.mode = mode
self.n_exp = n_exp
self.inplanes = 16 # 因为 CIFAR-10 图片较小,所以开始时需要更少的通道数 self.inplanes = 16 # 因为 CIFAR-10 图片较小,所以开始时需要更少的通道数
GlobalVariables.SELF_INPLANES = self.inplanes GlobalVariables.SELF_INPLANES = self.inplanes
print('resnet init:'+ str(GlobalVariables.SELF_INPLANES)) print('resnet init:'+ str(GlobalVariables.SELF_INPLANES))
...@@ -283,10 +286,19 @@ class ResNet(nn.Module): ...@@ -283,10 +286,19 @@ class ResNet(nn.Module):
x = x.view(x.size(0), -1) x = x.view(x.size(0), -1)
x = self.fc(x) x = self.fc(x)
return x out = F.softmax(x,dim = 1) # 这里不softmax也行 影响不大
return out
def quantize(self, num_bits=8): def quantize(self, num_bits=8):
pass self.qconvbnrelu1 = QConvBNReLU(self.conv1,self.bn1,qi=True,qo=True,num_bits=num_bits,n_exp=self.n_exp, mode=self.mode)
# 没有输入num_bits 需修改
self.layer1.quantize(num_bits=num_bits)
self.layer2.quantize(num_bits=num_bits)
self.layer3.quantize(num_bits=num_bits)
self.layer4.quantize(num_bits=num_bits)
self.qavgpool1 = QAdaptiveAvgPool2d(qi=False,qo=True,num_bits=num_bits,n_exp=self.n_exp, mode=self.mode)
self.qfc1 = QLinear(self.fc,qi=False,qo=True,num_bits=num_bits,n_exp=self.n_exp, mode=self.mode)
def quantize_forward(self, x): def quantize_forward(self, x):
# for _, layer in self.quantize_layers.items(): # for _, layer in self.quantize_layers.items():
...@@ -294,24 +306,56 @@ class ResNet(nn.Module): ...@@ -294,24 +306,56 @@ class ResNet(nn.Module):
# out = F.softmax(x, dim=1) # out = F.softmax(x, dim=1)
# return out # return out
pass x = self.qconvbnrelu1(x)
x = self.layer1.quantize_forward(x)
x = self.layer2.quantize_forward(x)
x = self.layer3.quantize_forward(x)
x = self.layer4.quantize_forward(x)
x = self.qavgpool1(x)
x = x.view(x.size(0), -1)
x = self.qfc1(x)
out = F.softmax(x,dim = 1) # 这里不softmax也行 影响不大
return out
def freeze(self): def freeze(self):
pass self.qconvbnrelu1.freeze() # 因为作为第一层是有qi的,所以freeze的时候无需再重新提供qi
qo = self.layer1.freeze(qinput = self.qconvbnrelu1.qo)
qo = self.layer2.freeze(qinput = qo)
qo = self.layer3.freeze(qinput = qo)
qo = self.layer4.freeze(qinput = qo)
self.qavgpool1.freeze(qo)
self.qfc1.freeze(qi=qo)
def fakefreeze(self): def fakefreeze(self):
pass pass
def quantize_inference(self, x): def quantize_inference(self, x):
pass qx = self.qconvbnrelu1.qi.quantize_tensor(x,mode=self.mode)
qx = self.qconvbnrelu1.quantize_inference(qx)
qx = self.layer1.quantize_inference(qx)
qx = self.layer2.quantize_inference(qx)
qx = self.layer3.quantize_inference(qx)
qx = self.layer4.quantize_inference(qx)
qx = self.qavgpool1.quantize_inference(qx)
qx = qx.view(qx.size(0), -1)
qx = self.qfc1.quantize_inference(qx)
if self.mode == 1:
qx = self.qfc1.qo.dequantize_tensor(qx,mode=self.mode)
out = F.softmax(qx,dim = 1) # 这里不softmax也行 影响不大
return out
# BasicBlock 类 # BasicBlock 类
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
expansion = 1 expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None): def __init__(self, inplanes, planes, stride=1, downsample=None, n_exp=4, mode=1):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
# 第一个卷积层 # 第一个卷积层
...@@ -328,6 +372,8 @@ class BasicBlock(nn.Module): ...@@ -328,6 +372,8 @@ class BasicBlock(nn.Module):
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
self.mode = mode
self.n_exp = n_exp
def forward(self, x): def forward(self, x):
...@@ -349,13 +395,13 @@ class BasicBlock(nn.Module): ...@@ -349,13 +395,13 @@ class BasicBlock(nn.Module):
return out return out
def quantize(self, num_bits=8): def quantize(self, num_bits=8):
self.qconvbnrelu1 = QConvBNReLU(self.conv1,self.bn2,qi=False,qo=True,num_bits=num_bits) self.qconvbnrelu1 = QConvBNReLU(self.conv1,self.bn1,qi=False,qo=True,num_bits=num_bits,n_exp=self.n_exp,mode=self.mode)
self.qconvbn1 = QConvBN(self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits) self.qconvbn1 = QConvBN(self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits,n_exp=self.n_exp,mode=self.mode)
if self.downsample is not None: if self.downsample is not None:
self.qconvbn2 = QConvBN(self.downsample[0],self.downsample[1],qi=False,qo=True,num_bits=num_bits) self.qconvbn2 = QConvBN(self.downsample[0],self.downsample[1],qi=False,qo=True,num_bits=num_bits,n_exp=self.n_exp,mode=self.mode)
self.qrelu1 = QReLU() self.qrelu1 = QReLU(n_exp=self.n_exp,mode=self.mode)
def quantize_forward(self, x): def quantize_forward(self, x):
...@@ -371,24 +417,26 @@ class BasicBlock(nn.Module): ...@@ -371,24 +417,26 @@ class BasicBlock(nn.Module):
out = self.qrelu1(out) out = self.qrelu1(out)
return out return out
def freeze(self): def freeze(self, qinput):
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用 # 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查 # 还需仔细检查
self.qconvbnrelu1.freeze() self.qconvbnrelu1.freeze(qi= qinput) # 需要接前一个module的最后一个qo
self.qconvbn1.freeze(qi = self.qconvbnrelu1.qo) self.qconvbn1.freeze(qi = self.qconvbnrelu1.qo)
if self.downsample is not None: if self.downsample is not None:
self.qconvbn2.freeze(qi = self.qconvbn1) self.qconvbn2.freeze(qi = self.qconvbn1.qo)
self.qrelu1.freeze(self.qconvbn2) self.qrelu1.freeze(self.qconvbn2.qo)
return self.qconvbn2.qo
else: else:
self.qrelu1.freeze(self.qconvbn1) self.qrelu1.freeze(self.qconvbn1.qo)
return self.qconvbn1.qo
def quantize_inference(self, x): def quantize_inference(self, x):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。 # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
identity = x identity = x
out = self.qconvbnrelu1.quantize_inference(x) out = self.qconvbnrelu1.quantize_inference(x)
out = self.qconvbn1.quantize_inference(x) out = self.qconvbn1.quantize_inference(out)
if self.downsample is not None: if self.downsample is not None:
identity = self.qconvbn2.quantize_inference(identity) identity = self.qconvbn2.quantize_inference(identity)
...@@ -453,7 +501,7 @@ class Bottleneck(nn.Module): ...@@ -453,7 +501,7 @@ class Bottleneck(nn.Module):
class MakeLayer(nn.Module): class MakeLayer(nn.Module):
def __init__(self, block, planes, blocks, stride=1): def __init__(self, block, planes, blocks, stride=1, n_exp=4, mode=1):
super(MakeLayer, self).__init__() super(MakeLayer, self).__init__()
print('makelayer init:'+ str(GlobalVariables.SELF_INPLANES)) print('makelayer init:'+ str(GlobalVariables.SELF_INPLANES))
self.downsample = None self.downsample = None
...@@ -462,13 +510,16 @@ class MakeLayer(nn.Module): ...@@ -462,13 +510,16 @@ class MakeLayer(nn.Module):
nn.Conv2d(GlobalVariables.SELF_INPLANES, planes * block.expansion,kernel_size=1, stride=stride, bias=False), nn.Conv2d(GlobalVariables.SELF_INPLANES, planes * block.expansion,kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion) nn.BatchNorm2d(planes * block.expansion)
) )
self.n_exp = n_exp
self.mode = mode
self.blockdict = nn.ModuleDict() self.blockdict = nn.ModuleDict()
self.blockdict['block1'] = block(GlobalVariables.SELF_INPLANES, planes, stride, self.downsample) self.blockdict['block1'] = block(inplanes=GlobalVariables.SELF_INPLANES, planes=planes, stride=stride, downsample=self.downsample,n_exp=self.n_exp,mode=self.mode)
GlobalVariables.SELF_INPLANES = planes * block.expansion GlobalVariables.SELF_INPLANES = planes * block.expansion
for i in range(1, blocks): # block的个数 这里只能用字典了 for i in range(1, blocks): # block的个数 这里只能用字典了
self.blockdict['block' + str(i+1)] = block(GlobalVariables.SELF_INPLANES, planes) # 此处进行实例化了 self.blockdict['block' + str(i+1)] = block(inplanes=GlobalVariables.SELF_INPLANES, planes=planes,n_exp=self.n_exp, mode=self.mode) # 此处进行实例化了
# def _make_layer(self, block, planes, blocks, stride=1): # def _make_layer(self, block, planes, blocks, stride=1):
# downsample = None # downsample = None
# # stride 是卷积层的步幅,而 self.inplanes 表示当前残差块输入的通道数, # # stride 是卷积层的步幅,而 self.inplanes 表示当前残差块输入的通道数,
...@@ -499,7 +550,7 @@ class MakeLayer(nn.Module): ...@@ -499,7 +550,7 @@ class MakeLayer(nn.Module):
def quantize(self, num_bits=8): def quantize(self, num_bits=8):
# 需检查 # 需检查
for _, layer in self.blockdict.items(): for _, layer in self.blockdict.items():
layer.quantize() # 这里是因为每一块都是block,而block中有具体的quantize策略 layer.quantize(num_bits=num_bits) # 这里是因为每一块都是block,而block中有具体的quantize策略, n_exp和mode已经在__init__中赋值了
def quantize_forward(self, x): def quantize_forward(self, x):
...@@ -509,11 +560,18 @@ class MakeLayer(nn.Module): ...@@ -509,11 +560,18 @@ class MakeLayer(nn.Module):
return x return x
def freeze(self): def freeze(self, qinput): # 需要在 Module Resnet的freeze里传出来
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用 # 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查 # 还需仔细检查
cnt = 0
for _, layer in self.blockdict.items(): for _, layer in self.blockdict.items():
layer.freeze() # 各个block中有具体的freeze if cnt == 0:
qo = layer.freeze(qinput = qinput)
cnt = 1
else:
qo = layer.freeze(qinput = qo) # 各个block中有具体的freeze
return qo # 供后续的层用
def quantize_inference(self, x): def quantize_inference(self, x):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment