Commit 50852902 by Zhihong Ma

fix : new mzh : PTQ ResNet18/50/152 old mzh: saved & reconstructed

parent 62550290
......@@ -3,15 +3,16 @@ import os
# 从get_param.py输出重定向文件val.txt中提取参数量和计算量
def extract_ratio():
fr = open('param_flops.txt','r')
def extract_ratio(md='ResNet18'):
fr = open('param_flops_' + md + '.txt','r')
lines = fr.readlines()
layer = []
par_ratio = []
flop_ratio = []
for line in lines:
if '(' in line and ')' in line:
layer.append(line.split(')')[0].split('(')[1])
# if '(' in line and ')' in line:
if 'Conv' in line or 'BatchNorm2d' in line or 'Linear' in line:
layer.append(line.split(':')[1].split('(')[0])
r1 = line.split('%')[0].split(',')[-1]
r1 = float(r1)
par_ratio.append(r1)
......@@ -24,6 +25,6 @@ def extract_ratio():
if __name__ == "__main__":
layer, par_ratio, flop_ratio = extract_ratio()
print(layer)
print(par_ratio)
print(flop_ratio)
\ No newline at end of file
print(len(layer))
print(len(par_ratio))
print(len(flop_ratio))
\ No newline at end of file
from model import *
import torch
from ptflops import get_model_complexity_info
import argparse
def get_children(model: torch.nn.Module):
# get children form model!
# 为了后续也能够更新参数,需要用nn.ModuleList来承载
# children = nn.ModuleList(model.children())
# print(children)
# 方便对其中的module进行后续的更新
# flatt_children = nn.ModuleList()
children = list(model.children())
# flatt_children = nn.ModuleList()
flatt_children = []
if len(children) == 0:
# if model has no children; model is last child! :O
return model
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
# print(flatt_children)
return flatt_children
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Model Analysis --- params & flops')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
args = parser.parse_args()
if args.model == 'ResNet18':
model = resnet18()
elif args.model == 'ResNet50':
model = resnet50()
elif args.model == 'ResNet152':
model = resnet152()
full_file = 'ckpt/cifar10_' + args.model + '.pt'
model.load_state_dict(torch.load(full_file))
# flat = get_children(model)
# print(flat)
# flat = get_children(model)
# new_model = nn.Sequential(*flat)
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
......@@ -625,10 +625,10 @@ class QModule_2(nn.Module):
if qi1:
self.qi1 = QParam(quant_type,num_bits, e_bits) # qi在此处就已经被num_bits和mode赋值了
if qo:
self.qo = QParam(quant_type,32, e_bits) # qo在此处就已经被num_bits和mode赋值了
self.qo = QParam(quant_type,num_bits, e_bits) # qo在此处就已经被num_bits和mode赋值了
self.quant_type = quant_type
self.num_bits = 32
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
......
......@@ -228,17 +228,17 @@ if __name__ == "__main__":
print(f"test accuracy: {acc:.2f}%")
for name, module in model.named_modules():
print(f"{name}: {module}\n")
# for name, module in model.named_modules():
# print(f"{name}: {module}\n")
print('========================================================')
print('========================================================')
# print('========================================================')
# print('========================================================')
model.quantize()
for name , layer in model.quantize_layers.items():
print(f"Layer {name}: {layer} ") # 足够遍历了
# model.quantize()
# for name , layer in model.quantize_layers.items():
# print(f"Layer {name}: {layer} ") # 足够遍历了
......
ResNet(
714.09 k, 101.626% Params, 36.09 MMac, 100.000% MACs,
(conv1): Conv2d(448, 0.064% Params, 458.75 KMac, 1.271% MACs, 3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 16.38 KMac, 0.045% MACs, )
(layer1): MakeLayer(
9.41 k, 1.339% Params, 9.7 MMac, 26.879% MACs,
(blockdict): ModuleDict(
9.41 k, 1.339% Params, 9.7 MMac, 26.879% MACs,
(block1): BasicBlock(
4.7 k, 0.669% Params, 4.85 MMac, 13.440% MACs,
(conv1): Conv2d(2.32 k, 0.330% Params, 2.38 MMac, 6.584% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(2.32 k, 0.330% Params, 2.38 MMac, 6.584% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 32.77 KMac, 0.091% MACs, )
)
(block2): BasicBlock(
4.7 k, 0.669% Params, 4.85 MMac, 13.440% MACs,
(conv1): Conv2d(2.32 k, 0.330% Params, 2.38 MMac, 6.584% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(2.32 k, 0.330% Params, 2.38 MMac, 6.584% MACs, 16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(32, 0.005% Params, 32.77 KMac, 0.091% MACs, 16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 32.77 KMac, 0.091% MACs, )
)
)
)
(layer2): MakeLayer(
33.86 k, 4.818% Params, 8.7 MMac, 24.109% MACs,
(downsample): Sequential(
608, 0.087% Params, 155.65 KMac, 0.431% MACs,
(0): Conv2d(544, 0.077% Params, 139.26 KMac, 0.386% MACs, 16, 32, kernel_size=(1, 1), stride=(2, 2))
(1): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.045% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(blockdict): ModuleDict(
33.25 k, 4.732% Params, 8.54 MMac, 23.678% MACs,
(block1): BasicBlock(
14.62 k, 2.081% Params, 3.76 MMac, 10.420% MACs,
(conv1): Conv2d(4.64 k, 0.660% Params, 1.19 MMac, 3.292% MACs, 16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.045% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(9.25 k, 1.316% Params, 2.37 MMac, 6.561% MACs, 32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.045% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 16.38 KMac, 0.045% MACs, )
)
(block2): BasicBlock(
18.62 k, 2.650% Params, 4.78 MMac, 13.258% MACs,
(conv1): Conv2d(9.25 k, 1.316% Params, 2.37 MMac, 6.561% MACs, 32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.045% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(9.25 k, 1.316% Params, 2.37 MMac, 6.561% MACs, 32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(64, 0.009% Params, 16.38 KMac, 0.045% MACs, 32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 16.38 KMac, 0.045% MACs, )
)
)
)
(layer3): MakeLayer(
134.27 k, 19.109% Params, 8.61 MMac, 23.860% MACs,
(downsample): Sequential(
2.24 k, 0.319% Params, 143.36 KMac, 0.397% MACs,
(0): Conv2d(2.11 k, 0.301% Params, 135.17 KMac, 0.375% MACs, 32, 64, kernel_size=(1, 1), stride=(2, 2))
(1): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(blockdict): ModuleDict(
132.03 k, 18.790% Params, 8.47 MMac, 23.462% MACs,
(block1): BasicBlock(
57.92 k, 8.243% Params, 3.72 MMac, 10.295% MACs,
(conv1): Conv2d(18.5 k, 2.632% Params, 1.18 MMac, 3.280% MACs, 32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(36.93 k, 5.255% Params, 2.36 MMac, 6.550% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 8.19 KMac, 0.023% MACs, )
)
(block2): BasicBlock(
74.11 k, 10.547% Params, 4.75 MMac, 13.167% MACs,
(conv1): Conv2d(36.93 k, 5.255% Params, 2.36 MMac, 6.550% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(36.93 k, 5.255% Params, 2.36 MMac, 6.550% MACs, 64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(128, 0.018% Params, 8.19 KMac, 0.023% MACs, 64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 8.19 KMac, 0.023% MACs, )
)
)
)
(layer4): MakeLayer(
534.78 k, 76.108% Params, 8.56 MMac, 23.735% MACs,
(downsample): Sequential(
8.58 k, 1.220% Params, 137.22 KMac, 0.380% MACs,
(0): Conv2d(8.32 k, 1.184% Params, 133.12 KMac, 0.369% MACs, 64, 128, kernel_size=(1, 1), stride=(2, 2))
(1): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(blockdict): ModuleDict(
526.21 k, 74.887% Params, 8.43 MMac, 23.355% MACs,
(block1): BasicBlock(
230.53 k, 32.808% Params, 3.69 MMac, 10.233% MACs,
(conv1): Conv2d(73.86 k, 10.511% Params, 1.18 MMac, 3.275% MACs, 64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(bn1): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(147.58 k, 21.003% Params, 2.36 MMac, 6.544% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 4.1 KMac, 0.011% MACs, )
)
(block2): BasicBlock(
295.68 k, 42.080% Params, 4.73 MMac, 13.122% MACs,
(conv1): Conv2d(147.58 k, 21.003% Params, 2.36 MMac, 6.544% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn1): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(147.58 k, 21.003% Params, 2.36 MMac, 6.544% MACs, 128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(bn2): BatchNorm2d(256, 0.036% Params, 4.1 KMac, 0.011% MACs, 128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(0, 0.000% Params, 4.1 KMac, 0.011% MACs, )
)
)
)
(avgpool): AdaptiveAvgPool2d(0, 0.000% Params, 2.05 KMac, 0.006% MACs, output_size=(1, 1))
(fc): Linear(1.29 k, 0.184% Params, 1.29 KMac, 0.004% MACs, in_features=128, out_features=10, bias=True)
)
\ No newline at end of file
......@@ -68,7 +68,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PyTorch FP32 Training')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='ResNet18')
parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
parser.add_argument('-j','--workers', default=4, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')
# parser.add_argument('-t', '--test', dest='test', action='store_true', help='test model on test set')
......@@ -121,19 +121,28 @@ if __name__ == "__main__":
model.eval()
full_acc = full_inference(model, test_loader, device)
model_fold = fold_model(model)
model_fold = fold_model(model) #
full_params = []
layer, par_ratio, flop_ratio = extract_ratio()
print(layer)
layer, par_ratio, flop_ratio = extract_ratio(args.model)
# print(layer)
layer = []
for name, param in model.named_parameters():
if 'weight' in name:
n = name.split('.')
pre = '.'.join(n[:len(n)-1])
layer.append(pre)
# print(name)
print('===================')
# print(layer)
par_ratio, flop_ratio = fold_ratio(layer, par_ratio, flop_ratio)
# sys.exit()
for name, param in model_fold.named_parameters():
if 'bn' in name:
if 'bn' in name or 'sample.1' in name:
continue
param_norm = F.normalize(param.data.cpu(),p=2,dim=-1)
full_params.append(param_norm)
full_params.append(param_norm) # 没统计bn的 只统计了conv的 而且还是fold后的
writer.add_histogram(tag='Full_' + name + '_data', values=param.data)
......@@ -156,7 +165,14 @@ if __name__ == "__main__":
for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list:
model_ptq = resnet18()
# model_ptq = resnet18()
if args.model == 'ResNet18':
model_ptq = resnet18()
elif args.model == 'ResNet50':
model_ptq = resnet50()
elif args.model == 'ResNet152':
model_ptq = resnet152()
if quant_type == 'FLOAT':
title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
else:
......@@ -194,25 +210,30 @@ if __name__ == "__main__":
js_flops = 0.
js_param = 0.
for name, param in model_ptq.named_parameters():
if '.' not in name or 'bn' in name:
# if '.' not in name or 'bn' in name:
if 'bn' in name or 'sample.1' in name:
continue
writer.add_histogram(tag=title +':'+ name + '_data', values=param.data)
# idx = idx + 1
# prefix = name.split('.')[0]
# if prefix in layer:
# layer_idx = layer.index(prefix)
# ptq_param = param.data.cpu()
# # 取L2范数
# ptq_norm = F.normalize(ptq_param,p=2,dim=-1)
# writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param)
# js = js_div(ptq_norm,full_params[idx])
# js = js.item()
# if js < 0.:
# js = 0.
# js_flops = js_flops + js * flop_ratio[layer_idx]
# js_param = js_param + js * flop_ratio[layer_idx]
# js_flops_list.append(js_flops)
# js_param_list.append(js_param)
idx = idx + 1
# renset中有多个. 需要改写拼一下
# prefix = name.split('.')[0]
n = name.split('.')
prefix = '.'.join(n[:len(n) - 1])
# weight和bias 1:1 ? 对于ratio,是按层赋予的,此处可以对weight和bias再单独赋予不同的权重,比如(8:2)
if prefix in layer:
layer_idx = layer.index(prefix)
ptq_param = param.data.cpu()
# 取L2范数
ptq_norm = F.normalize(ptq_param,p=2,dim=-1)
writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param)
js = js_div(ptq_norm,full_params[idx]) # 这里算了fold后的量化前后模型的js距离
js = js.item()
if js < 0.:
js = 0.
js_flops = js_flops + js * flop_ratio[layer_idx]
js_param = js_param + js * par_ratio[layer_idx]
js_flops_list.append(js_flops)
js_param_list.append(js_param)
print(title + ': js_flops: %f js_param: %f acc_loss: %f' % (js_flops, js_param, acc_loss))
......@@ -233,10 +254,10 @@ if __name__ == "__main__":
worksheet.cell(row=i+4, column=4, value=ptq_acc_list[i])
worksheet.cell(row=i+4, column=5, value=acc_loss_list[i])
workbook.save('ptq_result.xlsx')
workbook.save('ptq_result_' + args.model + '.xlsx')
writer.close()
ft = open('ptq_result.txt','w')
ft = open('ptq_result_' + args.model + '.txt','w')
print('title_list:',file=ft)
print(" ".join(title_list),file=ft)
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
1575.126077030527 980.8324825038856 447.4871705577316 203.8177281153719 94.1658153206219 44.73944284292641 21.730716696253086 10.687903335080755 5.2935009924434775 2.6865031426677675 1.345978185346981 0.6738058971124082 0.34590930672785625 0.16620132379306904 0.09185943251823848 1575.0264663456858 767.5068295225365 59.80415491853343 17.32189175246257 17.160386413787755 17.15972613238827 17.160655554562823 547.0296470821636 228.09197712606053 153.9307141697144 102.8744121697856 63.04910506966272 11.893784458090247 49.68929151890493 30.72369295281706 4.336553462330601 4.810517948583543 25.62475856077897 16.963161148931942 1.7730239215421446 1.2492962287085048 4.844787354857122 14.21240714817728 10.605240065475499 0.7963437572573967 0.32131797583853794 1.3061700599586734 4.844787523330232
js_param_list:
2231.9475377209037 1458.7430817370525 656.866021106162 290.661557510572 132.0211812900384 62.06574209045005 29.96287022906031 14.768791159744465 7.344364349715033 3.757019554618513 1.896182903527843 0.9241808205303167 0.45857306080932436 0.2269121111102425 0.12261352661167306 2231.901673608193 1143.359470049635 82.82637961696304 24.06635574752677 23.843136397545287 23.842358607630732 23.84306741528584 799.9775130544906 323.8336430582792 218.61973701520765 143.18120884416584 88.72081224892759 16.52024912262558 68.08470436272326 43.20128678260041 6.041579655336327 6.686327875421352 34.6238061335222 24.064747116161215 2.491426419987749 1.7403336017725606 6.690842031928857 18.94797143083834 15.257619881935225 1.0957373786589855 0.44768947355373956 1.7705741794826835 6.690842738428997
ptq_acc_list:
10.0 10.0 10.0 78.52 86.7 89.95 90.73 90.96 90.64 87.4 74.21 52.1 40.65 30.51 20.3 10.0 10.0 10.0 39.21 40.15 44.33 34.83 10.0 19.98 10.0 34.59 85.82 80.56 57.06 88.62 90.17 81.06 68.03 89.75 90.85 88.77 10.0 72.61 90.02 91.08 89.55 10.0 10.0
acc_loss_list:
0.8900978129464776 0.8900978129464776 0.8900978129464776 0.1370480272557424 0.04714803824596101 0.011429827453566238 0.0028574568633914815 0.0003297065611605796 0.0038465765468732203 0.03945488515221441 0.18441586987581055 0.4274096054511484 0.5532476096274316 0.6646884272997032 0.7768985602813496 0.8900978129464776 0.8900978129464776 0.8900978129464776 0.5690735245631388 0.5587427189801077 0.5128036047917354 0.6172106824925816 0.8900978129464776 0.7804154302670623 0.8900978129464776 0.6198483349818661 0.05681943070667109 0.11462798109682375 0.3728981206726013 0.026046818331684696 0.00901197933838876 0.10913287174414764 0.2523354214748873 0.013627871194636718 0.0015386306187493194 0.024398285525881955 0.8900978129464776 0.20200021980437408 0.010660512144191657 -0.0009891196834817388 0.015825914935707196 0.8900978129464776 0.8900978129464776
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
1833.4576454973073 814.7863891368864 217.7654229387627 54.07616924802023 13.731802945455469 3.5847427020530582 0.9118541432904458 0.2622900218848318 0.07627003915874074 0.027745791769400664 0.015915006254486226 0.012409352705166696 0.0077479353538904274 0.0062617873011873975 0.005917287498327866 1833.2003417254284 544.2136113656462 35.21026365121499 33.83804856891729 33.83703344984572 33.83750169488491 33.84147193704756 342.096219925368 82.6043808610444 75.92517125989443 27.82235802343243 26.574672151466128 9.6049355988981 14.044291246668882 14.55114135659603 2.4864347446515884 9.426150750133262 10.07193874315086 10.701781541507756 0.6597191298214471 2.4197650833272104 9.550487563237345 8.849135643873504 9.216705123201871 0.1929881628940372 0.6207588325434388 2.5428780026080493 9.550487563237345
js_param_list:
3613.037168160796 1825.7907466758202 512.0932785883192 129.26071365337654 33.314456921282606 8.673843570791789 2.1826018682118424 0.6138186833325912 0.1691841503982388 0.05180905191755439 0.02266508641878177 0.014530378356803484 0.00975786055068809 0.005063431812688739 0.00398069855228542 3612.992302272399 1246.9340617899438 71.14710558688047 67.61964317269017 67.6172664356203 67.61753548832318 67.6175100773394 755.379587970111 181.41267691267066 170.89087380459807 56.989989927129535 59.371069176236894 19.274735775346528 26.031672719261728 32.363778392002544 5.0043194398511135 18.814548222792805 17.309141148134536 23.84953967534161 1.332034978863292 4.83191046013193 18.864051408815957 14.787650268158211 20.519388091926267 0.3942680972083926 1.231435885110694 4.879394902995963 18.864051408815957
ptq_acc_list:
10.0 10.0 31.15 81.89 84.93 85.69 85.78 85.63 82.63 74.8 51.56 29.34 13.78 11.57 10.17 10.0 10.0 44.45 44.64 46.43 44.18 38.58 9.92 38.85 70.91 65.34 82.3 80.82 73.99 84.14 85.05 76.68 75.95 84.95 85.55 81.54 10.0 77.73 85.18 85.98 81.93 10.01 13.34
acc_loss_list:
0.8835991153532767 0.8835991153532767 0.6374112443254569 0.04679315562798273 0.011407286695378766 0.0025608194622279 0.0015132115004073503 0.003259224770108266 0.038179490164125265 0.1293213828425096 0.3998370387614945 0.6584798044465138 0.8395995809568153 0.8653241764637412 0.8816203003142824 0.8835991153532767 0.8835991153532767 0.48259806774531483 0.4803864509370271 0.45955069258526365 0.4857408916307764 0.5509253870329415 0.8845303224304505 0.5477825631474799 0.17460132697008499 0.2394366197183098 0.04202071935746711 0.05924805028518221 0.1387498544988942 0.02060295658246998 0.010010476079618198 0.10743801652892551 0.11593528110813635 0.011174484926085367 0.004190431847282033 0.05086718659061798 0.8835991153532767 0.09521592364101959 0.008497264579210684 -0.0008148061925271492 0.04632755208939576 0.8834827144686299 0.844721219881271
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
1489.6432790793892 858.47390911721 350.38842997977486 146.66108726257698 65.51871772345022 30.802447738403625 15.015633081763848 7.372939759214539 3.602748170145869 1.7596017349024324 0.9023980469489912 0.42604559053986407 0.2086617904343124 0.11696138076612213 0.06650877266410397 1489.4950585919782 648.6726890563766 44.0313761461945 14.813184296979202 14.70886411284525 14.708637223793453 14.708329981851291 442.110054514285 167.03961080744105 110.912352356486 73.14257321252117 44.75826464643717 8.918392355710786 35.41728607793805 22.00069249787684 3.0807357022322006 4.133106679411769 18.786975210869198 12.291142909976228 1.2420341267864268 1.0780820385160967 4.196963771246701 10.816659177228358 7.780715811154463 0.513961854917128 0.2801788104206083 1.150574461896198 4.196963771246701
js_param_list:
2988.8747567488617 1887.9793935628438 794.5371505720092 330.3960680775245 145.92231495119452 67.9559314292448 33.03981244952361 16.124047122726786 8.021401990398326 3.943098007875918 1.9811299823118427 0.9460539051395199 0.44709418282093033 0.22449034273754867 0.12425914862692854 2988.8363531204886 1451.7681143260804 94.67273844326954 30.460878266197444 30.244231409403923 30.2446749589304 30.244134610251493 984.9086948427197 371.60971497639866 248.5749360354289 159.90777702378026 99.54631101875773 19.048673214252524 75.87671359764475 48.95576239520067 6.683113070521427 8.485231215526596 39.31778320380456 27.44412247810391 2.6854627255413566 2.207580403630901 8.479439151405776 21.80574465505866 17.614834435129385 1.148945392883737 0.5553705895013917 2.2254689905601692 8.479439151405776
ptq_acc_list:
10.0 10.0 10.01 72.69 87.21 89.67 90.45 90.33 89.37 79.82 61.97 35.21 22.84 21.47 13.47 10.0 10.0 12.81 17.49 27.49 30.18 34.97 10.0 15.78 21.89 33.3 82.29 82.49 58.04 87.21 88.9 82.42 67.65 88.34 90.33 87.15 10.05 70.35 89.06 90.52 88.78 9.99 10.0
acc_loss_list:
0.8896369054188279 0.8896369054188279 0.8895265423242467 0.19777066548946035 0.037523452157598565 0.010374130890630148 0.0017658095132987153 0.00309016664827283 0.01368502372806528 0.11908177905308472 0.31607990288047677 0.6114115439796932 0.747930691976603 0.7630504359342236 0.8513409115991613 0.8896369054188279 0.8896369054188279 0.8586248758415186 0.8069749475775301 0.6966118529963581 0.6669241805540227 0.6140602582496413 0.8896369054188279 0.8258470367509105 0.7584151859618143 0.6324908950446971 0.09182209469153507 0.08961483279991177 0.3594525990508774 0.037523452157598565 0.018872089173380353 0.09038737446197989 0.253393665158371 0.025052422469926013 0.00309016664827283 0.03818563072508546 0.8890850899459222 0.22359562962145466 0.017106279660081637 0.0009932678512305862 0.020196446308354467 0.8897472685134091 0.8896369054188279
## update: <br>2023.4.12<br>
- 已修改get_param_flops.py, extract_ratio.py, ptq.py中与计算数据分布相似度、计算计算量、参数量加权权重相关的程序,使其能够适应于ResNet系列网络。<br>
(注: 目前需要在得到了param_flops_ResNet18/50/152.txt后,手动删除在layer中重复出现了一次的downsample module的统计数据,考虑到resnet18的只需要删3处,resnet50,152中的需要删4处,故直接手动进行了)
- 已补充ResNet18/50/152的PTQ训练测试结果,在ptq_result_ResNet18/50/152.xlsx中
- 拟合曲线图:
1. ResNet18 - 参数量加权
<img src = "fig/res18-1.png" class="h-90 auto">
ResNet18 - 计算量加权
<img src = "fig/res18-2.png" class="h-90 auto">
2. ResNet50 - 参数量加权
<img src = "fig/res50-2.png" class="h-90 auto">
ResNet50 - 计算量加权
<img src = "fig/res50.png" class="h-90 auto">
3. ResNet152 - 参数量加权
<img src = "fig/res152-1.png" class="h-90 auto">
ResNet152 - 计算量加权
<img src = "fig/res152-2.png" class="h-90 auto">
- 所有图中在左上方均存在一系列离群点,其中包括了INT量化在量化位宽较高时,acc骤降对应的点,也包括了一些其他的异常点,还需要再对程序、数据进一步分析,调整参数,使部分由于程序设置导致异常的点恢复过来。
- 后续待补充:
1. 分析为什么INT量化位宽较大时,acc骤降,并解决该问题
2. 检查网络量化是否存在细节错误
3. 检查计算数据分布相似度时是否正确加权计算
4. ResNet QAT相关实验
<br><br>
分割线
===============================================<br><br><br>
## update:<br> 2023.4.10 <br>
- 注:new_mzh中的程序改用了与游昆霖同学统一的度量方式、以及一些量化细节约定,将代码重新建立在游昆霖同学版本的程序上。<br>
在量化BN层的过程中遇到了较多问题,感谢游昆霖同学的帮助:D
### 程序改动:
为量化ResNet18,在module.py中新增的量化层包括QConvBNReLu层,QConvBN层,QElementwiseAdd层,QAdaptiveAvgPool2d层。在model.py中建立了ResNet18的量化架构,通过class BasicBlock, class Bottleneck, class MakeLayer等保障了ResNet的扩展性,能够较为方便的扩展成ResNet50和152
### 待完善:
- ResNet的网络架构相比于AlexNet,VGG等更加的跳跃,各种MakeLayer, Residual的结构使得其不是一个平铺开来的网络,则过去的很多计算相似度等的算法不能直接适用在ResNet上(直接遍历网络参数时,会有包装在conv,bn等层外面的layer, sequential, block等),关于参数相似度、梯度相似度的分析有待后续研究补充。
QAT方面有待后续补充
- 我在加上正确的QElementwiseAdd层前,PTQ后的acc都不超过15%,足以见到该层的重要性,他是负责残差的相加部分,因为两个层的输出结果是在不同量化范围,所以不能直接相加,而是需要做rescale。
- 目前看到INT量化随位宽增加而先增大后下降,我查看了量化后的参数分布,其整体趋势与全精度模型是较为相似的,因此问题不在Conv,BN等普通的量化层上,我猜想可能是因为量化位宽较大的时候,QElementwiseAdd做rescale的过程中出现了溢出,还有待后续观察确认。
\ No newline at end of file
......@@ -82,8 +82,8 @@ echo "Use GPU ${CUDA_VISIBLE_DEVICES}" # which gpus
sleep 2s
hostname
echo "python ./new_train.py -m ResNet18 -e 60 -b 128 -j 4 -lr 0.001 -wd 0.0001"
python ./new_train.py -m ResNet18 -e 60 -b 128 -j 4 -lr 0.001 -wd 0.0001
echo "python ./new_train.py -m ResNet152 -e 300 -b 128 -j 4 -lr 0.001 -wd 0.0001"
python ./new_train.py -m ResNet152 -e 300 -b 128 -j 4 -lr 0.001 -wd 0.0001
......
......@@ -5,7 +5,7 @@
# (TODO)
# Please modify job name
#SBATCH -J ResNet18_trial # The job name
#SBATCH -J ResNet50_trial # The job name
#SBATCH -o ./info/ret-%j.out # Write the standard output to file named 'ret-<job_number>.out'
#SBATCH -e ./info/ret-%j.err # Write the standard error to file named 'ret-<job_number>.err'
......@@ -83,7 +83,12 @@ sleep 2s
hostname
echo "python ./ptq.py -m ResNet18 -b 128 -j 4"
echo "python ./ptq.py -m ResNet50 -b 128 -j 4"
echo "python ./ptq.py -m ResNet152 -b 128 -j 4"
# python ./ptq.py -m ResNet18 -b 128 -j 4
python ./ptq.py -m ResNet18 -b 128 -j 4
python ./ptq.py -m ResNet50 -b 128 -j 4
python ./ptq.py -m ResNet152 -b 128 -j 4
......
from model import *
import torch
from ptflops import get_model_complexity_info
if __name__ == "__main__":
model = resnet18()
full_file = 'ckpt/cifar10_ResNet18.pt'
model.load_state_dict(torch.load(full_file))
flops, params = get_model_complexity_info(model, (3, 32, 32), as_strings=True, print_per_layer_stat=True)
2023.4.10
注:new_mzh中的程序改用了与游昆霖同学统一的度量方式、以及一些量化细节约定,将代码重新建立在游昆霖同学版本的程序上。
在量化BN层的过程中遇到了较多问题,感谢游昆霖同学的帮助:D
程序改动:
为量化ResNet18,在module.py中新增的量化层包括QConvBNReLu层,QConvBN层,QElementwiseAdd层,QAdaptiveAvgPool2d层。在model.py中建立了ResNet18的量化架构,通过class BasicBlock, class Bottleneck, class MakeLayer等保障了ResNet的扩展性,能够较为方便的扩展成ResNet50和152
待完善:
ResNet的网络架构相比于AlexNet,VGG等更加的跳跃,各种MakeLayer, Residual的结构使得其不是一个平铺开来的网络,则过去的很多计算相似度等的算法不能直接适用在ResNet上(直接遍历网络参数时,会有包装在conv,bn等层外面的layer, sequential, block等),关于参数相似度、梯度相似度的分析有待后续研究补充。
QAT方面有待后续补充
下面的实验是关于ResNet18的PTQ结果:(js_flops, js_param等还未修改计算方式,因而暂时未计算,赋值为0)
```
PTQ: INT_2
direct quantization finish
Test set: Quant Model Accuracy: 10.00%
INT_2: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.883599
PTQ: INT_3
direct quantization finish
Test set: Quant Model Accuracy: 10.00%
INT_3: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.883599
PTQ: INT_4
direct quantization finish
Test set: Quant Model Accuracy: 49.76%
INT_4: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.420789
PTQ: INT_5
direct quantization finish
Test set: Quant Model Accuracy: 80.86%
INT_5: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.058782
PTQ: INT_6
direct quantization finish
Test set: Quant Model Accuracy: 84.91%
INT_6: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.011640
PTQ: INT_7
direct quantization finish
Test set: Quant Model Accuracy: 85.60%
INT_7: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.003608
PTQ: INT_8
direct quantization finish
Test set: Quant Model Accuracy: 85.85%
INT_8: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.000698
PTQ: INT_9
direct quantization finish
Test set: Quant Model Accuracy: 85.64%
INT_9: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.003143
PTQ: INT_10
direct quantization finish
Test set: Quant Model Accuracy: 82.81%
INT_10: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.036084
PTQ: INT_11
direct quantization finish
Test set: Quant Model Accuracy: 74.91%
INT_11: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.128041
PTQ: INT_12
direct quantization finish
Test set: Quant Model Accuracy: 56.50%
INT_12: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.342335
PTQ: INT_13
direct quantization finish
Test set: Quant Model Accuracy: 26.25%
INT_13: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.694448
PTQ: INT_14
direct quantization finish
Test set: Quant Model Accuracy: 14.16%
INT_14: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.835176
PTQ: INT_15
direct quantization finish
Test set: Quant Model Accuracy: 11.29%
INT_15: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.868583
PTQ: INT_16
direct quantization finish
Test set: Quant Model Accuracy: 10.25%
INT_16: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.880689
PTQ: POT_2
direct quantization finish
Test set: Quant Model Accuracy: 10.00%
POT_2: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.883599
PTQ: POT_3
direct quantization finish
Test set: Quant Model Accuracy: 10.00%
POT_3: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.883599
PTQ: POT_4
direct quantization finish
Test set: Quant Model Accuracy: 44.75%
POT_4: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.479106
PTQ: POT_5
direct quantization finish
Test set: Quant Model Accuracy: 40.29%
POT_5: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.531021
PTQ: POT_6
direct quantization finish
Test set: Quant Model Accuracy: 50.13%
POT_6: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.416482
PTQ: POT_7
direct quantization finish
Test set: Quant Model Accuracy: 45.75%
POT_7: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.467466
PTQ: POT_8
direct quantization finish
Test set: Quant Model Accuracy: 39.79%
POT_8: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.536841
PTQ: FLOAT_3_E1
direct quantization finish
Test set: Quant Model Accuracy: 9.93%
FLOAT_3_E1: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.884414
PTQ: FLOAT_4_E1
direct quantization finish
Test set: Quant Model Accuracy: 39.63%
FLOAT_4_E1: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.538703
PTQ: FLOAT_4_E2
direct quantization finish
Test set: Quant Model Accuracy: 70.74%
FLOAT_4_E2: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.176580
PTQ: FLOAT_5_E1
direct quantization finish
Test set: Quant Model Accuracy: 65.04%
FLOAT_5_E1: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.242929
PTQ: FLOAT_5_E2
direct quantization finish
Test set: Quant Model Accuracy: 82.65%
FLOAT_5_E2: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.037947
PTQ: FLOAT_5_E3
direct quantization finish
Test set: Quant Model Accuracy: 80.86%
FLOAT_5_E3: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.058782
PTQ: FLOAT_6_E1
direct quantization finish
Test set: Quant Model Accuracy: 74.17%
FLOAT_6_E1: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.136655
PTQ: FLOAT_6_E2
direct quantization finish
Test set: Quant Model Accuracy: 84.28%
FLOAT_6_E2: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.018973
PTQ: FLOAT_6_E3
direct quantization finish
Test set: Quant Model Accuracy: 84.81%
FLOAT_6_E3: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.012804
PTQ: FLOAT_6_E4
direct quantization finish
Test set: Quant Model Accuracy: 78.06%
FLOAT_6_E4: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.091375
PTQ: FLOAT_7_E1
direct quantization finish
Test set: Quant Model Accuracy: 76.20%
FLOAT_7_E1: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.113025
PTQ: FLOAT_7_E2
direct quantization finish
Test set: Quant Model Accuracy: 84.83%
FLOAT_7_E2: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.012571
PTQ: FLOAT_7_E3
direct quantization finish
Test set: Quant Model Accuracy: 85.55%
FLOAT_7_E3: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.004190
PTQ: FLOAT_7_E4
direct quantization finish
Test set: Quant Model Accuracy: 82.00%
FLOAT_7_E4: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.045513
PTQ: FLOAT_7_E5
direct quantization finish
Test set: Quant Model Accuracy: 10.00%
FLOAT_7_E5: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.883599
PTQ: FLOAT_8_E1
direct quantization finish
Test set: Quant Model Accuracy: 77.39%
FLOAT_8_E1: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.099174
PTQ: FLOAT_8_E2
direct quantization finish
Test set: Quant Model Accuracy: 85.21%
FLOAT_8_E2: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.008148
PTQ: FLOAT_8_E3
direct quantization finish
Test set: Quant Model Accuracy: 86.00%
FLOAT_8_E3: js_flops: 0.000000 js_param: 0.000000 acc_loss: -0.001048
PTQ: FLOAT_8_E4
direct quantization finish
Test set: Quant Model Accuracy: 83.26%
FLOAT_8_E4: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.030846
PTQ: FLOAT_8_E5
direct quantization finish
Test set: Quant Model Accuracy: 10.02%
FLOAT_8_E5: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.883366
PTQ: FLOAT_8_E6
direct quantization finish
Test set: Quant Model Accuracy: 13.09%
FLOAT_8_E6: js_flops: 0.000000 js_param: 0.000000 acc_loss: 0.847631
```
我在加上正确的QElementwiseAdd层前,PTQ后的acc都不超过15%,足以见到该层的重要性,他是负责残差的相加部分,因为两个层的输出结果是在不同量化范围,所以不能直接相加,而是需要做rescale。
目前看到INT量化随位宽增加而先增大后下降,我查看了量化后的参数分布,其整体趋势与全精度模型是较为相似的,因此问题不在Conv,BN等普通的量化层上,我猜想可能是因为量化位宽较大的时候,QElementwiseAdd做rescale的过程中出现了溢出,还有待后续观察确认。
\ No newline at end of file
# -*- coding: utf-8 -*-
import numpy
import numpy as np
import torch
import sys
from mmd_loss import *
from collections import OrderedDict
d1 = sys.argv[1] # bit
d2 = sys.argv[2] # epoch
# d1=4
# d2=5
sum=0
flag=0
total_quan_list=list()
total_base_list=list()
# CNN FLOPs = Cout * Hout * Wout * (2 * Cin * K * K ) 是考虑bias 否则-1
# FCN FLOPs = Cout * Cin 是考虑bias 否则-1
# 把相关的relu,pool也考虑进去了
# MAdd
# weight0 =np.array( [ 705600.0+4704.0+ 3528.0 , 480000.0+ 1600.0 + 1200.0 , 95880.0 + 120.0,
# 20076.0 + 84.0 , 1670.0 ])
# weight1=np.array([705,600.0 , 480,000.0,+ 95,880.0 ,
# 20,076.0 , 1,670.0 ])
# flops
weight_f0= np.array([357504+4704+4704, 241600+1600+1600,48000+120,10080+84,840])
weight_f1=np.array([357504, 241600,48000,10080,840])
summary_quan_dict=OrderedDict()
summary_base_dict=OrderedDict()
losses=[]
# 最外层:不同epoch的字典 内层:各个网络层的grads
for i in range(int(d2)):
total_quan_list.append(torch.load('./project/p/checkpoint/cifar-10_lenet_bn_quant/' + str(d1) + '/ckpt_cifar-10_lenet_bn_quant_'+str(i+1)+'.pth'))
#total_quan_list.append(torch.load('checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(d2) + '.pth'))
total_base_list.append(torch.load('./project/p/checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(i+1) + '.pth'))
for k, _ in total_base_list[i]['grads'].items():
if flag == 0:
summary_quan_dict[k] = total_quan_list[i]['grads'][k].reshape(1,-1)
summary_base_dict[k] = total_base_list[i]['grads'][k].reshape(1,-1)
else :
# 字典里的数据不能直接改,需要重新赋值
a=summary_quan_dict[k]
b=total_quan_list[i]['grads'][k].reshape(1,-1)
c=np.vstack((a,b))
summary_quan_dict[k] = c
a = summary_base_dict[k]
b = total_base_list[i]['grads'][k].reshape(1,-1)
c = np.vstack((a, b))
summary_base_dict[k] = c
flag = 1
cnt = 0
flag = 0
for k, _ in summary_quan_dict.items():
if flag == 0:
sum += 0.99*weight_f1[cnt] * MK_MMD(source=summary_base_dict[k], target=summary_quan_dict[k]) # weight
else:
sum += 0.01*weight_f1[cnt] * MK_MMD(source=summary_base_dict[k], target=summary_quan_dict[k]) #bias
if flag == 1:
cnt = cnt + 1
flag = 0
else:
flag=1
sum=sum/(weight_f0.sum()*2)
print(sum)
f = open('./project/p/lenet_ptq_similarity.txt','a')
f.write('bit:' + str(d1) + ' epoch_num:' + str(d2) +': '+str(sum)+'\n')
f.close()
# for k,v in summary_base_dict.items():
# if k== 'conv_layers.conv1.weight':
# print(v)
# print('===========')
# print(summary_quan_dict[k])
\ No newline at end of file
# -*- coding: utf-8 -*-
import numpy
import numpy as np
import torch
import sys
from collections import OrderedDict
import scipy.stats
import pandas as pd
import os
# 整体思路: 本函数实现的是关于bit的,在不同epoch节点(5, 10, ...) 的梯度分布相似度计算 (考虑到是不同epoch节点,则需要在这一段epoch内取平均相似度?)
# 外界调用: 会用不同的bit分别调用该函数
# csv中每行记录的是该bit量化情况下,不同epoch节点的平均加权梯度分布相似度
#
d1 = sys.argv[1] # bit
d2 = sys.argv[2] # mode
d3 = sys.argv[3] # n_exp
# d2 = sys.argv[2] # epoch
# d1=4
# d2=5
tag = 0
dirpath = './project/p/qat_analysis_data/mode' + str(d2)
if not os.path.isdir(dirpath):
os.makedirs(dirpath, mode=0o777)
os.chmod(dirpath, mode=0o777)
# if int(d2) == 1:
# csvpath = './project/p/qat_analysis_data/wasserstein_distance.csv'
# else:
if int(d2) != 3:
csvpath = './project/p/qat_analysis_data/mode' + str(d2) + '/wasserstein_distance.csv'
else:
csvpath = './project/p/qat_analysis_data/mode' + str(d2) + '/wasserstein_distance_' + str(d3) + '.csv'
# if os.path.exists("./qat_analysis_data/wasserstein_distance.csv"):
if os.path.exists(csvpath):
tag = 1
if tag == 0: # 还没有csv
df = pd.DataFrame()
else: # 已有csv
# df = pd.read_csv("./qat_analysis_data/wasserstein_distance.csv", index_col=0)
df = pd.read_csv(csvpath, index_col=0)
df2 = pd.DataFrame()
# CNN FLOPs = Cout * Hout * Wout * (2 * Cin * K * K ) 是考虑bias 否则-1
# FCN FLOPs = Cout * Cin 是考虑bias 否则-1
# 把相关的relu,pool也考虑进去了
# MAdd
# weight0 =np.array( [ 705600.0+4704.0+ 3528.0 , 480000.0+ 1600.0 + 1200.0 , 95880.0 + 120.0,
# 20076.0 + 84.0 , 1670.0 ])
# weight1=np.array([705,600.0 , 480,000.0,+ 95,880.0 ,
# 20,076.0 , 1,670.0 ])
# flops
weight_f0= np.array([357504+4704+4704, 241600+1600+1600,48000+120,10080+84,840])
weight_f1=np.array([357504, 241600,48000,10080,840])
# 对不同的epoch节点
for epoch in [5, 10, 15, 20, 25, 30]:
total_quan_list = []
total_base_list = []
summary_quan_dict = OrderedDict()
summary_base_dict = OrderedDict()
flag = 0
result = 0
# 最外层:不同epoch的字典 内层:各个网络层的grads
# 遍历epoch节点内的epoch,收集梯度信息
for i in range(epoch):
if int(d2) == 1:
total_quan_list.append(torch.load(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/' + str(d1) + '/ckpt_cifar-10_lenet_bn_quant_' + str(
i + 1) + '.pth'))
elif int(d2) == 2:
total_quan_list.append(torch.load(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/mode' + str(d2) + '/' + str(d1)+ '/ckpt_cifar-10_lenet_bn_quant_' + str(
epoch) + '.pth'))
else:
total_quan_list.append(torch.load(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/mode' + str(d2) + '_' + str(d3) + '/' + str(d1)+ '/ckpt_cifar-10_lenet_bn_quant_' + str(
epoch) + '.pth'))
# total_quan_list.append(torch.load('checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(d2) + '.pth'))
# full的数据数不够
total_base_list.append(
torch.load('./project/p/checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(i + 1) + '.pth'))
for k, _ in total_base_list[i]['grads'].items(): # 得到每个epoch i 的各个层的梯度
if flag == 0: # 读的第一个epoch i 要新建立个数据矩阵的第一行,后续的epoch i都是在这行的基础上向下拓展
summary_quan_dict[k] = total_quan_list[i]['grads'][k].reshape(1, -1)
summary_base_dict[k] = total_base_list[i]['grads'][k].reshape(1, -1)
else:
# 字典里的数据不能直接改,需要重新赋值
a = summary_quan_dict[k]
b = total_quan_list[i]['grads'][k].reshape(1, -1)
c = np.vstack((a, b))
summary_quan_dict[k] = c
a = summary_base_dict[k]
b = total_base_list[i]['grads'][k].reshape(1, -1)
c = np.vstack((a, b))
summary_base_dict[k] = c
flag = 1
# loss = total_quan_list[i]['losses']
# print(loss)
# df = pd.read_csv('./data_analysis_folder/data.csv', index_col=0)
# # df = pd.DataFrame()
# df2 = pd.DataFrame()
# 上面是在收集数据,下面才是求和
for j in range(epoch):
flag0 = 0 # 各个layer的weight和bias
cnt = 0 # 依次遍历各个layer
sum = 0 # sum只是对一个epoch j 的加权梯度分布相似度记录
for k, _ in summary_quan_dict.items():
w = summary_base_dict[k][j, :] # 这里不合适 要改造
v = summary_quan_dict[k][j, :]
if flag0 == 0:
cur_weight = weight_f1[cnt] * scipy.stats.wasserstein_distance(w, v) # weight
# 不是很方便存 需要三维了(sheet)
# if tag == 1:
# df2[k] = [cur_weight]
# else:
# df[k] = [cur_weight]
sum += 0.99 * cur_weight
else:
cur_bias = weight_f1[cnt] * scipy.stats.wasserstein_distance(w, v) # bias
# if tag == 1:
# df2[k] = [cur_bias]
# else:
# df[k] = [cur_bias]
sum += 0.01 * cur_bias
if flag0 == 1:
cnt = cnt + 1
flag0 = 0
else:
flag0 = 1
sum = sum / (weight_f1.sum() * 2)
result += sum # 对各个epoch i的加权梯度相似度求和
print(sum)
result /= epoch # 对epoch节点阶段内的梯度相似度求平均
if tag == 1:
df2[str(epoch)] = [result]
else :
df[str(epoch)] = [result]
result = 0
if tag == 1 :
df = df.append(df2)
# df.to_csv('./qat_analysis_data/wasserstein_distance.csv')
df.to_csv(csvpath)
else :
# df.to_csv('./qat_analysis_data/wasserstein_distance.csv')
df.to_csv(csvpath)
# f = open('lenet_ptq_wasserstein_similarity.txt','a')
# f.write('bit:' + str(d1) + ' epoch_num:' + str(d2) +': '+str(sum)+'\n')
# f.close()
# -*- coding: utf-8 -*-
import numpy
import numpy as np
import torch
import sys
from collections import OrderedDict
import scipy.stats
import pandas as pd
from model import *
# from audtorch.metrics.functional import pearsonr
import math
# 该函数用于读出全精度、量化模型的weight和bias值,以作观察
if __name__ == "__main__":
d1 = sys.argv[1]
# d2 = sys.argv[2]
# d1=8
# df = pd.read_csv('./ptq_analysis_data/seperate_data.csv', index_col=0)
df = pd.DataFrame()
# df2 = pd.DataFrame()
base_data = torch.load('./project/p/ckpt/trail/model_trail.pt')
# checkpoint_data = torch.load('./project/p/ckpt/trail/model_trail.pt')
print('full_precision weight/bias loaded!')
checkpoint_dir = './project/p/checkpoint/cifar-10_trail_model'
# quan_data = torch.load('ckpt/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt')
# print('quantization bit ' + str(d1) + ' weight/bias loaded!')
sum=0
if int(d1) == 1:
print(base_data)
# for k, _ in base_data.items():
# base_data[k] = base_data[k].reshape(1, -1)
# # quan_data[k] = quan_data[k].reshape(1, -1)
# print(base_data[k])
else:
for i in [4,9,14,19]:
check_data = torch.load(checkpoint_dir + '/ckpt_cifar-10_trail_model%s.pt' % (str(i)))
print(check_data)
# if int(d2) == 1:
# print(base_data[k])
# else:
# print(quan_data[k])
# -*- coding: utf-8 -*-
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam): # 有qparam i.e. self 中记录的mode、scale、zeropoint、n_exp等信息,其实不用再额外传参
x = qparam.quantize_tensor(x, qparam.mode) # INT
x = qparam.dequantize_tensor(x, qparam.mode) # FP(int)
return x
@staticmethod
def backward(ctx, grad_output): # 用线性粗略近似 STE
return grad_output, None
\ No newline at end of file
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
def get_model_histogram(model):
"""
Description:
- get norm gradients from model, and store in a OrderDict
Args:
- model: (torch.nn.Module), torch model
Returns:
- grads in OrderDict
"""
gradshisto = OrderedDict()
grads = OrderedDict()
for name, params in model.named_parameters():
grad = params.grad
if grad is not None:
tmp = {}
params_np = grad.cpu().numpy()
histogram, bins = np.histogram(params_np.flatten(),bins=20)
tmp['histogram'] = list(histogram)
tmp['bins'] = list(bins)
gradshisto[name] = tmp
grads[name] = params_np
return gradshisto,grads
def get_model_norm_gradient(model):
"""
Description:
- get norm gradients from model, and store in a OrderDict
Args:
- model: (torch.nn.Module), torch model
Returns:
- grads in OrderDict
"""
grads = OrderedDict()
for name, params in model.named_parameters():
grad = params.grad
if grad is not None:
grads[name] = grad.norm().item()
return grads
def get_grad_histogram(grads_sum):
gradshisto = OrderedDict()
# grads = OrderedDict()
for name, params in grads_sum.items():
grad = params
if grad is not None:
tmp = {}
#params_np = grad.cpu().numpy()
params_np = grad
histogram, bins = np.histogram(params_np.flatten(),bins=20)
tmp['histogram'] = list(histogram)
tmp['bins'] = list(bins)
gradshisto[name] = tmp #每层一个histogram (tmp中的是描述直方图的信息)
# grads[name] = params_np
return gradshisto
\ No newline at end of file
# -*- coding: utf-8 -*-
import numpy
import numpy as np
import torch
import sys
from collections import OrderedDict
import scipy.stats
import pandas as pd
import os
import os.path
#
d1 = sys.argv[1] # bit
d2 = sys.argv[2] # mode
d3 = sys.argv[3] # n_exp
# d1=2
# d2 = sys.argv[2] # epoch
# d1=2
# d2=3
sum=0
flag=0
# CNN FLOPs = Cout * Hout * Wout * (2 * Cin * K * K ) 是考虑bias 否则-1
# FCN FLOPs = Cout * Cin 是考虑bias 否则-1
# 把相关的relu,pool也考虑进去了
# MAdd
# weight0 =np.array( [ 705600.0+4704.0+ 3528.0 , 480000.0+ 1600.0 + 1200.0 , 95880.0 + 120.0,
# 20076.0 + 84.0 , 1670.0 ])
# weight1=np.array([705,600.0 , 480,000.0,+ 95,880.0 ,
# 20,076.0 , 1,670.0 ])
# flops
weight_f0= np.array([357504+4704+4704, 241600+1600+1600,48000+120,10080+84,840])
weight_f1=np.array([357504, 241600,48000,10080,840])
summary_quan_dict=OrderedDict()
summary_base_dict=OrderedDict()
# 最外层:不同epoch的字典 内层:各个网络层的grads
flag = 0
dirpath = './project/p/qat_analysis_data/mode' + str(d2)
if not os.path.isdir(dirpath):
os.makedirs(dirpath, mode=0o777)
os.chmod(dirpath, mode=0o777)
if int(d2) == 1 or int(d2) == 2:
csvpath = dirpath + '/scratch_loss.csv'
else:
csvpath = dirpath + '/scratch_loss_' + str(d3) + '.csv'
if os.path.exists(csvpath):
flag = 1
if flag == 0: # 还没有csv
df = pd.DataFrame()
else: # 已有csv
df = pd.read_csv(csvpath, index_col=0)
df2 = pd.DataFrame()
for epoch in ([5, 10, 15, 20, 25, 30]):
sums = []
total_quan_list = []
total_base_list = []
for i in range(int(epoch)):
if int(d2) == 1:
total_quan_list.append(torch.load(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/' + str(d1) + '/ckpt_cifar-10_lenet_bn_quant_' + str(
i + 1) + '.pth'))
elif int(d2) == 2:
total_quan_list.append(torch.load(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/mode' + str(d2) + '/' + str(
d1) + '/ckpt_cifar-10_lenet_bn_quant_' + str(
i + 1) + '.pth'))
else:
total_quan_list.append(torch.load(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/mode' + str(d2) + '_' + str(d3) + '/' + str(
d1) + '/ckpt_cifar-10_lenet_bn_quant_' + str(
i + 1) + '.pth'))
sum_loss = 0
loss = total_quan_list[i]['losses']
# print(len(loss))
# 每个epoch的不同batch的
for j in range(len(loss)):
sum_loss += loss[j].cpu()
# print(sum_loss)
sum_loss /= j
sums.append(sum_loss)
# print(sums)
#print(sums[0] - sums[int(d1) - 1])
if flag == 0:
df[str(epoch)] = [(sums[0] - sums[int(epoch) - 1]).detach().numpy()]
else:
df2[str(epoch)] = [(sums[0] - sums[int(epoch) - 1]).detach().numpy()]
if flag == 0:
# df.to_csv('./qat_analysis_data/scratch_loss.csv')
df.to_csv(csvpath)
else:
df = df.append(df2)
# df.to_csv('./qat_analysis_data/scratch_loss.csv')
df.to_csv(csvpath)
\ No newline at end of file
# -*- coding: utf-8 -*-
from torch.serialization import load
from model import *
import argparse
import torch
import sys
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
# 为了得到PTQ的权重数据的伪量化版 (先quantize再dequantize,与full precision的权重数据分布相似,便于用wasserstein距离求相似度)
def direct_quantize(model, test_loader, device):
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_forward(data) # 这里会依次调用model中各个层的forward,则会update qw
if i % 5000 == 0:
break
print('direct quantization finish')
def full_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.4f}%\n'.format(100. * correct / len(test_loader.dataset)))
def quantize_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_inference(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100. * correct / len(test_loader.dataset)
print('\nTest set: Quant Model Accuracy: {:.4f}%\n'.format(acc))
return acc
if __name__ == "__main__":
d1 = sys.argv[1]
batch_size = 32
using_bn = True
load_quant_model_file = None
# load_model_file = None
net = 'LeNet' # 1:
acc = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
if using_bn:
model = LeNet().to(device)
# 生成梯度分布图的时候是从0开始训练的
model.load_state_dict(torch.load('ckpt/cifar-10_lenet_bn.pt', map_location='cpu'))
# else:
# model = Net()
# model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu'))
# save_file = "ckpt/mnist_cnn_ptq.pt"
# model.to(device)
model.eval()
full_inference(model, test_loader, device)
num_bits = int(d1)
model.quantize(num_bits=num_bits)
model.eval()
print('Quantization bit: %d' % num_bits)
dir_name = './ptq_fake_log/' + 'quant_bit_' + str(d1) + '_log'
if not os.path.isdir(dir_name):
os.makedirs(dir_name, mode=0o777)
os.chmod(dir_name, mode=0o777)
qwriter = SummaryWriter(log_dir=dir_name)
# for name, param in model.named_parameters():
# qwriter.add_histogram(tag=name + '_data', values=param.data)
if load_quant_model_file is not None:
model.load_state_dict(torch.load(load_quant_model_file))
print("Successfully load quantized model %s" % load_quant_model_file)
direct_quantize(model, train_loader, device)
model.fakefreeze() # 权重量化
for name, param in model.named_parameters():
qwriter.add_histogram(tag=name + '_data', values=param.data)
dir_name ='ckpt/ptq_fakefreeze'
if not os.path.isdir(dir_name):
os.makedirs(dir_name, mode=0o777)
os.chmod(dir_name, mode=0o777)
save_file = 'ckpt/ptq_fakefreeze/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt'
torch.save(model.state_dict(), save_file)
# -*- coding: utf-8 -*-
from torch.serialization import load
# from model import *
import argparse
import torch
import sys
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
from resnet import *
def direct_quantize(model, test_loader, device):
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_forward(data) # 这里会依次调用model中各个层的forward,则会update qw
if i % 5000 == 0:
break
print('direct quantization finish')
def full_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.4f}%\n'.format(100. * correct / len(test_loader.dataset)))
def quantize_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_inference(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100. * correct / len(test_loader.dataset)
print('\nTest set: Quant Model Accuracy: {:.4f}%\n'.format(acc))
return acc
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='PTQ Training')
parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
parser.add_argument('-n','--num_bits', default=8, type=int, metavar='BITS', help='number of bits')
parser.add_argument('-t','--mode', default=1, type=int, metavar='MODES', help='PTQ mode(1:INT 2:PoT 3:FP)')
parser.add_argument('-e','--n_exp', default=4, type=int, metavar='N_EXP', help='number of exp')
# d1 = sys.argv[1] # num_bits
# d2 = sys.argv[2] # mode
# d3 = sys.argv[3] # n_exp
# d1 = 8
# d2 = 3
# d3 = 4
args = parser.parse_args()
d1 = args.num_bits
d2 = args.mode
d3 = args.n_exp
batch_size = 128
using_bn = True
load_quant_model_file = None
# load_model_file = None
acc = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./project/p/data', train=True, download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./project/p/data', train=False, download=False ,transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
if using_bn:
# model = LeNet(n_exp=int(d3), mode=int(d2)).to(device)
if args.model == 'resnet18' :
model = resnet18(n_exp=int(d3), mode=int(d2)).to(device)
elif args.model == 'resnet50' :
model = resnet50(n_exp=int(d3), mode=int(d2)).to(device)
elif args.model == 'resnet152' :
model = resnet152(n_exp=int(d3), mode=int(d2)).to(device)
elif args.model == 'LeNet' :
model = LeNet(n_exp=int(d3), mode=int(d2)).to(device)
elif args.model == 'NetBN' :
model = NetBN().to(device)
# model = resnet18(n_exp=int(d3), mode=int(d2)).to(device)
# 生成梯度分布图的时候是从0开始训练的
# model.load_state_dict(torch.load('./project/p/ckpt/cifar-10_lenet_bn.pt', map_location='cpu'))
model.load_state_dict(torch.load('./project/p/ckpt/' + args.model + '/' + args.model + '.pt', map_location='cpu'))
# else:
# model = Net()
# model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu'))
# save_file = "ckpt/mnist_cnn_ptq.pt"
# model.to(device)
model.eval()
full_inference(model, test_loader, device)
full_writer = SummaryWriter(log_dir='./project/p/' + args.model +'/ptqlog_mode' + str(d2) + '/' + str(d3) + '/' + 'full_log')
for name, param in model.named_parameters():
full_writer.add_histogram(tag=name + '_data', values=param.data)
num_bits = int(d1)
model.quantize(num_bits=num_bits)
model.eval()
print('Quantization bit: %d' % num_bits)
writer = SummaryWriter(log_dir='./project/p/'+ args.model + '/ptqlog_mode' + str(d2) + '/' + str(d3) + '/' + 'quant_bit_' + str(d1) + '_log')
if load_quant_model_file is not None:
model.load_state_dict(torch.load(load_quant_model_file))
print("Successfully load quantized model %s" % load_quant_model_file)
direct_quantize(model, train_loader, device)
model.freeze() # 权重量化
for name, param in model.named_parameters():
writer.add_histogram(tag=name + '_data', values=param.data)
# 原PTQ mode=1时
# save_file = 'ckpt/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt'
dir_name ='./project/p/ckpt/' + args.model + '/mode'+ str(d2) + '_' + str(d3) + '/ptq'
if not os.path.isdir(dir_name):
os.makedirs(dir_name, mode=0o777)
os.chmod(dir_name, mode=0o777)
save_file = './project/p/ckpt/' + args.model + '/mode'+ str(d2) + '_' + str(d3) + '/ptq' + '/cifar-10_' + args.model + '_ptq_' + str(d1) + '_.pt'
torch.save(model.state_dict(), save_file)
# 测试是否设备转移是否正确
# model.cuda()
# print(model.qconv1.M.device)
# model.cpu()
# print(model.qconv1.M.device)
acc = quantize_inference(model, test_loader, device)
f = open('./project/p/' + args.model + '_ptq_acc' + '.txt', 'a')
f.write('bit ' + str(d1) + ': ' + str(acc) + '\n')
f.close()
# -*- coding: utf-8 -*-
from model import *
from get_weight import *
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
import sys
import time
# import matplotlib.pyplot as plt
# import matplotlib
from torchvision.datasets import ImageFolder
from torch.utils.tensorboard import SummaryWriter
from absl import app, flags
# from easydict import EasyDict
# from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
# from cleverhans.torch.attacks.projected_gradient_descent import (
# projected_gradient_descent,
# )
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
flag = 0
cnt = 0
for batch_idx, (data, target) in enumerate(train_loader):
cnt = cnt + 1
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
histo, grads = (get_model_histogram(model))
if flag == 0:
flag = 1
grads_sum = grads
else:
for k,v in grads_sum.items():
grads_sum[k] += grads[k]
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
for k, v in grads_sum.items():
grads_sum[k] = v / len(train_loader.dataset)
return grads_sum
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
acc=0
lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
# report = EasyDict(nb_test=0, correct=0, correct_fgm=0, correct_pgd=0)
with torch.no_grad:
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# x_fgm = fast_gradient_method(model, data, 0.01, np.inf)
# x_pgd = projected_gradient_descent(model, data, 0.01, 0.01, 40, np.inf)
# model prediction on clean examples
# _, y_pred = model(data).max(1)
# # model prediction on FGM adversarial examples
# _, y_pred_fgm = model(x_fgm).max(1)
#
# # model prediction on PGD adversarial examples
# _, y_pred_pgd = model(x_pgd).max(1)
# report.nb_test += target.size(0)
# report.correct += y_pred.eq(target).sum().item()
# report.correct_fgm += y_pred_fgm.eq(target).sum().item()
# report.correct_pgd += y_pred_pgd.eq(target).sum().item()
test_loss += lossLayer(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc=100. * correct / len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format(
test_loss, acc
))
# print(
# "test acc on clean examples (%): {:.3f}".format(
# report.correct / report.nb_test * 100.0
# )
# )
# print(
# "test acc on FGM adversarial examples (%): {:.3f}".format(
# report.correct_fgm / report.nb_test * 100.0
# )
# )
# print(
# "test acc on PGD adversarial examples (%): {:.3f}".format(
# report.correct_pgd / report.nb_test * 100.0
# )
# )
return acc
batch_size = 32
test_batch_size = 32
seed = 1
# epochs = 15
d1 = sys.argv[1]
epochs = int(d1)
lr = 0.001
momentum = 0.5
save_model = False
using_bn = True
net = 'LeNet'
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
)
#if using_bn:
if (net == 'VGG19') == True:
model = VGG_19().to(device)
elif (net == 'LeNet') == True:
model = LeNet().to(device)
# else:
# model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
writer = SummaryWriter(log_dir='./fullprecision_log')
#optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9,0.999),eps=1e-08,weight_decay=0,amsgrad=False)
for epoch in range(1, epochs + 1):
grads_sum = train(model, device, train_loader, optimizer, epoch)
acc = test(model, device, test_loader)
print('epoch:', epoch)
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'grads': grads_sum,
'accuracy':acc
}
# for name, param in model.named_parameters():
# writer.add_histogram(tag=name + '_grad', values=param.grad, global_step=epoch)
# writer.add_histogram(tag=name + '_data', values=param.data, global_step=epoch)
for name, param in grads_sum.items():
# 此处的grad是累加值吧 不是平均值
writer.add_histogram(tag=name + '_grad', values=param, global_step=epoch)
# 取这个epoch最后一个batch算完之后的weight
for name, param in model.named_parameters():
writer.add_histogram(tag=name + '_data', values=param.data, global_step=epoch)
if (net == 'LeNet') == True:
torch.save(checkpoint, 'checkpoint/cifar-10_lenet_bn/full/ckpt_cifar-10_lenet_bn_%s.pth' % (str(epoch)))
#保存参数
# if (net == 'VGG19') == True:
# torch.save(checkpoint, 'checkpoint/cifar-10_vgg19_bn/ckpt_cifar-10_vgg19_bn_%s.pth' % (str(epoch)))
# elif (net == 'LeNet') == True:
# torch.save(checkpoint, 'checkpoint/cifar-10_lenet_bn/ckpt_cifar-10_lenet_bn_%s.pth' % (str(epoch)))
#print('Saved all parameters!\n')
if save_model:
if not osp.exists('ckpt'):
os.makedirs('ckpt')
#if using_bn:
if (net == 'VGG19') == True:
torch.save(model.state_dict(), 'ckpt/cifar-10_vgg19_bn.pt')
elif (net == 'LeNet') == True:
torch.save(model.state_dict(), 'ckpt/cifar-10_lenet_bn.pt')
# else:
# torch.save(model.state_dict(), 'ckpt/cifar-10_vgg19.pt')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment