Commit c312eace by Zhihong Ma
parents 8fae747b 55f47e27
## Inception_BN 量化说明
## 新框架说明
+ 使用cfg_table进行模型快速部署和量化(可应用于其他模型),cfg_table提供整体forward平面结构,包括inc模块。inc模块则由inc_ch_table和inc_cfg_table进行部署量化。其规则可详见文件
+ cfg_table每项对应一个可进行量化融合的模块,如相邻conv bn relu可融合,在cfg_table中表现为['C','BR',...]。从而可以更方便的从该表进行量化和flops/param权重提取
+ 更改fold_ratio方法,以支持cfg_table的同项融合,多考虑了relu层,摆脱原先临近依赖的限制。方案:读取到conv时,获取相同前后缀的层,并相加
+ 更改module,允许量化层传入层的bias为none。原fold_bn已经考虑了,只需要改动freeze即可。
+ 对于Conv,freeze前后都无bias。
+ 对于ConvBN,freeze前后有bias。forward使用临时值,inference使用固定值(固定到相应conv_module)。
+ 由于允许conv.bias=None,相应改变全精度模型fold_bn方法,从而保证量化前后可比参数相同。改写方式同量化层
+ 更改js_div计算方法,一个层如果同时有多个参数,例如weight和bias,应该总共加起来权重为1。当前直接简单取平均(即js除以该层参数量),后续考虑加权。PS: Inception_BN中,外层conv层有bias,Inception模块内由于后接bn层,bias为false
+ 由于named_parameters迭代器长度不固定,需要先将排成固定列表再处理,从而获得同一层参数数,改动见ptq.py。对全精度模型做此操作即可
## ptq部分
+ 量化结果
![Inception_BN_table](image/Inception_BN_table.png)
+ 拟合结果![flops](image/flops.png)
+ ![param](image/param.png)
### debug
+ 观察量化结果可知,POT量化精度损失较大。尝试在Inception BN网络的不同位置反量化,观察POT量化的精度效果。即整体结构如下:量化->前半部分量化层->反量化->后半部分全精度层
+ 模型结构
Inception_BN_cfg_table = [
['C','',True,3,64,3,1,1],
['R'],
['C','',False,64,64,3,1,1],
['R'],
['Inc',0],
['Inc',1],
['MP',3,2,1],
['Inc',2],
['Inc',3],
['Inc',4],
['Inc',5],
['Inc',6],
['MP',3,2,1],
['Inc',7],
['Inc',8],
['AAP',1],
['C','',False,1024,10,1,1,0],
['F']
]
+ 反量化位置:
+ 最后
| Title | ptq2 | ptq3 | ptq4 | ptq5 | ptq6 | ptq7 | ptq8 |
| ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
| acc | 10.00 | 10.03 | 17.55 | 21.00 | 22.80 | 24.38 | 16.54 |
+ AAP层前
| Title | ptq2 | ptq3 | ptq4 | ptq5 | ptq6 | ptq7 | ptq8 |
| ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
| acc | 10.00 | 12.39 | 21.06 | 26.21 | 23.88 | 25.56 | 30.04 |
+ 第一个inc前
| Title | ptq2 | ptq3 | ptq4 | ptq5 | ptq6 | ptq7 | ptq8 |
| ----- | ---- | ----- | ----- | ----- | ----- | ----- | ----- |
| acc | 9.99 | 22.07 | 85.71 | 85.72 | 85.36 | 85.66 | 85.70 |
+ 第二个inc前
| Title | ptq2 | ptq3 | ptq4 | ptq5 | ptq6 | ptq7 | ptq8 |
| ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
| acc | 10.01 | 14.67 | 82.92 | 81.54 | 82.05 | 82.41 | 83.10 |
+ 第二个inc后
| Title | ptq2 | ptq3 | ptq4 | ptq5 | ptq6 | ptq7 | ptq8 |
| ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- |
| acc | 10.00 | 12.53 | 42.75 | 64.89 | 20.54 | 66.84 | 34.05 |
+ 根据不同反量化位置,初步推断,POT精度损失应该不是量化框架问题,而是该模型本身与量化方式不适配。
\ No newline at end of file
import sys
import os
# 从get_param.py输出重定向文件val.txt中提取参数量和计算量
def extract_ratio():
fr = open('param_flops.txt','r')
lines = fr.readlines()
layer = []
par_ratio = []
flop_ratio = []
for line in lines:
if '(' in line and ')' in line:
layer.append(line.split(')')[0].split('(')[1])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
par_ratio.append(r1)
r2 = line.split('%')[-2].split(',')[-1]
r2 = float(r2)
flop_ratio.append(r2)
return layer, par_ratio, flop_ratio
if __name__ == "__main__":
layer, par_ratio, flop_ratio = extract_ratio()
print(layer)
print(par_ratio)
print(flop_ratio)
\ No newline at end of file
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
\ No newline at end of file
from model import *
import torch
from ptflops import get_model_complexity_info
if __name__ == "__main__":
model = Inception_BN()
full_file = 'ckpt/cifar10_Inception_BN.pt'
model.load_state_dict(torch.load(full_file))
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
import torch
import torch.nn as nn
import torch.nn.functional as F
from module import *
import module
# conv: 'C',''/'B'/'BR',qi,in_ch,out_ch,kernel_size,stirde,padding,
# relu: 'R'
# inception: 'Inc'
# maxpool: 'MP',kernel_size,stride,padding
# adaptiveavgpool: 'AAP',output_size
# flatten: 'F'
# class 10
Inception_BN_cfg_table = [
['C','',True,3,64,3,1,1],
['R'],
['C','',False,64,64,3,1,1],
['R'],
['Inc',0],
['Inc',1],
['MP',3,2,1],
['Inc',2],
['Inc',3],
['Inc',4],
['Inc',5],
['Inc',6],
['MP',3,2,1],
['Inc',7],
['Inc',8],
['AAP',1],
['C','',False,1024,10,1,1,0],
['F']
]
inc_ch_table=[
[64, 64, 96,128, 16, 32, 32],#3a
[256,128,128,192, 32, 96, 64],#3b
[480,192, 96,208, 16, 48, 64],#4a
[512,160,112,224, 24, 64, 64],#4b
[512,128,128,256, 24, 64, 64],#4c
[512,112,144,288, 32, 64, 64],#4d
[528,256,160,320, 32,128,128],#4e
[832,256,160,320, 32,128,128],#5a
[832,384,192,384, 48,128,128] #5b
]
# br0,br1,br2,br3 <- br1x1,br3x3,br5x5,brM
# 这里的第2,3个参数是channel中的索引
# 对于后续cfg拓展,可以认为'C'有'BR'参数。这里的一个项根据量化后可融合结构指定
inc_cfg_table = [
[['C',0,1,1,1,0]],
[['C',0,2,1,1,0],
['C',2,3,3,1,1]],
[['C',0,4,1,1,0],
['C',4,5,5,1,2]],
[['MP',3,1,1],
['R'],
['C',0,6,1,1,0]]
]
def make_layers(model,cfg_table):
for i in range(len(cfg_table)):
cfg = cfg_table[i]
if cfg[0] == 'Inc':
make_inc_layers(model,cfg[1])
elif cfg[0] == 'C':
name = 'conv%d'%i
layer = nn.Conv2d(cfg[3],cfg[4],kernel_size=cfg[5],stride=cfg[6],padding=cfg[7])
model.add_module(name,layer)
if 'B' in cfg[1]:
name = 'bn%d'%i
layer = nn.BatchNorm2d(cfg[4])
model.add_module(name,layer)
if 'R' in cfg[1]:
name = 'relu%d'%i
layer = nn.ReLU(True)
model.add_module(name,layer)
elif cfg[0] == 'R':
name = 'relu%d'%i
layer = nn.ReLU(True)
model.add_module(name,layer)
elif cfg[0] == 'MP':
name = 'pool%d'%i
layer = nn.MaxPool2d(kernel_size=cfg[1],stride=cfg[2],padding=cfg[3])
model.add_module(name,layer)
elif cfg[0] == 'AAP':
name = 'aap%d'%i
layer = nn.AdaptiveAvgPool2d(cfg[1])
model.add_module(name,layer)
def model_forward(model,cfg_table,x):
for i in range(len(cfg_table)):
cfg = cfg_table[i]
if cfg[0] == 'Inc':
x = inc_forward(model,cfg[1],x)
elif cfg[0] == 'C':
name = 'conv%d'%i
layer = getattr(model,name)
x = layer(x)
if 'B' in cfg[1]:
name = 'bn%d'%i
layer = getattr(model,name)
x = layer(x)
if 'R' in cfg[1]:
name = 'relu%d'%i
layer = getattr(model,name)
x = layer(x)
elif cfg[0] == 'R':
name = 'relu%d'%i
layer = getattr(model,name)
x = layer(x)
elif cfg[0] == 'MP':
name = 'pool%d'%i
layer = getattr(model,name)
x = layer(x)
elif cfg[0] == 'AAP':
name = 'aap%d'%i
layer = getattr(model,name)
x = layer(x)
elif cfg[0] == 'F':
x = torch.flatten(x, start_dim=1)
return x
def model_quantize(model,cfg_table,quant_type,num_bits,e_bits):
for i in range(len(cfg_table)):
cfg = cfg_table[i]
if cfg[0] == 'Inc':
inc_quantize(model,cfg[1],quant_type,num_bits,e_bits)
elif cfg[0] == 'C':
conv_name = 'conv%d'%i
conv_layer = getattr(model,conv_name)
qname = 'q_'+conv_name
if 'B' in cfg[1]:
bn_name = 'bn%d'%i
bn_layer = getattr(model,bn_name)
if 'R' in cfg[1]:
qlayer = QConvBNReLU(quant_type,conv_layer,bn_layer,qi=cfg[2],num_bits=num_bits,e_bits=e_bits)
else:
qlayer = QConvBN(quant_type,conv_layer,bn_layer,qi=cfg[2],num_bits=num_bits,e_bits=e_bits)
else:
qlayer = QConv2d(quant_type,conv_layer,qi=cfg[2],num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
elif cfg[0] == 'R':
name = 'relu%d'%i
qname = 'q_'+name
qlayer = QReLU(quant_type,num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
elif cfg[0] == 'MP':
name = 'pool%d'%i
qname = 'q_'+name
qlayer = QMaxPooling2d(quant_type,kernel_size=cfg[1],stride=cfg[2],padding=cfg[3],num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
elif cfg[0] == 'AAP':
name = 'aap%d'%i
qname = 'q_'+name
qlayer = QAdaptiveAvgPool2d(quant_type,output_size=cfg[1],num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
# 支持选择反量化位置,进行debug。最后release时可取消
# end_pos为-1时表示到最后才反量化,否则在i层反量化
def model_utils(model,cfg_table,func,x=None):
end_flag = False
end_pos = 6
last_qo = None
for i in range(len(cfg_table)):
cfg = cfg_table[i]
if i == end_pos:
end_flag = True
if func == 'inference':
x = last_qo.dequantize_tensor(x)
if cfg[0] == 'Inc':
if end_flag:
if func != 'freeze':
x = inc_forward(model,cfg[1],x)
continue
x,last_qo = inc_utils(model,cfg[1],func,x,last_qo)
elif cfg[0] == 'C':
if end_flag:
name = 'conv%d'%i
layer = getattr(model,name)
if func != 'freeze':
x = layer(x)
continue
qname = 'q_conv%d'%i
qlayer = getattr(model,qname)
if func == 'forward':
x = qlayer(x)
elif func == 'inference':
# cfg[2]为True表示起始层,需要量化
if cfg[2]:
x = qlayer.qi.quantize_tensor(x)
x = qlayer.quantize_inference(x)
else: #freeze
qlayer.freeze(last_qo)
last_qo = qlayer.qo
elif cfg[0] == 'R':
if end_flag:
name = 'relu%d'%i
layer = getattr(model,name)
if func != 'freeze':
x = layer(x)
continue
qname = 'q_relu%d'%i
qlayer = getattr(model,qname)
if func == 'forward':
x = qlayer(x)
elif func == 'inference':
x = qlayer.quantize_inference(x)
else: #freeze
qlayer.freeze(last_qo)
elif cfg[0] == 'MP':
if end_flag:
name = 'pool%d'%i
layer = getattr(model,name)
if func != 'freeze':
x = layer(x)
continue
qname = 'q_pool%d'%i
qlayer = getattr(model,qname)
if func == 'forward':
x = qlayer(x)
elif func == 'inference':
x = qlayer.quantize_inference(x)
else: #freeze
qlayer.freeze(last_qo)
elif cfg[0] == 'AAP':
if end_flag:
name = 'aap%d'%i
layer = getattr(model,name)
if func != 'freeze':
x = layer(x)
continue
qname = 'q_aap%d'%i
qlayer = getattr(model,qname)
if func == 'forward':
x = qlayer(x)
elif func == 'inference':
x = qlayer.quantize_inference(x)
else: #freeze
qlayer.freeze(last_qo)
last_qo = qlayer.qo
elif cfg[0] == 'F':
if func != 'freeze':
x = torch.flatten(x,start_dim=1)
if func == 'inference' and not end_flag:
x = last_qo.dequantize_tensor(x)
return x
def make_inc_layers(model,inc_idx):
inc_name = 'inc%d'%inc_idx
ch = inc_ch_table[inc_idx]
for i in range(4): # branch
prefix = inc_name+'_br%d_'%i
for j in range(len(inc_cfg_table[i])):
cfg = inc_cfg_table[i][j]
if cfg[0] == 'MP':
name = prefix+'pool%d'%j
layer =nn.MaxPool2d(kernel_size=cfg[1],stride=cfg[2],padding=cfg[3])
model.add_module(name,layer)
elif cfg[0] == 'R':
name=prefix+'relu%d'%j
layer=nn.ReLU(True)
model.add_module(name,layer)
elif cfg[0] == 'C':
name=prefix+'conv%d'%j
layer=nn.Conv2d(ch[cfg[1]],ch[cfg[2]],kernel_size=cfg[3],stride=cfg[4],padding=cfg[5],bias=False)
model.add_module(name,layer)
name=prefix+'bn%d'%j
layer=nn.BatchNorm2d(ch[cfg[2]])
model.add_module(name,layer)
name=prefix+'relu%d'%j
layer=nn.ReLU(True)
model.add_module(name,layer)
def inc_forward(model,inc_idx,x=None):
inc_name = 'inc%d'%inc_idx
outs = []
for i in range(4):
prefix = inc_name+'_br%d_'%i
tmp = x
for j in range(len(inc_cfg_table[i])):
cfg = inc_cfg_table[i][j]
if cfg[0] == 'MP':
name = prefix+'pool%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
elif cfg[0] == 'R':
name=prefix+'relu%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
else:
name=prefix+'conv%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
name=prefix+'bn%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
name=prefix+'relu%d'%j
layer = getattr(model,name)
tmp = layer(tmp)
outs.append(tmp)
out = torch.cat(outs,1)
return out
def inc_quantize(model,inc_idx,quant_type,num_bits,e_bits):
inc_name = 'inc%d'%inc_idx
for i in range(4):
prefix = inc_name+'_br%d_'%i
for j in range(len(inc_cfg_table[i])):
cfg = inc_cfg_table[i][j]
if cfg[0] == 'MP':
name = prefix+'pool%d'%j
qname = 'q_'+name
qlayer = QMaxPooling2d(quant_type,kernel_size=cfg[1],stride=cfg[2],padding=cfg[3],num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
elif cfg[0] == 'R':
name = prefix+'relu%d'%j
qname = 'q_'+name
qlayer = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
model.add_module(qname,qlayer)
else:
conv_name=prefix+'conv%d'%j
conv_layer=getattr(model,conv_name)
bn_name=prefix+'bn%d'%j
bn_layer=getattr(model,bn_name)
qname='q_'+conv_name
qlayer=QConvBNReLU(quant_type, conv_layer, bn_layer, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
model.add_module(qname,qlayer)
qname = 'q_'+inc_name+'_concat'
qlayer = QConcat(quant_type,4,qi_array=False,qo=True,num_bits=num_bits,e_bits=e_bits)
model.add_module(qname,qlayer)
def inc_utils(model,inc_idx,func,x=None,qo=None):
inc_name = 'inc%d'%inc_idx
outs=[]
qos=[]
for i in range(4):
qprefix = 'q_'+inc_name+'_br%d_'%i
tmp = x
last_qo = qo
for j in range(len(inc_cfg_table[i])):
cfg = inc_cfg_table[i][j]
if cfg[0] == 'MP':
qname = qprefix+'pool%d'%j
qlayer = getattr(model,qname)
if func == 'forward':
tmp = qlayer(tmp)
elif func == 'inference':
tmp = qlayer.quantize_inference(tmp)
else: #freeze
qlayer.freeze(last_qo)
elif cfg[0] == 'R':
qname = qprefix+'relu%d'%j
qlayer = getattr(model,qname)
if func == 'forward':
tmp = qlayer(tmp)
elif func == 'inference':
tmp = qlayer.quantize_inference(tmp)
else: #freeze
qlayer.freeze(last_qo)
else:
qname = qprefix+'conv%d'%j
qlayer = getattr(model,qname)
if func == 'forward':
tmp = qlayer(tmp)
elif func == 'inference':
tmp = qlayer.quantize_inference(tmp)
else: #freeze
qlayer.freeze(last_qo)
last_qo = qlayer.qo
outs.append(tmp)
qos.append(last_qo)
qname = 'q_'+inc_name+'_concat'
qlayer = getattr(model,qname)
out = None
if func == 'forward':
out = qlayer(outs)
elif func == 'inference':
out = qlayer.quantize_inference(outs)
else: #freeze
qlayer.freeze(qos)
last_qo = qlayer.qo
return out,last_qo
class Inception_BN(nn.Module):
def __init__(self,num_classes=10):
super(Inception_BN, self).__init__()
self.cfg_table = Inception_BN_cfg_table
make_layers(self,self.cfg_table)
def forward(self,x):
x = model_forward(self,self.cfg_table,x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
model_quantize(self,self.cfg_table,quant_type,num_bits,e_bits)
def quantize_forward(self,x):
return model_utils(self,self.cfg_table,func='forward',x=x)
def freeze(self):
model_utils(self,self.cfg_table,func='freeze')
def quantize_inference(self,x):
return model_utils(self,self.cfg_table,func='inference',x=x)
if __name__ == "__main__":
model = Inception_BN()
model.quantize('INT',8,3)
print(model.named_modules)
print('-------')
print(model.named_parameters)
print(len(model.conv0.named_parameters()))
\ No newline at end of file
import math
import numpy as np
import gol
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from function import FakeQuantize
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
# 获取最近的量化值
def get_nearest_val(quant_type, x, is_bias=False, block_size=1000000):
if quant_type == 'INT':
return x.round_()
plist = gol.get_value(is_bias)
shape = x.shape
xhard = x.view(-1)
xout = torch.zeros_like(xhard)
plist = plist.type_as(x)
n_blocks = (x.numel() + block_size - 1) // block_size
for i in range(n_blocks):
start_idx = i * block_size
end_idx = min(start_idx + block_size, xhard.numel())
block_size_i = end_idx - start_idx
# print(x.numel())
# print(block_size_i)
# print(start_idx)
# print(end_idx)
xblock = xhard[start_idx:end_idx]
# xblock = xblock.view(shape[start_idx:end_idx])
plist_block = plist.unsqueeze(1) #.expand(-1, block_size_i)
idx = (xblock.unsqueeze(0) - plist_block).abs().min(dim=0)[1]
# print(xblock.shape)
xhard_block = plist[idx].view(xblock.shape)
xout[start_idx:end_idx] = (xhard_block - xblock).detach() + xblock
xout = xout.view(shape)
return xout
# 采用对称有符号量化时,获取量化范围最大值
def get_qmax(quant_type,num_bits=None, e_bits=None):
if quant_type == 'INT':
qmax = 2. ** (num_bits - 1) - 1
elif quant_type == 'POT':
qmax = 1
else: #FLOAT
m_bits = num_bits - 1 - e_bits
dist_m = 2 ** (-m_bits)
e = 2 ** (e_bits - 1)
expo = 2 ** e
m = 2 ** m_bits -1
frac = 1. + m * dist_m
qmax = frac * expo
return qmax
# 都采用有符号量化,zeropoint都置为0
def calcScaleZeroPoint(min_val, max_val, qmax):
scale = torch.max(max_val.abs(),min_val.abs()) / qmax
zero_point = torch.tensor(0.)
return scale, zero_point
# 将输入进行量化,输入输出都为tensor
def quantize_tensor(quant_type, x, scale, zero_point, qmax, is_bias=False):
# 量化后范围,直接根据位宽确定
qmin = -qmax
q_x = zero_point + x / scale
q_x.clamp_(qmin, qmax)
q_x = get_nearest_val(quant_type, q_x, is_bias)
return q_x
# bias使用不同精度,需要根据量化类型指定num_bits/e_bits
def bias_qmax(quant_type):
if quant_type == 'INT':
return get_qmax(quant_type, 64)
elif quant_type == 'POT':
return get_qmax(quant_type)
else:
return get_qmax(quant_type, 16, 7)
# 转化为FP32,不需再做限制
def dequantize_tensor(q_x, scale, zero_point):
return scale * (q_x - zero_point)
class QParam(nn.Module):
def __init__(self,quant_type, num_bits=8, e_bits=3):
super(QParam, self).__init__()
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.qmax = get_qmax(quant_type, num_bits, e_bits)
scale = torch.tensor([], requires_grad=False)
zero_point = torch.tensor([], requires_grad=False)
min = torch.tensor([], requires_grad=False)
max = torch.tensor([], requires_grad=False)
# 通过注册为register,使得buffer可以被记录到state_dict
self.register_buffer('scale', scale)
self.register_buffer('zero_point', zero_point)
self.register_buffer('min', min)
self.register_buffer('max', max)
# 更新统计范围及量化参数
def update(self, tensor):
if self.max.nelement() == 0 or self.max.data < tensor.max().data:
self.max.data = tensor.max().data
self.max.clamp_(min=0)
if self.min.nelement() == 0 or self.min.data > tensor.min().data:
self.min.data = tensor.min().data
self.min.clamp_(max=0)
self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, self.qmax)
def quantize_tensor(self, tensor):
return quantize_tensor(self.quant_type, tensor, self.scale, self.zero_point, self.qmax)
def dequantize_tensor(self, q_x):
return dequantize_tensor(q_x, self.scale, self.zero_point)
# 该方法保证了可以从state_dict里恢复
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
key_names = ['scale', 'zero_point', 'min', 'max']
for key in key_names:
value = getattr(self, key)
value.data = state_dict[prefix + key].data
state_dict.pop(prefix + key)
# 该方法返回值将是打印该对象的结果
def __str__(self):
info = 'scale: %.10f ' % self.scale
info += 'zp: %.6f ' % self.zero_point
info += 'min: %.6f ' % self.min
info += 'max: %.6f' % self.max
return info
# 作为具体量化层的父类,qi和qo分别为量化输入/输出
class QModule(nn.Module):
def __init__(self,quant_type, qi=False, qo=True, num_bits=8, e_bits=3):
super(QModule, self).__init__()
if qi:
self.qi = QParam(quant_type,num_bits, e_bits)
if qo:
self.qo = QParam(quant_type,num_bits, e_bits)
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
def freeze(self):
pass # 空语句
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
"""
QModule 量化卷积
:quant_type: 量化类型
:conv_module: 卷积模块
:qi: 是否量化输入特征图
:qo: 是否量化输出特征图
:num_bits: 8位bit数
"""
class QConv2d(QModule):
def __init__(self, quant_type, conv_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QConv2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
# freeze方法可以固定真量化的权重参数,并将该值更新到原全精度层上,便于散度计算
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
# 这里因为在池化或者激活的输入,不需要对最大值和最小是进行额外的统计,会共享相同的输出
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
# 根据https://zhuanlan.zhihu.com/p/156835141, 这是式3 的系数
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
#考虑conv层无bias,此时forward和inference传入none亦可
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
self.conv_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0.,qmax=self.bias_qmax, is_bias=True)
def forward(self, x): # 前向传播,输入张量,x为浮点型数据
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi) # 对输入张量X完成量化
# foward前更新qw,保证量化weight时候scale正确
self.qw.update(self.conv_module.weight.data)
# 注意:此处主要为了统计各层x和weight范围,未对bias进行量化操作
tmp_wgt = FakeQuantize.apply(self.conv_module.weight, self.qw)
x = F.conv2d(x, tmp_wgt, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
# 利用公式 q_a = M(\sigma(q_w-Z_w)(q_x-Z_x) + q_b)
def quantize_inference(self, x): # 此处input为已经量化的qx
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QLinear(QModule):
def __init__(self, quant_type, fc_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QLinear, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.fc_module = fc_module
self.qw = QParam(quant_type, num_bits, e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data)
self.fc_module.weight.data = self.fc_module.weight.data - self.qw.zero_point
if self.fc_module.bias is not None:
self.fc_module.bias.data = quantize_tensor(self.quant_type,
self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax, is_bias=True)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
self.qw.update(self.fc_module.weight.data)
tmp_wgt = FakeQuantize.apply(self.fc_module.weight, self.qw)
x = F.linear(x, tmp_wgt, self.fc_module.bias)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.fc_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QReLU(QModule):
def __init__(self,quant_type, qi=False, qo=False, num_bits=8, e_bits=3):
super(QReLU, self).__init__(quant_type, qi, qo, num_bits, e_bits)
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.relu(x)
return x
def quantize_inference(self, x):
x = x.clone()
x[x < self.qi.zero_point] = self.qi.zero_point
return x
class QMaxPooling2d(QModule):
def __init__(self, quant_type, kernel_size=3, stride=1, padding=0, qi=False, qo=False, num_bits=8,e_bits=3):
super(QMaxPooling2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
return x
def quantize_inference(self, x):
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
class QAdaptiveAvgPool2d(QModule):
def __init__(self, quant_type, output_size, qi=False, qo=True, num_bits=8,e_bits=3):
super(QAdaptiveAvgPool2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.output_size = output_size
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qi.scale / self.qo.scale).data
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.adaptive_avg_pool2d(x, self.output_size)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = F.adaptive_avg_pool2d(x, self.output_size)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x+self.qo.zero_point
return x
class QConvBNReLU(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QConvBNReLU, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
else:
bias = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
self.conv_module.bias = torch.nn.Parameter(bias)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
x.clamp_(min=0)
return x
class QConvBN(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=False, qo=True, num_bits=8, e_bits=3):
super(QConvBN, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is not None:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
else:
bias = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
self.conv_module.bias = torch.nn.Parameter(bias)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
# 作为具体量化层的父类,qi和qo分别为量化输入/输出
# 用于处理多个层结果或qo以array形式传入
class QModule_array(nn.Module):
def __init__(self,quant_type,len,qi_array=False, qo=True, num_bits=8, e_bits=3):
super(QModule_array, self).__init__()
if qi_array:
for i in range(len):
self.add_module('qi%d'%i,QParam(quant_type,num_bits, e_bits))
if qo:
self.qo = QParam(quant_type,num_bits, e_bits)
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
self.len = len
def freeze(self):
pass # 空语句
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
class QConcat(QModule_array):
def __init__(self, quant_type, len, qi_array=False, qo=True, num_bits=8, e_bits=3):
super(QConcat,self).__init__(quant_type, len, qi_array, qo, num_bits, e_bits)
for i in range(len):
self.register_buffer('M%d'%i,torch.tensor([], requires_grad=False))
def freeze(self, qi_array=None, qo=None):
if qi_array is None:
raise ValueError('qi_array should be provided')
elif len(qi_array) != self.len:
raise ValueError('qi_array len no match')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
for i in range(self.len):
self.add_module('qi%d'%i,qi_array[i])
if qo is not None:
self.qo = qo
for i in range(self.len):
getattr(self,'M%d'%i).data = (getattr(self,'qi%d'%i).scale / self.qo.scale).data
def forward(self,x_array):
outs=[]
for i in range(self.len):
x = x_array[i]
if hasattr(self,'qi%d'%i):
qi = getattr(self,'qi%d'%i)
qi.update(x)
x = FakeQuantize.apply(x,qi)
outs.append(x)
out = torch.cat(outs,1)
if hasattr(self,'qo'):
self.qo.update(x)
out = FakeQuantize.apply(out,self.qo)
return out
def quantize_inference(self, x_array):
outs=[]
for i in range(self.len):
qi = getattr(self,'qi%d'%i)
x = x_array[i] - qi.zero_point
x = getattr(self,'M%d'%i) * x
outs.append(x)
out = torch.concat(outs,1)
out = get_nearest_val(self.quant_type,out)
out = out + self.qo.zero_point
return out
\ No newline at end of file
Inception_BN(
5.88 M, 100.000% Params, 1.46 GMac, 100.000% MACs,
(conv0): Conv2d(1.79 k, 0.030% Params, 1.84 MMac, 0.125% MACs, 3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu1): ReLU(0, 0.000% Params, 65.54 KMac, 0.004% MACs, inplace=True)
(conv2): Conv2d(36.93 k, 0.628% Params, 37.81 MMac, 2.583% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(relu3): ReLU(0, 0.000% Params, 65.54 KMac, 0.004% MACs, inplace=True)
(inc0_br0_conv0): Conv2d(4.1 k, 0.070% Params, 4.19 MMac, 0.287% MACs, 64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc0_br0_bn0): BatchNorm2d(128, 0.002% Params, 131.07 KMac, 0.009% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc0_br0_relu0): ReLU(0, 0.000% Params, 65.54 KMac, 0.004% MACs, inplace=True)
(inc0_br1_conv0): Conv2d(6.14 k, 0.105% Params, 6.29 MMac, 0.430% MACs, 64, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc0_br1_bn0): BatchNorm2d(192, 0.003% Params, 196.61 KMac, 0.013% MACs, 96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc0_br1_relu0): ReLU(0, 0.000% Params, 98.3 KMac, 0.007% MACs, inplace=True)
(inc0_br1_conv1): Conv2d(110.59 k, 1.881% Params, 113.25 MMac, 7.737% MACs, 96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(inc0_br1_bn1): BatchNorm2d(256, 0.004% Params, 262.14 KMac, 0.018% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc0_br1_relu1): ReLU(0, 0.000% Params, 131.07 KMac, 0.009% MACs, inplace=True)
(inc0_br2_conv0): Conv2d(1.02 k, 0.017% Params, 1.05 MMac, 0.072% MACs, 64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc0_br2_bn0): BatchNorm2d(32, 0.001% Params, 32.77 KMac, 0.002% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc0_br2_relu0): ReLU(0, 0.000% Params, 16.38 KMac, 0.001% MACs, inplace=True)
(inc0_br2_conv1): Conv2d(12.8 k, 0.218% Params, 13.11 MMac, 0.895% MACs, 16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
(inc0_br2_bn1): BatchNorm2d(64, 0.001% Params, 65.54 KMac, 0.004% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc0_br2_relu1): ReLU(0, 0.000% Params, 32.77 KMac, 0.002% MACs, inplace=True)
(inc0_br3_pool0): MaxPool2d(0, 0.000% Params, 65.54 KMac, 0.004% MACs, kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(inc0_br3_relu1): ReLU(0, 0.000% Params, 65.54 KMac, 0.004% MACs, inplace=True)
(inc0_br3_conv2): Conv2d(2.05 k, 0.035% Params, 2.1 MMac, 0.143% MACs, 64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc0_br3_bn2): BatchNorm2d(64, 0.001% Params, 65.54 KMac, 0.004% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc0_br3_relu2): ReLU(0, 0.000% Params, 32.77 KMac, 0.002% MACs, inplace=True)
(inc1_br0_conv0): Conv2d(32.77 k, 0.557% Params, 33.55 MMac, 2.292% MACs, 256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc1_br0_bn0): BatchNorm2d(256, 0.004% Params, 262.14 KMac, 0.018% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc1_br0_relu0): ReLU(0, 0.000% Params, 131.07 KMac, 0.009% MACs, inplace=True)
(inc1_br1_conv0): Conv2d(32.77 k, 0.557% Params, 33.55 MMac, 2.292% MACs, 256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc1_br1_bn0): BatchNorm2d(256, 0.004% Params, 262.14 KMac, 0.018% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc1_br1_relu0): ReLU(0, 0.000% Params, 131.07 KMac, 0.009% MACs, inplace=True)
(inc1_br1_conv1): Conv2d(221.18 k, 3.763% Params, 226.49 MMac, 15.474% MACs, 128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(inc1_br1_bn1): BatchNorm2d(384, 0.007% Params, 393.22 KMac, 0.027% MACs, 192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc1_br1_relu1): ReLU(0, 0.000% Params, 196.61 KMac, 0.013% MACs, inplace=True)
(inc1_br2_conv0): Conv2d(8.19 k, 0.139% Params, 8.39 MMac, 0.573% MACs, 256, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc1_br2_bn0): BatchNorm2d(64, 0.001% Params, 65.54 KMac, 0.004% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc1_br2_relu0): ReLU(0, 0.000% Params, 32.77 KMac, 0.002% MACs, inplace=True)
(inc1_br2_conv1): Conv2d(76.8 k, 1.306% Params, 78.64 MMac, 5.373% MACs, 32, 96, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
(inc1_br2_bn1): BatchNorm2d(192, 0.003% Params, 196.61 KMac, 0.013% MACs, 96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc1_br2_relu1): ReLU(0, 0.000% Params, 98.3 KMac, 0.007% MACs, inplace=True)
(inc1_br3_pool0): MaxPool2d(0, 0.000% Params, 262.14 KMac, 0.018% MACs, kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(inc1_br3_relu1): ReLU(0, 0.000% Params, 262.14 KMac, 0.018% MACs, inplace=True)
(inc1_br3_conv2): Conv2d(16.38 k, 0.279% Params, 16.78 MMac, 1.146% MACs, 256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc1_br3_bn2): BatchNorm2d(128, 0.002% Params, 131.07 KMac, 0.009% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc1_br3_relu2): ReLU(0, 0.000% Params, 65.54 KMac, 0.004% MACs, inplace=True)
(pool6): MaxPool2d(0, 0.000% Params, 491.52 KMac, 0.034% MACs, kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(inc2_br0_conv0): Conv2d(92.16 k, 1.568% Params, 23.59 MMac, 1.612% MACs, 480, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc2_br0_bn0): BatchNorm2d(384, 0.007% Params, 98.3 KMac, 0.007% MACs, 192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc2_br0_relu0): ReLU(0, 0.000% Params, 49.15 KMac, 0.003% MACs, inplace=True)
(inc2_br1_conv0): Conv2d(46.08 k, 0.784% Params, 11.8 MMac, 0.806% MACs, 480, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc2_br1_bn0): BatchNorm2d(192, 0.003% Params, 49.15 KMac, 0.003% MACs, 96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc2_br1_relu0): ReLU(0, 0.000% Params, 24.58 KMac, 0.002% MACs, inplace=True)
(inc2_br1_conv1): Conv2d(179.71 k, 3.057% Params, 46.01 MMac, 3.143% MACs, 96, 208, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(inc2_br1_bn1): BatchNorm2d(416, 0.007% Params, 106.5 KMac, 0.007% MACs, 208, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc2_br1_relu1): ReLU(0, 0.000% Params, 53.25 KMac, 0.004% MACs, inplace=True)
(inc2_br2_conv0): Conv2d(7.68 k, 0.131% Params, 1.97 MMac, 0.134% MACs, 480, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc2_br2_bn0): BatchNorm2d(32, 0.001% Params, 8.19 KMac, 0.001% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc2_br2_relu0): ReLU(0, 0.000% Params, 4.1 KMac, 0.000% MACs, inplace=True)
(inc2_br2_conv1): Conv2d(19.2 k, 0.327% Params, 4.92 MMac, 0.336% MACs, 16, 48, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
(inc2_br2_bn1): BatchNorm2d(96, 0.002% Params, 24.58 KMac, 0.002% MACs, 48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc2_br2_relu1): ReLU(0, 0.000% Params, 12.29 KMac, 0.001% MACs, inplace=True)
(inc2_br3_pool0): MaxPool2d(0, 0.000% Params, 122.88 KMac, 0.008% MACs, kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(inc2_br3_relu1): ReLU(0, 0.000% Params, 122.88 KMac, 0.008% MACs, inplace=True)
(inc2_br3_conv2): Conv2d(30.72 k, 0.523% Params, 7.86 MMac, 0.537% MACs, 480, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc2_br3_bn2): BatchNorm2d(128, 0.002% Params, 32.77 KMac, 0.002% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc2_br3_relu2): ReLU(0, 0.000% Params, 16.38 KMac, 0.001% MACs, inplace=True)
(inc3_br0_conv0): Conv2d(81.92 k, 1.394% Params, 20.97 MMac, 1.433% MACs, 512, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc3_br0_bn0): BatchNorm2d(320, 0.005% Params, 81.92 KMac, 0.006% MACs, 160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc3_br0_relu0): ReLU(0, 0.000% Params, 40.96 KMac, 0.003% MACs, inplace=True)
(inc3_br1_conv0): Conv2d(57.34 k, 0.975% Params, 14.68 MMac, 1.003% MACs, 512, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc3_br1_bn0): BatchNorm2d(224, 0.004% Params, 57.34 KMac, 0.004% MACs, 112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc3_br1_relu0): ReLU(0, 0.000% Params, 28.67 KMac, 0.002% MACs, inplace=True)
(inc3_br1_conv1): Conv2d(225.79 k, 3.841% Params, 57.8 MMac, 3.949% MACs, 112, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(inc3_br1_bn1): BatchNorm2d(448, 0.008% Params, 114.69 KMac, 0.008% MACs, 224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc3_br1_relu1): ReLU(0, 0.000% Params, 57.34 KMac, 0.004% MACs, inplace=True)
(inc3_br2_conv0): Conv2d(12.29 k, 0.209% Params, 3.15 MMac, 0.215% MACs, 512, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc3_br2_bn0): BatchNorm2d(48, 0.001% Params, 12.29 KMac, 0.001% MACs, 24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc3_br2_relu0): ReLU(0, 0.000% Params, 6.14 KMac, 0.000% MACs, inplace=True)
(inc3_br2_conv1): Conv2d(38.4 k, 0.653% Params, 9.83 MMac, 0.672% MACs, 24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
(inc3_br2_bn1): BatchNorm2d(128, 0.002% Params, 32.77 KMac, 0.002% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc3_br2_relu1): ReLU(0, 0.000% Params, 16.38 KMac, 0.001% MACs, inplace=True)
(inc3_br3_pool0): MaxPool2d(0, 0.000% Params, 131.07 KMac, 0.009% MACs, kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(inc3_br3_relu1): ReLU(0, 0.000% Params, 131.07 KMac, 0.009% MACs, inplace=True)
(inc3_br3_conv2): Conv2d(32.77 k, 0.557% Params, 8.39 MMac, 0.573% MACs, 512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc3_br3_bn2): BatchNorm2d(128, 0.002% Params, 32.77 KMac, 0.002% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc3_br3_relu2): ReLU(0, 0.000% Params, 16.38 KMac, 0.001% MACs, inplace=True)
(inc4_br0_conv0): Conv2d(65.54 k, 1.115% Params, 16.78 MMac, 1.146% MACs, 512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc4_br0_bn0): BatchNorm2d(256, 0.004% Params, 65.54 KMac, 0.004% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc4_br0_relu0): ReLU(0, 0.000% Params, 32.77 KMac, 0.002% MACs, inplace=True)
(inc4_br1_conv0): Conv2d(65.54 k, 1.115% Params, 16.78 MMac, 1.146% MACs, 512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc4_br1_bn0): BatchNorm2d(256, 0.004% Params, 65.54 KMac, 0.004% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc4_br1_relu0): ReLU(0, 0.000% Params, 32.77 KMac, 0.002% MACs, inplace=True)
(inc4_br1_conv1): Conv2d(294.91 k, 5.017% Params, 75.5 MMac, 5.158% MACs, 128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(inc4_br1_bn1): BatchNorm2d(512, 0.009% Params, 131.07 KMac, 0.009% MACs, 256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc4_br1_relu1): ReLU(0, 0.000% Params, 65.54 KMac, 0.004% MACs, inplace=True)
(inc4_br2_conv0): Conv2d(12.29 k, 0.209% Params, 3.15 MMac, 0.215% MACs, 512, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc4_br2_bn0): BatchNorm2d(48, 0.001% Params, 12.29 KMac, 0.001% MACs, 24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc4_br2_relu0): ReLU(0, 0.000% Params, 6.14 KMac, 0.000% MACs, inplace=True)
(inc4_br2_conv1): Conv2d(38.4 k, 0.653% Params, 9.83 MMac, 0.672% MACs, 24, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
(inc4_br2_bn1): BatchNorm2d(128, 0.002% Params, 32.77 KMac, 0.002% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc4_br2_relu1): ReLU(0, 0.000% Params, 16.38 KMac, 0.001% MACs, inplace=True)
(inc4_br3_pool0): MaxPool2d(0, 0.000% Params, 131.07 KMac, 0.009% MACs, kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(inc4_br3_relu1): ReLU(0, 0.000% Params, 131.07 KMac, 0.009% MACs, inplace=True)
(inc4_br3_conv2): Conv2d(32.77 k, 0.557% Params, 8.39 MMac, 0.573% MACs, 512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc4_br3_bn2): BatchNorm2d(128, 0.002% Params, 32.77 KMac, 0.002% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc4_br3_relu2): ReLU(0, 0.000% Params, 16.38 KMac, 0.001% MACs, inplace=True)
(inc5_br0_conv0): Conv2d(57.34 k, 0.975% Params, 14.68 MMac, 1.003% MACs, 512, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc5_br0_bn0): BatchNorm2d(224, 0.004% Params, 57.34 KMac, 0.004% MACs, 112, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc5_br0_relu0): ReLU(0, 0.000% Params, 28.67 KMac, 0.002% MACs, inplace=True)
(inc5_br1_conv0): Conv2d(73.73 k, 1.254% Params, 18.87 MMac, 1.289% MACs, 512, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc5_br1_bn0): BatchNorm2d(288, 0.005% Params, 73.73 KMac, 0.005% MACs, 144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc5_br1_relu0): ReLU(0, 0.000% Params, 36.86 KMac, 0.003% MACs, inplace=True)
(inc5_br1_conv1): Conv2d(373.25 k, 6.349% Params, 95.55 MMac, 6.528% MACs, 144, 288, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(inc5_br1_bn1): BatchNorm2d(576, 0.010% Params, 147.46 KMac, 0.010% MACs, 288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc5_br1_relu1): ReLU(0, 0.000% Params, 73.73 KMac, 0.005% MACs, inplace=True)
(inc5_br2_conv0): Conv2d(16.38 k, 0.279% Params, 4.19 MMac, 0.287% MACs, 512, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc5_br2_bn0): BatchNorm2d(64, 0.001% Params, 16.38 KMac, 0.001% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc5_br2_relu0): ReLU(0, 0.000% Params, 8.19 KMac, 0.001% MACs, inplace=True)
(inc5_br2_conv1): Conv2d(51.2 k, 0.871% Params, 13.11 MMac, 0.895% MACs, 32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
(inc5_br2_bn1): BatchNorm2d(128, 0.002% Params, 32.77 KMac, 0.002% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc5_br2_relu1): ReLU(0, 0.000% Params, 16.38 KMac, 0.001% MACs, inplace=True)
(inc5_br3_pool0): MaxPool2d(0, 0.000% Params, 131.07 KMac, 0.009% MACs, kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(inc5_br3_relu1): ReLU(0, 0.000% Params, 131.07 KMac, 0.009% MACs, inplace=True)
(inc5_br3_conv2): Conv2d(32.77 k, 0.557% Params, 8.39 MMac, 0.573% MACs, 512, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc5_br3_bn2): BatchNorm2d(128, 0.002% Params, 32.77 KMac, 0.002% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc5_br3_relu2): ReLU(0, 0.000% Params, 16.38 KMac, 0.001% MACs, inplace=True)
(inc6_br0_conv0): Conv2d(135.17 k, 2.299% Params, 34.6 MMac, 2.364% MACs, 528, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc6_br0_bn0): BatchNorm2d(512, 0.009% Params, 131.07 KMac, 0.009% MACs, 256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc6_br0_relu0): ReLU(0, 0.000% Params, 65.54 KMac, 0.004% MACs, inplace=True)
(inc6_br1_conv0): Conv2d(84.48 k, 1.437% Params, 21.63 MMac, 1.478% MACs, 528, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc6_br1_bn0): BatchNorm2d(320, 0.005% Params, 81.92 KMac, 0.006% MACs, 160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc6_br1_relu0): ReLU(0, 0.000% Params, 40.96 KMac, 0.003% MACs, inplace=True)
(inc6_br1_conv1): Conv2d(460.8 k, 7.839% Params, 117.96 MMac, 8.059% MACs, 160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(inc6_br1_bn1): BatchNorm2d(640, 0.011% Params, 163.84 KMac, 0.011% MACs, 320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc6_br1_relu1): ReLU(0, 0.000% Params, 81.92 KMac, 0.006% MACs, inplace=True)
(inc6_br2_conv0): Conv2d(16.9 k, 0.287% Params, 4.33 MMac, 0.296% MACs, 528, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc6_br2_bn0): BatchNorm2d(64, 0.001% Params, 16.38 KMac, 0.001% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc6_br2_relu0): ReLU(0, 0.000% Params, 8.19 KMac, 0.001% MACs, inplace=True)
(inc6_br2_conv1): Conv2d(102.4 k, 1.742% Params, 26.21 MMac, 1.791% MACs, 32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
(inc6_br2_bn1): BatchNorm2d(256, 0.004% Params, 65.54 KMac, 0.004% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc6_br2_relu1): ReLU(0, 0.000% Params, 32.77 KMac, 0.002% MACs, inplace=True)
(inc6_br3_pool0): MaxPool2d(0, 0.000% Params, 135.17 KMac, 0.009% MACs, kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(inc6_br3_relu1): ReLU(0, 0.000% Params, 135.17 KMac, 0.009% MACs, inplace=True)
(inc6_br3_conv2): Conv2d(67.58 k, 1.150% Params, 17.3 MMac, 1.182% MACs, 528, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc6_br3_bn2): BatchNorm2d(256, 0.004% Params, 65.54 KMac, 0.004% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc6_br3_relu2): ReLU(0, 0.000% Params, 32.77 KMac, 0.002% MACs, inplace=True)
(pool12): MaxPool2d(0, 0.000% Params, 212.99 KMac, 0.015% MACs, kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(inc7_br0_conv0): Conv2d(212.99 k, 3.623% Params, 13.63 MMac, 0.931% MACs, 832, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc7_br0_bn0): BatchNorm2d(512, 0.009% Params, 32.77 KMac, 0.002% MACs, 256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc7_br0_relu0): ReLU(0, 0.000% Params, 16.38 KMac, 0.001% MACs, inplace=True)
(inc7_br1_conv0): Conv2d(133.12 k, 2.265% Params, 8.52 MMac, 0.582% MACs, 832, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc7_br1_bn0): BatchNorm2d(320, 0.005% Params, 20.48 KMac, 0.001% MACs, 160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc7_br1_relu0): ReLU(0, 0.000% Params, 10.24 KMac, 0.001% MACs, inplace=True)
(inc7_br1_conv1): Conv2d(460.8 k, 7.839% Params, 29.49 MMac, 2.015% MACs, 160, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(inc7_br1_bn1): BatchNorm2d(640, 0.011% Params, 40.96 KMac, 0.003% MACs, 320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc7_br1_relu1): ReLU(0, 0.000% Params, 20.48 KMac, 0.001% MACs, inplace=True)
(inc7_br2_conv0): Conv2d(26.62 k, 0.453% Params, 1.7 MMac, 0.116% MACs, 832, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc7_br2_bn0): BatchNorm2d(64, 0.001% Params, 4.1 KMac, 0.000% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc7_br2_relu0): ReLU(0, 0.000% Params, 2.05 KMac, 0.000% MACs, inplace=True)
(inc7_br2_conv1): Conv2d(102.4 k, 1.742% Params, 6.55 MMac, 0.448% MACs, 32, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
(inc7_br2_bn1): BatchNorm2d(256, 0.004% Params, 16.38 KMac, 0.001% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc7_br2_relu1): ReLU(0, 0.000% Params, 8.19 KMac, 0.001% MACs, inplace=True)
(inc7_br3_pool0): MaxPool2d(0, 0.000% Params, 53.25 KMac, 0.004% MACs, kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(inc7_br3_relu1): ReLU(0, 0.000% Params, 53.25 KMac, 0.004% MACs, inplace=True)
(inc7_br3_conv2): Conv2d(106.5 k, 1.812% Params, 6.82 MMac, 0.466% MACs, 832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc7_br3_bn2): BatchNorm2d(256, 0.004% Params, 16.38 KMac, 0.001% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc7_br3_relu2): ReLU(0, 0.000% Params, 8.19 KMac, 0.001% MACs, inplace=True)
(inc8_br0_conv0): Conv2d(319.49 k, 5.435% Params, 20.45 MMac, 1.397% MACs, 832, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc8_br0_bn0): BatchNorm2d(768, 0.013% Params, 49.15 KMac, 0.003% MACs, 384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc8_br0_relu0): ReLU(0, 0.000% Params, 24.58 KMac, 0.002% MACs, inplace=True)
(inc8_br1_conv0): Conv2d(159.74 k, 2.717% Params, 10.22 MMac, 0.698% MACs, 832, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc8_br1_bn0): BatchNorm2d(384, 0.007% Params, 24.58 KMac, 0.002% MACs, 192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc8_br1_relu0): ReLU(0, 0.000% Params, 12.29 KMac, 0.001% MACs, inplace=True)
(inc8_br1_conv1): Conv2d(663.55 k, 11.288% Params, 42.47 MMac, 2.901% MACs, 192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(inc8_br1_bn1): BatchNorm2d(768, 0.013% Params, 49.15 KMac, 0.003% MACs, 384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc8_br1_relu1): ReLU(0, 0.000% Params, 24.58 KMac, 0.002% MACs, inplace=True)
(inc8_br2_conv0): Conv2d(39.94 k, 0.679% Params, 2.56 MMac, 0.175% MACs, 832, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc8_br2_bn0): BatchNorm2d(96, 0.002% Params, 6.14 KMac, 0.000% MACs, 48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc8_br2_relu0): ReLU(0, 0.000% Params, 3.07 KMac, 0.000% MACs, inplace=True)
(inc8_br2_conv1): Conv2d(153.6 k, 2.613% Params, 9.83 MMac, 0.672% MACs, 48, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
(inc8_br2_bn1): BatchNorm2d(256, 0.004% Params, 16.38 KMac, 0.001% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc8_br2_relu1): ReLU(0, 0.000% Params, 8.19 KMac, 0.001% MACs, inplace=True)
(inc8_br3_pool0): MaxPool2d(0, 0.000% Params, 53.25 KMac, 0.004% MACs, kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
(inc8_br3_relu1): ReLU(0, 0.000% Params, 53.25 KMac, 0.004% MACs, inplace=True)
(inc8_br3_conv2): Conv2d(106.5 k, 1.812% Params, 6.82 MMac, 0.466% MACs, 832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(inc8_br3_bn2): BatchNorm2d(256, 0.004% Params, 16.38 KMac, 0.001% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(inc8_br3_relu2): ReLU(0, 0.000% Params, 8.19 KMac, 0.001% MACs, inplace=True)
(aap15): AdaptiveAvgPool2d(0, 0.000% Params, 65.54 KMac, 0.004% MACs, output_size=1)
(conv16): Conv2d(10.25 k, 0.174% Params, 10.25 KMac, 0.001% MACs, 1024, 10, kernel_size=(1, 1), stride=(1, 1))
)
from torch.serialization import load
from model import *
from extract_ratio import *
from utils import *
import gol
import openpyxl
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import torch.utils.bottleneck as bn
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
def direct_quantize(model, test_loader,device):
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model.quantize_forward(data).cpu()
if i % 500 == 0:
break
print('direct quantization finish')
def full_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model(data).cpu()
pred = output.argmax(dim=1, keepdim=True)
# print(pred)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
return 100. * correct / len(test_loader.dataset)
def quantize_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data = data.to(device)
output = model.quantize_inference(data).cpu()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('Test set: Quant Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
return 100. * correct / len(test_loader.dataset)
if __name__ == "__main__":
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = Inception_BN()
writer = SummaryWriter(log_dir='./log')
full_file = 'ckpt/cifar10_Inception_BN.pt'
model.load_state_dict(torch.load(full_file))
model.to(device)
load_ptq = True
store_ptq = False
ptq_file_prefix = 'ckpt/cifar10_Inception_BN_ptq_'
model.eval()
full_acc = full_inference(model, test_loader, device)
# 传入后可变
fold_model(model)
layer, par_ratio, flop_ratio = extract_ratio()
par_ratio, flop_ratio = fold_ratio(layer, par_ratio, flop_ratio)
full_names = []
full_params = []
for name, param in model.named_parameters():
if 'conv' in name or 'fc' in name:
full_names.append(name)
param_norm = F.normalize(param.data.cpu(),p=2,dim=-1)
full_params.append(param_norm)
writer.add_histogram(tag='Full_' + name + '_data', values=param.data)
#统计每个参数对应层的参数个数
full_par_num=[]
for name in full_names:
prefix = name.rsplit('.',1)[0]
cnt = 0
for str in full_names:
sprefix = str.rsplit('.',1)[0]
if prefix == sprefix:
cnt += 1
full_par_num.append(cnt)
# print(full_names)
# print(full_par_num)
# print('-------')
# input()
gol._init()
# quant_type_list = ['INT','POT','FLOAT']
# quant_type_list = ['INT']
quant_type_list = ['POT']
title_list = []
js_flops_list = []
js_param_list = []
ptq_acc_list = []
acc_loss_list = []
for quant_type in quant_type_list:
num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可
if quant_type != 'INT':
bias_list = build_bias_list(quant_type)
gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
model_ptq = Inception_BN()
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
title = '%s_%d' % (quant_type, num_bits)
print('\nPTQ: '+title)
title_list.append(title)
# 设置量化表
if quant_type != 'INT':
plist = build_list(quant_type, num_bits, e_bits)
gol.set_value(plist)
# 判断是否需要载入
if load_ptq is True and osp.exists(ptq_file_prefix + title + '.pt'):
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.load_state_dict(torch.load(ptq_file_prefix + title + '.pt'))
model_ptq.to(device)
print('Successfully load ptq model: ' + title)
else:
model_ptq.load_state_dict(torch.load(full_file))
model_ptq.to(device)
model_ptq.quantize(quant_type,num_bits,e_bits)
model_ptq.eval()
direct_quantize(model_ptq, train_loader, device)
if store_ptq:
torch.save(model_ptq.state_dict(), ptq_file_prefix + title + '.pt')
model_ptq.freeze()
ptq_acc = quantize_inference(model_ptq, test_loader, device)
ptq_acc_list.append(ptq_acc)
acc_loss = (full_acc - ptq_acc) / full_acc
acc_loss_list.append(acc_loss)
# 获取计算量/参数量下的js-div
js_flops = 0.
js_param = 0.
for name, param in model_ptq.named_parameters():
if 'conv' not in name and 'fc' not in name:
continue
prefix = name.rsplit('.',1)[0]
layer_idx = layer.index(prefix)
name_idx = full_names.index(name)
layer_idx = layer.index(prefix)
ptq_param = param.data.cpu()
# 取L2范数
ptq_norm = F.normalize(ptq_param,p=2,dim=-1)
writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param)
js = js_div(ptq_norm,full_params[name_idx])
js /= full_par_num[name_idx]
js = js.item()
if js < 0.:
js = 0.
js_flops = js_flops + js * flop_ratio[layer_idx]
js_param = js_param + js * par_ratio[layer_idx]
js_flops_list.append(js_flops)
js_param_list.append(js_param)
print(title + ': js_flops: %f js_param: %f acc_loss: %f' % (js_flops, js_param, acc_loss))
# 写入xlsx
workbook = openpyxl.Workbook()
worksheet = workbook.active
worksheet.cell(row=1,column=1,value='FP32-acc')
worksheet.cell(row=1,column=2,value=full_acc)
worksheet.cell(row=3,column=1,value='title')
worksheet.cell(row=3,column=2,value='js_flops')
worksheet.cell(row=3,column=3,value='js_param')
worksheet.cell(row=3,column=4,value='ptq_acc')
worksheet.cell(row=3,column=5,value='acc_loss')
for i in range(len(title_list)):
worksheet.cell(row=i+4, column=1, value=title_list[i])
worksheet.cell(row=i+4, column=2, value=js_flops_list[i])
worksheet.cell(row=i+4, column=3, value=js_param_list[i])
worksheet.cell(row=i+4, column=4, value=ptq_acc_list[i])
worksheet.cell(row=i+4, column=5, value=acc_loss_list[i])
workbook.save('ptq_result.xlsx')
writer.close()
ft = open('ptq_result.txt','w')
print('title_list:',file=ft)
print(" ".join(title_list),file=ft)
print('js_flops_list:',file=ft)
print(" ".join(str(i) for i in js_flops_list), file=ft)
print('js_param_list:',file=ft)
print(" ".join(str(i) for i in js_param_list), file=ft)
print('ptq_acc_list:',file=ft)
print(" ".join(str(i) for i in ptq_acc_list), file=ft)
print('acc_loss_list:',file=ft)
print(" ".join(str(i) for i in acc_loss_list), file=ft)
ft.close()
title_list:
POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8
js_flops_list:
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import sys
class BNInception(nn.Module):
def __init__(self, num_classes=1000):
super(BNInception, self).__init__()
inplace = True
self.conv1_7x7_s2 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
self.conv1_7x7_s2_bn = nn.BatchNorm2d(64, affine=True)
self.conv1_relu_7x7 = nn.ReLU (inplace)
self.pool1_3x3_s2 = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
self.conv2_3x3_reduce = nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
self.conv2_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.conv2_relu_3x3_reduce = nn.ReLU (inplace)
self.conv2_3x3 = nn.Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.conv2_3x3_bn = nn.BatchNorm2d(192, affine=True)
self.conv2_relu_3x3 = nn.ReLU (inplace)
self.pool2_3x3_s2 = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
self.inception_3a_1x1 = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3a_1x1_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3a_relu_1x1 = nn.ReLU (inplace)
self.inception_3a_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3a_relu_3x3_reduce = nn.ReLU (inplace)
self.inception_3a_3x3 = nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3a_3x3_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3a_relu_3x3 = nn.ReLU (inplace)
self.inception_3a_double_3x3_reduce = nn.Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3a_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3a_relu_double_3x3_reduce = nn.ReLU (inplace)
self.inception_3a_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3a_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3a_relu_double_3x3_1 = nn.ReLU (inplace)
self.inception_3a_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3a_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3a_relu_double_3x3_2 = nn.ReLU (inplace)
self.inception_3a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_3a_pool_proj = nn.Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1))
self.inception_3a_pool_proj_bn = nn.BatchNorm2d(32, affine=True)
self.inception_3a_relu_pool_proj = nn.ReLU (inplace)
self.inception_3b_1x1 = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3b_1x1_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3b_relu_1x1 = nn.ReLU (inplace)
self.inception_3b_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3b_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3b_relu_3x3_reduce = nn.ReLU (inplace)
self.inception_3b_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3b_3x3_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3b_relu_3x3 = nn.ReLU (inplace)
self.inception_3b_double_3x3_reduce = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3b_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3b_relu_double_3x3_reduce = nn.ReLU (inplace)
self.inception_3b_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3b_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3b_relu_double_3x3_1 = nn.ReLU (inplace)
self.inception_3b_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3b_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3b_relu_double_3x3_2 = nn.ReLU (inplace)
self.inception_3b_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_3b_pool_proj = nn.Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3b_pool_proj_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3b_relu_pool_proj = nn.ReLU (inplace)
self.inception_3c_3x3_reduce = nn.Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_3c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
self.inception_3c_relu_3x3_reduce = nn.ReLU (inplace)
self.inception_3c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.inception_3c_3x3_bn = nn.BatchNorm2d(160, affine=True)
self.inception_3c_relu_3x3 = nn.ReLU (inplace)
self.inception_3c_double_3x3_reduce = nn.Conv2d(320, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_3c_double_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.inception_3c_relu_double_3x3_reduce = nn.ReLU (inplace)
self.inception_3c_double_3x3_1 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_3c_double_3x3_1_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3c_relu_double_3x3_1 = nn.ReLU (inplace)
self.inception_3c_double_3x3_2 = nn.Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.inception_3c_double_3x3_2_bn = nn.BatchNorm2d(96, affine=True)
self.inception_3c_relu_double_3x3_2 = nn.ReLU (inplace)
self.inception_3c_pool = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
self.inception_4a_1x1 = nn.Conv2d(576, 224, kernel_size=(1, 1), stride=(1, 1))
self.inception_4a_1x1_bn = nn.BatchNorm2d(224, affine=True)
self.inception_4a_relu_1x1 = nn.ReLU (inplace)
self.inception_4a_3x3_reduce = nn.Conv2d(576, 64, kernel_size=(1, 1), stride=(1, 1))
self.inception_4a_3x3_reduce_bn = nn.BatchNorm2d(64, affine=True)
self.inception_4a_relu_3x3_reduce = nn.ReLU (inplace)
self.inception_4a_3x3 = nn.Conv2d(64, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4a_3x3_bn = nn.BatchNorm2d(96, affine=True)
self.inception_4a_relu_3x3 = nn.ReLU (inplace)
self.inception_4a_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
self.inception_4a_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
self.inception_4a_relu_double_3x3_reduce = nn.ReLU (inplace)
self.inception_4a_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4a_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4a_relu_double_3x3_1 = nn.ReLU (inplace)
self.inception_4a_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4a_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4a_relu_double_3x3_2 = nn.ReLU (inplace)
self.inception_4a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_4a_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4a_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4a_relu_pool_proj = nn.ReLU (inplace)
self.inception_4b_1x1 = nn.Conv2d(576, 192, kernel_size=(1, 1), stride=(1, 1))
self.inception_4b_1x1_bn = nn.BatchNorm2d(192, affine=True)
self.inception_4b_relu_1x1 = nn.ReLU (inplace)
self.inception_4b_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
self.inception_4b_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
self.inception_4b_relu_3x3_reduce = nn.ReLU (inplace)
self.inception_4b_3x3 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4b_3x3_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4b_relu_3x3 = nn.ReLU (inplace)
self.inception_4b_double_3x3_reduce = nn.Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1))
self.inception_4b_double_3x3_reduce_bn = nn.BatchNorm2d(96, affine=True)
self.inception_4b_relu_double_3x3_reduce = nn.ReLU (inplace)
self.inception_4b_double_3x3_1 = nn.Conv2d(96, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4b_double_3x3_1_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4b_relu_double_3x3_1 = nn.ReLU (inplace)
self.inception_4b_double_3x3_2 = nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4b_double_3x3_2_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4b_relu_double_3x3_2 = nn.ReLU (inplace)
self.inception_4b_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_4b_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4b_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4b_relu_pool_proj = nn.ReLU (inplace)
self.inception_4c_1x1 = nn.Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1))
self.inception_4c_1x1_bn = nn.BatchNorm2d(160, affine=True)
self.inception_4c_relu_1x1 = nn.ReLU (inplace)
self.inception_4c_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4c_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4c_relu_3x3_reduce = nn.ReLU (inplace)
self.inception_4c_3x3 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4c_3x3_bn = nn.BatchNorm2d(160, affine=True)
self.inception_4c_relu_3x3 = nn.ReLU (inplace)
self.inception_4c_double_3x3_reduce = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4c_double_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4c_relu_double_3x3_reduce = nn.ReLU (inplace)
self.inception_4c_double_3x3_1 = nn.Conv2d(128, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4c_double_3x3_1_bn = nn.BatchNorm2d(160, affine=True)
self.inception_4c_relu_double_3x3_1 = nn.ReLU (inplace)
self.inception_4c_double_3x3_2 = nn.Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4c_double_3x3_2_bn = nn.BatchNorm2d(160, affine=True)
self.inception_4c_relu_double_3x3_2 = nn.ReLU (inplace)
self.inception_4c_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_4c_pool_proj = nn.Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4c_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4c_relu_pool_proj = nn.ReLU (inplace)
self.inception_4d_1x1 = nn.Conv2d(608, 96, kernel_size=(1, 1), stride=(1, 1))
self.inception_4d_1x1_bn = nn.BatchNorm2d(96, affine=True)
self.inception_4d_relu_1x1 = nn.ReLU (inplace)
self.inception_4d_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4d_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4d_relu_3x3_reduce = nn.ReLU (inplace)
self.inception_4d_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4d_3x3_bn = nn.BatchNorm2d(192, affine=True)
self.inception_4d_relu_3x3 = nn.ReLU (inplace)
self.inception_4d_double_3x3_reduce = nn.Conv2d(608, 160, kernel_size=(1, 1), stride=(1, 1))
self.inception_4d_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True)
self.inception_4d_relu_double_3x3_reduce = nn.ReLU (inplace)
self.inception_4d_double_3x3_1 = nn.Conv2d(160, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4d_double_3x3_1_bn = nn.BatchNorm2d(192, affine=True)
self.inception_4d_relu_double_3x3_1 = nn.ReLU (inplace)
self.inception_4d_double_3x3_2 = nn.Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4d_double_3x3_2_bn = nn.BatchNorm2d(192, affine=True)
self.inception_4d_relu_double_3x3_2 = nn.ReLU (inplace)
self.inception_4d_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_4d_pool_proj = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4d_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4d_relu_pool_proj = nn.ReLU (inplace)
self.inception_4e_3x3_reduce = nn.Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_4e_3x3_reduce_bn = nn.BatchNorm2d(128, affine=True)
self.inception_4e_relu_3x3_reduce = nn.ReLU (inplace)
self.inception_4e_3x3 = nn.Conv2d(128, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.inception_4e_3x3_bn = nn.BatchNorm2d(192, affine=True)
self.inception_4e_relu_3x3 = nn.ReLU (inplace)
self.inception_4e_double_3x3_reduce = nn.Conv2d(608, 192, kernel_size=(1, 1), stride=(1, 1))
self.inception_4e_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
self.inception_4e_relu_double_3x3_reduce = nn.ReLU (inplace)
self.inception_4e_double_3x3_1 = nn.Conv2d(192, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_4e_double_3x3_1_bn = nn.BatchNorm2d(256, affine=True)
self.inception_4e_relu_double_3x3_1 = nn.ReLU (inplace)
self.inception_4e_double_3x3_2 = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.inception_4e_double_3x3_2_bn = nn.BatchNorm2d(256, affine=True)
self.inception_4e_relu_double_3x3_2 = nn.ReLU (inplace)
self.inception_4e_pool = nn.MaxPool2d ((3, 3), stride=(2, 2), dilation=(1, 1), ceil_mode=True)
self.inception_5a_1x1 = nn.Conv2d(1056, 352, kernel_size=(1, 1), stride=(1, 1))
self.inception_5a_1x1_bn = nn.BatchNorm2d(352, affine=True)
self.inception_5a_relu_1x1 = nn.ReLU (inplace)
self.inception_5a_3x3_reduce = nn.Conv2d(1056, 192, kernel_size=(1, 1), stride=(1, 1))
self.inception_5a_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
self.inception_5a_relu_3x3_reduce = nn.ReLU (inplace)
self.inception_5a_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_5a_3x3_bn = nn.BatchNorm2d(320, affine=True)
self.inception_5a_relu_3x3 = nn.ReLU (inplace)
self.inception_5a_double_3x3_reduce = nn.Conv2d(1056, 160, kernel_size=(1, 1), stride=(1, 1))
self.inception_5a_double_3x3_reduce_bn = nn.BatchNorm2d(160, affine=True)
self.inception_5a_relu_double_3x3_reduce = nn.ReLU (inplace)
self.inception_5a_double_3x3_1 = nn.Conv2d(160, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_5a_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True)
self.inception_5a_relu_double_3x3_1 = nn.ReLU (inplace)
self.inception_5a_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_5a_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True)
self.inception_5a_relu_double_3x3_2 = nn.ReLU (inplace)
self.inception_5a_pool = nn.AvgPool2d (3, stride=1, padding=1, ceil_mode=True, count_include_pad=True)
self.inception_5a_pool_proj = nn.Conv2d(1056, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_5a_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_5a_relu_pool_proj = nn.ReLU (inplace)
self.inception_5b_1x1 = nn.Conv2d(1024, 352, kernel_size=(1, 1), stride=(1, 1))
self.inception_5b_1x1_bn = nn.BatchNorm2d(352, affine=True)
self.inception_5b_relu_1x1 = nn.ReLU (inplace)
self.inception_5b_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1))
self.inception_5b_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
self.inception_5b_relu_3x3_reduce = nn.ReLU (inplace)
self.inception_5b_3x3 = nn.Conv2d(192, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_5b_3x3_bn = nn.BatchNorm2d(320, affine=True)
self.inception_5b_relu_3x3 = nn.ReLU (inplace)
self.inception_5b_double_3x3_reduce = nn.Conv2d(1024, 192, kernel_size=(1, 1), stride=(1, 1))
self.inception_5b_double_3x3_reduce_bn = nn.BatchNorm2d(192, affine=True)
self.inception_5b_relu_double_3x3_reduce = nn.ReLU (inplace)
self.inception_5b_double_3x3_1 = nn.Conv2d(192, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_5b_double_3x3_1_bn = nn.BatchNorm2d(224, affine=True)
self.inception_5b_relu_double_3x3_1 = nn.ReLU (inplace)
self.inception_5b_double_3x3_2 = nn.Conv2d(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
self.inception_5b_double_3x3_2_bn = nn.BatchNorm2d(224, affine=True)
self.inception_5b_relu_double_3x3_2 = nn.ReLU (inplace)
self.inception_5b_pool = nn.MaxPool2d ((3, 3), stride=(1, 1), padding=(1, 1), dilation=(1, 1), ceil_mode=True)
self.inception_5b_pool_proj = nn.Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1))
self.inception_5b_pool_proj_bn = nn.BatchNorm2d(128, affine=True)
self.inception_5b_relu_pool_proj = nn.ReLU (inplace)
self.last_linear = nn.Linear (1024, num_classes)
def features(self, input):
conv1_7x7_s2_out = self.conv1_7x7_s2(input)
conv1_7x7_s2_bn_out = self.conv1_7x7_s2_bn(conv1_7x7_s2_out)
conv1_relu_7x7_out = self.conv1_relu_7x7(conv1_7x7_s2_bn_out)
pool1_3x3_s2_out = self.pool1_3x3_s2(conv1_relu_7x7_out)
conv2_3x3_reduce_out = self.conv2_3x3_reduce(pool1_3x3_s2_out)
conv2_3x3_reduce_bn_out = self.conv2_3x3_reduce_bn(conv2_3x3_reduce_out)
conv2_relu_3x3_reduce_out = self.conv2_relu_3x3_reduce(conv2_3x3_reduce_bn_out)
conv2_3x3_out = self.conv2_3x3(conv2_relu_3x3_reduce_out)
conv2_3x3_bn_out = self.conv2_3x3_bn(conv2_3x3_out)
conv2_relu_3x3_out = self.conv2_relu_3x3(conv2_3x3_bn_out)
pool2_3x3_s2_out = self.pool2_3x3_s2(conv2_relu_3x3_out)
inception_3a_1x1_out = self.inception_3a_1x1(pool2_3x3_s2_out)
inception_3a_1x1_bn_out = self.inception_3a_1x1_bn(inception_3a_1x1_out)
inception_3a_relu_1x1_out = self.inception_3a_relu_1x1(inception_3a_1x1_bn_out)
inception_3a_3x3_reduce_out = self.inception_3a_3x3_reduce(pool2_3x3_s2_out)
inception_3a_3x3_reduce_bn_out = self.inception_3a_3x3_reduce_bn(inception_3a_3x3_reduce_out)
inception_3a_relu_3x3_reduce_out = self.inception_3a_relu_3x3_reduce(inception_3a_3x3_reduce_bn_out)
inception_3a_3x3_out = self.inception_3a_3x3(inception_3a_relu_3x3_reduce_out)
inception_3a_3x3_bn_out = self.inception_3a_3x3_bn(inception_3a_3x3_out)
inception_3a_relu_3x3_out = self.inception_3a_relu_3x3(inception_3a_3x3_bn_out)
inception_3a_double_3x3_reduce_out = self.inception_3a_double_3x3_reduce(pool2_3x3_s2_out)
inception_3a_double_3x3_reduce_bn_out = self.inception_3a_double_3x3_reduce_bn(inception_3a_double_3x3_reduce_out)
inception_3a_relu_double_3x3_reduce_out = self.inception_3a_relu_double_3x3_reduce(inception_3a_double_3x3_reduce_bn_out)
inception_3a_double_3x3_1_out = self.inception_3a_double_3x3_1(inception_3a_relu_double_3x3_reduce_out)
inception_3a_double_3x3_1_bn_out = self.inception_3a_double_3x3_1_bn(inception_3a_double_3x3_1_out)
inception_3a_relu_double_3x3_1_out = self.inception_3a_relu_double_3x3_1(inception_3a_double_3x3_1_bn_out)
inception_3a_double_3x3_2_out = self.inception_3a_double_3x3_2(inception_3a_relu_double_3x3_1_out)
inception_3a_double_3x3_2_bn_out = self.inception_3a_double_3x3_2_bn(inception_3a_double_3x3_2_out)
inception_3a_relu_double_3x3_2_out = self.inception_3a_relu_double_3x3_2(inception_3a_double_3x3_2_bn_out)
inception_3a_pool_out = self.inception_3a_pool(pool2_3x3_s2_out)
inception_3a_pool_proj_out = self.inception_3a_pool_proj(inception_3a_pool_out)
inception_3a_pool_proj_bn_out = self.inception_3a_pool_proj_bn(inception_3a_pool_proj_out)
inception_3a_relu_pool_proj_out = self.inception_3a_relu_pool_proj(inception_3a_pool_proj_bn_out)
inception_3a_output_out = torch.cat([inception_3a_relu_1x1_out,inception_3a_relu_3x3_out,inception_3a_relu_double_3x3_2_out ,inception_3a_relu_pool_proj_out], 1)
inception_3b_1x1_out = self.inception_3b_1x1(inception_3a_output_out)
inception_3b_1x1_bn_out = self.inception_3b_1x1_bn(inception_3b_1x1_out)
inception_3b_relu_1x1_out = self.inception_3b_relu_1x1(inception_3b_1x1_bn_out)
inception_3b_3x3_reduce_out = self.inception_3b_3x3_reduce(inception_3a_output_out)
inception_3b_3x3_reduce_bn_out = self.inception_3b_3x3_reduce_bn(inception_3b_3x3_reduce_out)
inception_3b_relu_3x3_reduce_out = self.inception_3b_relu_3x3_reduce(inception_3b_3x3_reduce_bn_out)
inception_3b_3x3_out = self.inception_3b_3x3(inception_3b_relu_3x3_reduce_out)
inception_3b_3x3_bn_out = self.inception_3b_3x3_bn(inception_3b_3x3_out)
inception_3b_relu_3x3_out = self.inception_3b_relu_3x3(inception_3b_3x3_bn_out)
inception_3b_double_3x3_reduce_out = self.inception_3b_double_3x3_reduce(inception_3a_output_out)
inception_3b_double_3x3_reduce_bn_out = self.inception_3b_double_3x3_reduce_bn(inception_3b_double_3x3_reduce_out)
inception_3b_relu_double_3x3_reduce_out = self.inception_3b_relu_double_3x3_reduce(inception_3b_double_3x3_reduce_bn_out)
inception_3b_double_3x3_1_out = self.inception_3b_double_3x3_1(inception_3b_relu_double_3x3_reduce_out)
inception_3b_double_3x3_1_bn_out = self.inception_3b_double_3x3_1_bn(inception_3b_double_3x3_1_out)
inception_3b_relu_double_3x3_1_out = self.inception_3b_relu_double_3x3_1(inception_3b_double_3x3_1_bn_out)
inception_3b_double_3x3_2_out = self.inception_3b_double_3x3_2(inception_3b_relu_double_3x3_1_out)
inception_3b_double_3x3_2_bn_out = self.inception_3b_double_3x3_2_bn(inception_3b_double_3x3_2_out)
inception_3b_relu_double_3x3_2_out = self.inception_3b_relu_double_3x3_2(inception_3b_double_3x3_2_bn_out)
inception_3b_pool_out = self.inception_3b_pool(inception_3a_output_out)
inception_3b_pool_proj_out = self.inception_3b_pool_proj(inception_3b_pool_out)
inception_3b_pool_proj_bn_out = self.inception_3b_pool_proj_bn(inception_3b_pool_proj_out)
inception_3b_relu_pool_proj_out = self.inception_3b_relu_pool_proj(inception_3b_pool_proj_bn_out)
inception_3b_output_out = torch.cat([inception_3b_relu_1x1_out,inception_3b_relu_3x3_out,inception_3b_relu_double_3x3_2_out,inception_3b_relu_pool_proj_out], 1)
inception_3c_3x3_reduce_out = self.inception_3c_3x3_reduce(inception_3b_output_out)
inception_3c_3x3_reduce_bn_out = self.inception_3c_3x3_reduce_bn(inception_3c_3x3_reduce_out)
inception_3c_relu_3x3_reduce_out = self.inception_3c_relu_3x3_reduce(inception_3c_3x3_reduce_bn_out)
inception_3c_3x3_out = self.inception_3c_3x3(inception_3c_relu_3x3_reduce_out)
inception_3c_3x3_bn_out = self.inception_3c_3x3_bn(inception_3c_3x3_out)
inception_3c_relu_3x3_out = self.inception_3c_relu_3x3(inception_3c_3x3_bn_out)
inception_3c_double_3x3_reduce_out = self.inception_3c_double_3x3_reduce(inception_3b_output_out)
inception_3c_double_3x3_reduce_bn_out = self.inception_3c_double_3x3_reduce_bn(inception_3c_double_3x3_reduce_out)
inception_3c_relu_double_3x3_reduce_out = self.inception_3c_relu_double_3x3_reduce(inception_3c_double_3x3_reduce_bn_out)
inception_3c_double_3x3_1_out = self.inception_3c_double_3x3_1(inception_3c_relu_double_3x3_reduce_out)
inception_3c_double_3x3_1_bn_out = self.inception_3c_double_3x3_1_bn(inception_3c_double_3x3_1_out)
inception_3c_relu_double_3x3_1_out = self.inception_3c_relu_double_3x3_1(inception_3c_double_3x3_1_bn_out)
inception_3c_double_3x3_2_out = self.inception_3c_double_3x3_2(inception_3c_relu_double_3x3_1_out)
inception_3c_double_3x3_2_bn_out = self.inception_3c_double_3x3_2_bn(inception_3c_double_3x3_2_out)
inception_3c_relu_double_3x3_2_out = self.inception_3c_relu_double_3x3_2(inception_3c_double_3x3_2_bn_out)
inception_3c_pool_out = self.inception_3c_pool(inception_3b_output_out)
inception_3c_output_out = torch.cat([inception_3c_relu_3x3_out,inception_3c_relu_double_3x3_2_out,inception_3c_pool_out], 1)
inception_4a_1x1_out = self.inception_4a_1x1(inception_3c_output_out)
inception_4a_1x1_bn_out = self.inception_4a_1x1_bn(inception_4a_1x1_out)
inception_4a_relu_1x1_out = self.inception_4a_relu_1x1(inception_4a_1x1_bn_out)
inception_4a_3x3_reduce_out = self.inception_4a_3x3_reduce(inception_3c_output_out)
inception_4a_3x3_reduce_bn_out = self.inception_4a_3x3_reduce_bn(inception_4a_3x3_reduce_out)
inception_4a_relu_3x3_reduce_out = self.inception_4a_relu_3x3_reduce(inception_4a_3x3_reduce_bn_out)
inception_4a_3x3_out = self.inception_4a_3x3(inception_4a_relu_3x3_reduce_out)
inception_4a_3x3_bn_out = self.inception_4a_3x3_bn(inception_4a_3x3_out)
inception_4a_relu_3x3_out = self.inception_4a_relu_3x3(inception_4a_3x3_bn_out)
inception_4a_double_3x3_reduce_out = self.inception_4a_double_3x3_reduce(inception_3c_output_out)
inception_4a_double_3x3_reduce_bn_out = self.inception_4a_double_3x3_reduce_bn(inception_4a_double_3x3_reduce_out)
inception_4a_relu_double_3x3_reduce_out = self.inception_4a_relu_double_3x3_reduce(inception_4a_double_3x3_reduce_bn_out)
inception_4a_double_3x3_1_out = self.inception_4a_double_3x3_1(inception_4a_relu_double_3x3_reduce_out)
inception_4a_double_3x3_1_bn_out = self.inception_4a_double_3x3_1_bn(inception_4a_double_3x3_1_out)
inception_4a_relu_double_3x3_1_out = self.inception_4a_relu_double_3x3_1(inception_4a_double_3x3_1_bn_out)
inception_4a_double_3x3_2_out = self.inception_4a_double_3x3_2(inception_4a_relu_double_3x3_1_out)
inception_4a_double_3x3_2_bn_out = self.inception_4a_double_3x3_2_bn(inception_4a_double_3x3_2_out)
inception_4a_relu_double_3x3_2_out = self.inception_4a_relu_double_3x3_2(inception_4a_double_3x3_2_bn_out)
inception_4a_pool_out = self.inception_4a_pool(inception_3c_output_out)
inception_4a_pool_proj_out = self.inception_4a_pool_proj(inception_4a_pool_out)
inception_4a_pool_proj_bn_out = self.inception_4a_pool_proj_bn(inception_4a_pool_proj_out)
inception_4a_relu_pool_proj_out = self.inception_4a_relu_pool_proj(inception_4a_pool_proj_bn_out)
inception_4a_output_out = torch.cat([inception_4a_relu_1x1_out,inception_4a_relu_3x3_out,inception_4a_relu_double_3x3_2_out,inception_4a_relu_pool_proj_out], 1)
inception_4b_1x1_out = self.inception_4b_1x1(inception_4a_output_out)
inception_4b_1x1_bn_out = self.inception_4b_1x1_bn(inception_4b_1x1_out)
inception_4b_relu_1x1_out = self.inception_4b_relu_1x1(inception_4b_1x1_bn_out)
inception_4b_3x3_reduce_out = self.inception_4b_3x3_reduce(inception_4a_output_out)
inception_4b_3x3_reduce_bn_out = self.inception_4b_3x3_reduce_bn(inception_4b_3x3_reduce_out)
inception_4b_relu_3x3_reduce_out = self.inception_4b_relu_3x3_reduce(inception_4b_3x3_reduce_bn_out)
inception_4b_3x3_out = self.inception_4b_3x3(inception_4b_relu_3x3_reduce_out)
inception_4b_3x3_bn_out = self.inception_4b_3x3_bn(inception_4b_3x3_out)
inception_4b_relu_3x3_out = self.inception_4b_relu_3x3(inception_4b_3x3_bn_out)
inception_4b_double_3x3_reduce_out = self.inception_4b_double_3x3_reduce(inception_4a_output_out)
inception_4b_double_3x3_reduce_bn_out = self.inception_4b_double_3x3_reduce_bn(inception_4b_double_3x3_reduce_out)
inception_4b_relu_double_3x3_reduce_out = self.inception_4b_relu_double_3x3_reduce(inception_4b_double_3x3_reduce_bn_out)
inception_4b_double_3x3_1_out = self.inception_4b_double_3x3_1(inception_4b_relu_double_3x3_reduce_out)
inception_4b_double_3x3_1_bn_out = self.inception_4b_double_3x3_1_bn(inception_4b_double_3x3_1_out)
inception_4b_relu_double_3x3_1_out = self.inception_4b_relu_double_3x3_1(inception_4b_double_3x3_1_bn_out)
inception_4b_double_3x3_2_out = self.inception_4b_double_3x3_2(inception_4b_relu_double_3x3_1_out)
inception_4b_double_3x3_2_bn_out = self.inception_4b_double_3x3_2_bn(inception_4b_double_3x3_2_out)
inception_4b_relu_double_3x3_2_out = self.inception_4b_relu_double_3x3_2(inception_4b_double_3x3_2_bn_out)
inception_4b_pool_out = self.inception_4b_pool(inception_4a_output_out)
inception_4b_pool_proj_out = self.inception_4b_pool_proj(inception_4b_pool_out)
inception_4b_pool_proj_bn_out = self.inception_4b_pool_proj_bn(inception_4b_pool_proj_out)
inception_4b_relu_pool_proj_out = self.inception_4b_relu_pool_proj(inception_4b_pool_proj_bn_out)
inception_4b_output_out = torch.cat([inception_4b_relu_1x1_out,inception_4b_relu_3x3_out,inception_4b_relu_double_3x3_2_out,inception_4b_relu_pool_proj_out], 1)
inception_4c_1x1_out = self.inception_4c_1x1(inception_4b_output_out)
inception_4c_1x1_bn_out = self.inception_4c_1x1_bn(inception_4c_1x1_out)
inception_4c_relu_1x1_out = self.inception_4c_relu_1x1(inception_4c_1x1_bn_out)
inception_4c_3x3_reduce_out = self.inception_4c_3x3_reduce(inception_4b_output_out)
inception_4c_3x3_reduce_bn_out = self.inception_4c_3x3_reduce_bn(inception_4c_3x3_reduce_out)
inception_4c_relu_3x3_reduce_out = self.inception_4c_relu_3x3_reduce(inception_4c_3x3_reduce_bn_out)
inception_4c_3x3_out = self.inception_4c_3x3(inception_4c_relu_3x3_reduce_out)
inception_4c_3x3_bn_out = self.inception_4c_3x3_bn(inception_4c_3x3_out)
inception_4c_relu_3x3_out = self.inception_4c_relu_3x3(inception_4c_3x3_bn_out)
inception_4c_double_3x3_reduce_out = self.inception_4c_double_3x3_reduce(inception_4b_output_out)
inception_4c_double_3x3_reduce_bn_out = self.inception_4c_double_3x3_reduce_bn(inception_4c_double_3x3_reduce_out)
inception_4c_relu_double_3x3_reduce_out = self.inception_4c_relu_double_3x3_reduce(inception_4c_double_3x3_reduce_bn_out)
inception_4c_double_3x3_1_out = self.inception_4c_double_3x3_1(inception_4c_relu_double_3x3_reduce_out)
inception_4c_double_3x3_1_bn_out = self.inception_4c_double_3x3_1_bn(inception_4c_double_3x3_1_out)
inception_4c_relu_double_3x3_1_out = self.inception_4c_relu_double_3x3_1(inception_4c_double_3x3_1_bn_out)
inception_4c_double_3x3_2_out = self.inception_4c_double_3x3_2(inception_4c_relu_double_3x3_1_out)
inception_4c_double_3x3_2_bn_out = self.inception_4c_double_3x3_2_bn(inception_4c_double_3x3_2_out)
inception_4c_relu_double_3x3_2_out = self.inception_4c_relu_double_3x3_2(inception_4c_double_3x3_2_bn_out)
inception_4c_pool_out = self.inception_4c_pool(inception_4b_output_out)
inception_4c_pool_proj_out = self.inception_4c_pool_proj(inception_4c_pool_out)
inception_4c_pool_proj_bn_out = self.inception_4c_pool_proj_bn(inception_4c_pool_proj_out)
inception_4c_relu_pool_proj_out = self.inception_4c_relu_pool_proj(inception_4c_pool_proj_bn_out)
inception_4c_output_out = torch.cat([inception_4c_relu_1x1_out,inception_4c_relu_3x3_out,inception_4c_relu_double_3x3_2_out,inception_4c_relu_pool_proj_out], 1)
inception_4d_1x1_out = self.inception_4d_1x1(inception_4c_output_out)
inception_4d_1x1_bn_out = self.inception_4d_1x1_bn(inception_4d_1x1_out)
inception_4d_relu_1x1_out = self.inception_4d_relu_1x1(inception_4d_1x1_bn_out)
inception_4d_3x3_reduce_out = self.inception_4d_3x3_reduce(inception_4c_output_out)
inception_4d_3x3_reduce_bn_out = self.inception_4d_3x3_reduce_bn(inception_4d_3x3_reduce_out)
inception_4d_relu_3x3_reduce_out = self.inception_4d_relu_3x3_reduce(inception_4d_3x3_reduce_bn_out)
inception_4d_3x3_out = self.inception_4d_3x3(inception_4d_relu_3x3_reduce_out)
inception_4d_3x3_bn_out = self.inception_4d_3x3_bn(inception_4d_3x3_out)
inception_4d_relu_3x3_out = self.inception_4d_relu_3x3(inception_4d_3x3_bn_out)
inception_4d_double_3x3_reduce_out = self.inception_4d_double_3x3_reduce(inception_4c_output_out)
inception_4d_double_3x3_reduce_bn_out = self.inception_4d_double_3x3_reduce_bn(inception_4d_double_3x3_reduce_out)
inception_4d_relu_double_3x3_reduce_out = self.inception_4d_relu_double_3x3_reduce(inception_4d_double_3x3_reduce_bn_out)
inception_4d_double_3x3_1_out = self.inception_4d_double_3x3_1(inception_4d_relu_double_3x3_reduce_out)
inception_4d_double_3x3_1_bn_out = self.inception_4d_double_3x3_1_bn(inception_4d_double_3x3_1_out)
inception_4d_relu_double_3x3_1_out = self.inception_4d_relu_double_3x3_1(inception_4d_double_3x3_1_bn_out)
inception_4d_double_3x3_2_out = self.inception_4d_double_3x3_2(inception_4d_relu_double_3x3_1_out)
inception_4d_double_3x3_2_bn_out = self.inception_4d_double_3x3_2_bn(inception_4d_double_3x3_2_out)
inception_4d_relu_double_3x3_2_out = self.inception_4d_relu_double_3x3_2(inception_4d_double_3x3_2_bn_out)
inception_4d_pool_out = self.inception_4d_pool(inception_4c_output_out)
inception_4d_pool_proj_out = self.inception_4d_pool_proj(inception_4d_pool_out)
inception_4d_pool_proj_bn_out = self.inception_4d_pool_proj_bn(inception_4d_pool_proj_out)
inception_4d_relu_pool_proj_out = self.inception_4d_relu_pool_proj(inception_4d_pool_proj_bn_out)
inception_4d_output_out = torch.cat([inception_4d_relu_1x1_out,inception_4d_relu_3x3_out,inception_4d_relu_double_3x3_2_out,inception_4d_relu_pool_proj_out], 1)
inception_4e_3x3_reduce_out = self.inception_4e_3x3_reduce(inception_4d_output_out)
inception_4e_3x3_reduce_bn_out = self.inception_4e_3x3_reduce_bn(inception_4e_3x3_reduce_out)
inception_4e_relu_3x3_reduce_out = self.inception_4e_relu_3x3_reduce(inception_4e_3x3_reduce_bn_out)
inception_4e_3x3_out = self.inception_4e_3x3(inception_4e_relu_3x3_reduce_out)
inception_4e_3x3_bn_out = self.inception_4e_3x3_bn(inception_4e_3x3_out)
inception_4e_relu_3x3_out = self.inception_4e_relu_3x3(inception_4e_3x3_bn_out)
inception_4e_double_3x3_reduce_out = self.inception_4e_double_3x3_reduce(inception_4d_output_out)
inception_4e_double_3x3_reduce_bn_out = self.inception_4e_double_3x3_reduce_bn(inception_4e_double_3x3_reduce_out)
inception_4e_relu_double_3x3_reduce_out = self.inception_4e_relu_double_3x3_reduce(inception_4e_double_3x3_reduce_bn_out)
inception_4e_double_3x3_1_out = self.inception_4e_double_3x3_1(inception_4e_relu_double_3x3_reduce_out)
inception_4e_double_3x3_1_bn_out = self.inception_4e_double_3x3_1_bn(inception_4e_double_3x3_1_out)
inception_4e_relu_double_3x3_1_out = self.inception_4e_relu_double_3x3_1(inception_4e_double_3x3_1_bn_out)
inception_4e_double_3x3_2_out = self.inception_4e_double_3x3_2(inception_4e_relu_double_3x3_1_out)
inception_4e_double_3x3_2_bn_out = self.inception_4e_double_3x3_2_bn(inception_4e_double_3x3_2_out)
inception_4e_relu_double_3x3_2_out = self.inception_4e_relu_double_3x3_2(inception_4e_double_3x3_2_bn_out)
inception_4e_pool_out = self.inception_4e_pool(inception_4d_output_out)
inception_4e_output_out = torch.cat([inception_4e_relu_3x3_out,inception_4e_relu_double_3x3_2_out,inception_4e_pool_out], 1)
inception_5a_1x1_out = self.inception_5a_1x1(inception_4e_output_out)
inception_5a_1x1_bn_out = self.inception_5a_1x1_bn(inception_5a_1x1_out)
inception_5a_relu_1x1_out = self.inception_5a_relu_1x1(inception_5a_1x1_bn_out)
inception_5a_3x3_reduce_out = self.inception_5a_3x3_reduce(inception_4e_output_out)
inception_5a_3x3_reduce_bn_out = self.inception_5a_3x3_reduce_bn(inception_5a_3x3_reduce_out)
inception_5a_relu_3x3_reduce_out = self.inception_5a_relu_3x3_reduce(inception_5a_3x3_reduce_bn_out)
inception_5a_3x3_out = self.inception_5a_3x3(inception_5a_relu_3x3_reduce_out)
inception_5a_3x3_bn_out = self.inception_5a_3x3_bn(inception_5a_3x3_out)
inception_5a_relu_3x3_out = self.inception_5a_relu_3x3(inception_5a_3x3_bn_out)
inception_5a_double_3x3_reduce_out = self.inception_5a_double_3x3_reduce(inception_4e_output_out)
inception_5a_double_3x3_reduce_bn_out = self.inception_5a_double_3x3_reduce_bn(inception_5a_double_3x3_reduce_out)
inception_5a_relu_double_3x3_reduce_out = self.inception_5a_relu_double_3x3_reduce(inception_5a_double_3x3_reduce_bn_out)
inception_5a_double_3x3_1_out = self.inception_5a_double_3x3_1(inception_5a_relu_double_3x3_reduce_out)
inception_5a_double_3x3_1_bn_out = self.inception_5a_double_3x3_1_bn(inception_5a_double_3x3_1_out)
inception_5a_relu_double_3x3_1_out = self.inception_5a_relu_double_3x3_1(inception_5a_double_3x3_1_bn_out)
inception_5a_double_3x3_2_out = self.inception_5a_double_3x3_2(inception_5a_relu_double_3x3_1_out)
inception_5a_double_3x3_2_bn_out = self.inception_5a_double_3x3_2_bn(inception_5a_double_3x3_2_out)
inception_5a_relu_double_3x3_2_out = self.inception_5a_relu_double_3x3_2(inception_5a_double_3x3_2_bn_out)
inception_5a_pool_out = self.inception_5a_pool(inception_4e_output_out)
inception_5a_pool_proj_out = self.inception_5a_pool_proj(inception_5a_pool_out)
inception_5a_pool_proj_bn_out = self.inception_5a_pool_proj_bn(inception_5a_pool_proj_out)
inception_5a_relu_pool_proj_out = self.inception_5a_relu_pool_proj(inception_5a_pool_proj_bn_out)
inception_5a_output_out = torch.cat([inception_5a_relu_1x1_out,inception_5a_relu_3x3_out,inception_5a_relu_double_3x3_2_out,inception_5a_relu_pool_proj_out], 1)
inception_5b_1x1_out = self.inception_5b_1x1(inception_5a_output_out)
inception_5b_1x1_bn_out = self.inception_5b_1x1_bn(inception_5b_1x1_out)
inception_5b_relu_1x1_out = self.inception_5b_relu_1x1(inception_5b_1x1_bn_out)
inception_5b_3x3_reduce_out = self.inception_5b_3x3_reduce(inception_5a_output_out)
inception_5b_3x3_reduce_bn_out = self.inception_5b_3x3_reduce_bn(inception_5b_3x3_reduce_out)
inception_5b_relu_3x3_reduce_out = self.inception_5b_relu_3x3_reduce(inception_5b_3x3_reduce_bn_out)
inception_5b_3x3_out = self.inception_5b_3x3(inception_5b_relu_3x3_reduce_out)
inception_5b_3x3_bn_out = self.inception_5b_3x3_bn(inception_5b_3x3_out)
inception_5b_relu_3x3_out = self.inception_5b_relu_3x3(inception_5b_3x3_bn_out)
inception_5b_double_3x3_reduce_out = self.inception_5b_double_3x3_reduce(inception_5a_output_out)
inception_5b_double_3x3_reduce_bn_out = self.inception_5b_double_3x3_reduce_bn(inception_5b_double_3x3_reduce_out)
inception_5b_relu_double_3x3_reduce_out = self.inception_5b_relu_double_3x3_reduce(inception_5b_double_3x3_reduce_bn_out)
inception_5b_double_3x3_1_out = self.inception_5b_double_3x3_1(inception_5b_relu_double_3x3_reduce_out)
inception_5b_double_3x3_1_bn_out = self.inception_5b_double_3x3_1_bn(inception_5b_double_3x3_1_out)
inception_5b_relu_double_3x3_1_out = self.inception_5b_relu_double_3x3_1(inception_5b_double_3x3_1_bn_out)
inception_5b_double_3x3_2_out = self.inception_5b_double_3x3_2(inception_5b_relu_double_3x3_1_out)
inception_5b_double_3x3_2_bn_out = self.inception_5b_double_3x3_2_bn(inception_5b_double_3x3_2_out)
inception_5b_relu_double_3x3_2_out = self.inception_5b_relu_double_3x3_2(inception_5b_double_3x3_2_bn_out)
inception_5b_pool_out = self.inception_5b_pool(inception_5a_output_out)
inception_5b_pool_proj_out = self.inception_5b_pool_proj(inception_5b_pool_out)
inception_5b_pool_proj_bn_out = self.inception_5b_pool_proj_bn(inception_5b_pool_proj_out)
inception_5b_relu_pool_proj_out = self.inception_5b_relu_pool_proj(inception_5b_pool_proj_bn_out)
inception_5b_output_out = torch.cat([inception_5b_relu_1x1_out,inception_5b_relu_3x3_out,inception_5b_relu_double_3x3_2_out,inception_5b_relu_pool_proj_out], 1)
return inception_5b_output_out
def logits(self, features):
adaptiveAvgPoolWidth = features.shape[2]
x = F.avg_pool2d(features, kernel_size=adaptiveAvgPoolWidth)
x = x.view(x.size(0), -1)
x = self.last_linear(x)
return x
def forward(self, input):
x = self.features(input)
x = self.logits(x)
return x
if __name__ == '__main__':
model = BNInception(num_classes=10)
print(model)
\ No newline at end of file
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, **kwargs):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
self.bn = nn.BatchNorm2d(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Inception(nn.Module):
def __init__(self,channel,batch_norm=False):
super(Inception, self).__init__()
if batch_norm==False:
self.branch1x1=nn.Conv2d(channel[0],channel[1],kernel_size=(1,1),stride=1)
self.branch3x3_1=nn.Conv2d(channel[0],channel[2],kernel_size=(1,1),stride=1)
self.branch3x3_2=nn.Conv2d(channel[2],channel[3],kernel_size=(3,3),stride=1,padding=1)
self.branch5x5_1=nn.Conv2d(channel[0],channel[4],kernel_size=(1,1),stride=1)
self.branch5x5_2=nn.Conv2d(channel[4],channel[5],kernel_size=(5,5),stride=1,padding=2)
self.branchM_1=nn.MaxPool2d(kernel_size=3,stride=1,padding=1)
self.branchM_2=nn.Conv2d(channel[0],channel[6],kernel_size=(1,1),stride=1)
else:
self.branch1x1=BasicConv2d(channel[0],channel[1],kernel_size=(1,1),stride=1)
self.branch3x3_1=BasicConv2d(channel[0],channel[2],kernel_size=(1,1),stride=1)
self.branch3x3_2=BasicConv2d(channel[2],channel[3],kernel_size=(3,3),stride=1,padding=1)
self.branch5x5_1=BasicConv2d(channel[0],channel[4],kernel_size=(1,1),stride=1)
self.branch5x5_2=BasicConv2d(channel[4],channel[5],kernel_size=(5,5),stride=1,padding=2)
self.branchM_1=nn.MaxPool2d(kernel_size=3,stride=1,padding=1)
self.branchM_2=BasicConv2d(channel[0],channel[6],kernel_size=(1,1),stride=1)
self.relu=nn.ReLU(True)
def forward(self,x):
branch1x1=self.relu(self.branch1x1(x))
branch3x3_1=self.relu(self.branch3x3_1(x))
branch3x3_2=self.relu(self.branch3x3_2(branch3x3_1))
branch5x5_1=self.relu(self.branch5x5_1(x))
branch5x5_2=self.relu(self.branch5x5_2(branch5x5_1))
branchM_1=self.relu(self.branchM_1(x))
branchM_2=self.relu(self.branchM_2(branchM_1))
outputs = [branch1x1, branch3x3_2, branch5x5_2, branchM_2]
return torch.cat(outputs,1)
channel=[
[192, 64, 96,128, 16, 32, 32],#3a
[256,128,128,192, 32, 96, 64],#3b
[480,192, 96,208, 16, 48, 64],#4a
[512,160,112,224, 24, 64, 64],#4b
[512,128,128,256, 24, 64, 64],#4c
[512,112,144,288, 32, 64, 64],#4d
[528,256,160,320, 32,128,128],#4e
[832,256,160,320, 32,128,128],#5a
[832,384,192,384, 48,128,128] #5b
]
class InceptionNet(nn.Module):
def __init__(self,num_classes=1000,batch_norm=False):
super(InceptionNet, self).__init__()
if num_classes==10:
channel[0][0]=64
self.begin=nn.Sequential(
nn.Conv2d(3,64,kernel_size=3,stride=1),
nn.ReLU(True),
nn.Conv2d(64,64,kernel_size=3,stride=1),
nn.ReLU(True)
)
self.auxout1=nn.Sequential(
nn.Conv2d(512,512,kernel_size=5,stride=3), #4x4x512
nn.ReLU(True),
nn.Conv2d(512,128,kernel_size=1), #4x4x128
nn.ReLU(True),
nn.Conv2d(128, 10,kernel_size=4) #1x1x10
)
self.auxout2=nn.Sequential(
nn.Conv2d(528,528,kernel_size=5,stride=3), #4x4x528,
nn.ReLU(True),
nn.Conv2d(528,128,kernel_size=1), #4x4x128,
nn.ReLU(True),
nn.Conv2d(128, 10,kernel_size=4) #1x1x10
)
else:
self.begin=nn.Sequential(
nn.Conv2d(3,64,kernel_size=7,stride=2,padding=3),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
nn.Conv2d(64,192,kernel_size=3,stride=1,padding=1),
nn.ReLU(True),
nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
)
self.auxout1=nn.Sequential(
nn.Conv2d(512,512,kernel_size=5,stride=3),#4x4x512
nn.ReLU(True),
nn.Conv2d(512,128,kernel_size=1), #4x4x128
nn.ReLU(True)
)
self.auxout12=nn.Sequential(
nn.Linear(2048,1024),
nn.Dropout(0.5),
nn.linear(1024,num_classes)
)
self.auxout2=nn.Sequential(
nn.Conv2d(528,528,kernel_size=5,stride=3),#4x4x528
nn.ReLU(True),
nn.Conv2d(528,128,kernel_size=1), #4x4x128
nn.ReLU(True)
)
self.auxout22=nn.Sequential(
nn.Linear(2048,1024),
nn.Dropout(0.5),
nn.linear(1024,num_classes)
)
self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.inception3a=Inception(channel[0],batch_norm)
self.inception3b=Inception(channel[1],batch_norm)
self.inception4a=Inception(channel[2],batch_norm)
self.inception4b=Inception(channel[3],batch_norm)
self.inception4c=Inception(channel[4],batch_norm)
self.inception4d=Inception(channel[5],batch_norm)
self.inception4e=Inception(channel[6],batch_norm)
self.inception5a=Inception(channel[7],batch_norm)
self.inception5b=Inception(channel[8],batch_norm)
self.avgpool=nn.AdaptiveAvgPool2d((1,1))
self.conv1x1=nn.Conv2d(1024,num_classes,kernel_size=1)
self._initialize_weights()
'''
#follow the original papar,but for the computation ,I do not use it
self.drop=nn.Dropout()
self.linear=nn.Linear(1024,1000)
'''
def _initialize_weights(self):
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias,0)
elif isinstance(m,nn.BatchNorm2d):
nn.init.constant_(m.weight,1)
nn.init.constant_(m.bias,0)
elif isinstance(m,nn.Linear):
nn.init.normal_(m.weight,0,0.01)
nn.init.constant_(m.bias,0)
def forward(self,x):
x=self.begin(x)
x=self.inception3a(x)
x=self.inception3b(x)
x=self.maxpool(x)
x=self.inception4a(x)
auxout1=self.auxout1(x)
auxout1=auxout1.view(auxout1.size(0),-1)
#if you use this network to train on ImageNet you should add this code
#auxout1=self.auxout12(auxout1)
x=self.inception4b(x)
x=self.inception4c(x)
x=self.inception4d(x)
auxout2=self.auxout2(x)
auxout2=auxout2.view(auxout2.size(0),-1)
#if you use this network to train on ImageNet you should add this code
#auxout2=self.auxout22(auxout2)
x=self.inception4e(x)
x=self.maxpool(x)
x=self.inception5a(x)
x=self.inception5b(x)
x=self.avgpool(x)
outputs=self.conv1x1(x)
outputs=outputs.view(outputs.size(0),-1)
return outputs,auxout1,auxout2
if __name__ == '__main__':
net=InceptionNet(num_classes=10,batch_norm=True)
print(net)
from model import *
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
import os.path as osp
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += lossLayer(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
test_loss, 100. * correct / len(test_loader.dataset)
))
if __name__ == "__main__":
batch_size = 32
seed = 1
epochs_cfg = [20, 30, 30, 20, 20, 10, 10]
lr_cfg = [0.01, 0.008, 0.005, 0.002, 0.001, 0.0005, 0.0001]
momentum = 0.5
save_model = True
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = Inception_BN().to(device)
epoch_start = 1
for epochs,lr in zip(epochs_cfg,lr_cfg):
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
epoch_end = epoch_start+epochs
for epoch in range(epoch_start,epoch_end):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
epoch_start += epochs
if save_model:
if not osp.exists('ckpt'):
os.makedirs('ckpt')
torch.save(model.state_dict(), 'ckpt/cifar10_Inception_BN.pt')
\ No newline at end of file
import torch
import torch.nn as nn
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
e_bit_list = list(range(1,num_bits-1))
else:
e_bit_list = [0]
return e_bit_list
def numbit_list(quant_type):
if quant_type == 'INT':
num_bit_list = list(range(2,17))
elif quant_type == 'POT':
num_bit_list = list(range(2,9))
else:
num_bit_list = list(range(2,9))
# num_bit_list = [8]
return num_bit_list
def build_bias_list(quant_type):
if quant_type == 'POT':
return build_pot_list(8)
else:
return build_float_list(16,7)
def build_list(quant_type, num_bits, e_bits):
if quant_type == 'POT':
return build_pot_list(num_bits)
else:
return build_float_list(num_bits,e_bits)
def build_pot_list(num_bits):
plist = [0.]
for i in range(-2 ** (num_bits-1) + 2, 1):
# i最高到0,即pot量化最大值为1
plist.append(2. ** i)
plist.append(-2. ** i)
plist = torch.Tensor(list(set(plist)))
# plist = plist.mul(1.0 / torch.max(plist))
return plist
def build_float_list(num_bits,e_bits):
m_bits = num_bits - 1 - e_bits
plist = [0.]
# 相邻尾数的差值
dist_m = 2 ** (-m_bits)
e = -2 ** (e_bits - 1) + 1
for m in range(1, 2 ** m_bits):
frac = m * dist_m # 尾数部分
expo = 2 ** e # 指数部分
flt = frac * expo
plist.append(flt)
plist.append(-flt)
for e in range(-2 ** (e_bits - 1) + 2, 2 ** (e_bits - 1) + 1):
expo = 2 ** e
for m in range(0, 2 ** m_bits):
frac = 1. + m * dist_m
flt = frac * expo
plist.append(flt)
plist.append(-flt)
plist = torch.Tensor(list(set(plist)))
return plist
#此处不必cfg,直接取同前缀同后缀即可。将relu一起考虑进去
def fold_ratio(layer, par_ratio, flop_ratio):
idx = -1
for name in layer:
if 'conv' in name:
conv_idx = layer.index(name)
[prefix,suffix] = name.split('conv')
bn_name = prefix+'bn'+suffix
relu_name = prefix+'relu'+suffix
if bn_name in layer:
bn_idx = layer.index(bn_name)
par_ratio[conv_idx]+=par_ratio[bn_idx]
flop_ratio[conv_idx]+=flop_ratio[bn_idx]
if relu_name in layer:
relu_idx = layer.index(relu_name)
par_ratio[conv_idx]+=par_ratio[relu_idx]
flop_ratio[conv_idx]+=flop_ratio[bn_idx]
return par_ratio,flop_ratio
def fold_model(model):
for name, module in model.named_modules():
if 'conv' in name:
[prefix,suffix] = name.split('conv')
bn_name = prefix+'bn'+suffix
if hasattr(model,bn_name):
bn_layer = getattr(model,bn_name)
fold_bn(module,bn_layer)
def fold_bn(conv, bn):
# 获取 BN 层的参数
mean = bn.running_mean
var = bn.running_var
eps = bn.eps
std = torch.sqrt(var + eps)
if bn.affine:
gamma_ = bn.weight / std
weight = conv.weight * gamma_.view(conv.out_channels, 1, 1, 1)
if conv.bias is not None:
bias = gamma_ * conv.bias - gamma_ * mean + bn.bias
else:
bias = bn.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = conv.weight * gamma_
if conv.bias is not None:
bias = gamma_ * conv.bias - gamma_ * mean
else:
bias = -gamma_ * mean
# 设置新的 weight 和 bias
conv.weight.data = weight.data
if conv.bias is not None:
conv.bias.data = bias.data
else:
conv.bias = torch.nn.Parameter(bias)
# 改动说明 # 改动说明
## update:2023/04/17 ## update: 2023/04/22
+ 添加Inception BN模型,对框架改动如下
+ 使用cfg_table进行模型快速部署和量化(可应用于其他模型),cfg_table提供整体forward平面结构,包括inc模块。inc模块则由inc_ch_table和inc_cfg_table进行部署量化。其规则可详见文件
+ cfg_table每项对应一个可进行量化融合的模块,如相邻conv bn relu可融合,在cfg_table中表现为['C','BR',...]。从而可以更方便的从该表进行量化和flops/param权重提取
+ 更改fold_ratio方法,以支持cfg_table的同项融合,多考虑了relu层,摆脱原先临近依赖的限制。方案:读取到conv时,获取相同前后缀的层,并相加
+ 更改module,允许量化层传入层的bias为none。原fold_bn已经考虑了,只需要改动freeze即可。
+ 对于Conv,freeze前后都无bias。
+ 对于ConvBN,freeze前后有bias。forward使用临时值,inference使用固定值(固定到相应conv_module)。
+ 由于允许conv.bias=None,相应改变全精度模型fold_bn方法,从而保证量化前后可比参数相同。改写方式同量化层
+ 更改js_div计算方法,一个层如果同时有多个参数,例如weight和bias,应该总共加起来权重为1。当前直接简单取平均(即js除以该层参数量),后续考虑加权。PS: Inception_BN中,外层conv层有bias,Inception模块内由于后接bn层,bias为false
+ 由于named_parameters迭代器长度不固定,需要先将排成固定列表再处理,从而获得同一层参数数,改动见ptq.py。对全精度模型做此操作即可
+ 新框架中的model_utils方法可以通过调整反量化位置来进行bug的确定。经过当前实验,可以初步判断精度问题出现在inception结构中,具体信息见Inception_BN相关部分。经过排查,量化框架本身并未出现问题,问题可能在于该模型参数分布与POT集中分布的不适配。
## update: 2023/04/17
+ 指定了新的梯度学习率方案,对全精度模型重新训练以达到更高的acc,并重新进行ptq和fit + 指定了新的梯度学习率方案,对全精度模型重新训练以达到更高的acc,并重新进行ptq和fit
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment