Commit 8352e4cf by Zhihong Ma
parents 2b7aac0c c5562a31
ykl/AlexNet/image/AlexNet_table.png

21.4 KB | W: | H:

ykl/AlexNet/image/AlexNet_table.png

20.9 KB | W: | H:

ykl/AlexNet/image/AlexNet_table.png
ykl/AlexNet/image/AlexNet_table.png
ykl/AlexNet/image/AlexNet_table.png
ykl/AlexNet/image/AlexNet_table.png
  • 2-up
  • Swipe
  • Onion skin
ykl/AlexNet/image/flops.png

33.6 KB | W: | H:

ykl/AlexNet/image/flops.png

33.7 KB | W: | H:

ykl/AlexNet/image/flops.png
ykl/AlexNet/image/flops.png
ykl/AlexNet/image/flops.png
ykl/AlexNet/image/flops.png
  • 2-up
  • Swipe
  • Onion skin
ykl/AlexNet/image/param.png

35.5 KB | W: | H:

ykl/AlexNet/image/param.png

32.2 KB | W: | H:

ykl/AlexNet/image/param.png
ykl/AlexNet/image/param.png
ykl/AlexNet/image/param.png
ykl/AlexNet/image/param.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -8,6 +8,19 @@ from torch.autograd import Variable
from function import FakeQuantize
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
# 获取最近的量化值
def get_nearest_val(quant_type,x,is_bias=False):
if quant_type=='INT':
......@@ -64,7 +77,7 @@ def bias_qmax(quant_type):
elif quant_type == 'POT':
return get_qmax(quant_type)
else:
return get_qmax(quant_type, 16, 5)
return get_qmax(quant_type, 16, 7)
# 转化为FP32,不需再做限制
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
7507.750630903518 2739.696686260571 602.5610368972597 140.92196362794522 34.51723630314398 8.518501248514761 2.1353880852742875 0.5319393032307673 0.13161172461255044 0.03248745580294116 0.008041228126669929 0.002041263178018999 0.0004401365998344318 0.0001238196358571173 2.6684157208189957e-07 7507.6677482789855 1654.3776790768084 136.73977548990493 134.5782553082339 134.57822914515503 134.57815728826475 134.57822793682794 1054.343125709169 244.48315063136482 247.8970440086859 87.65672287018292 89.63829780248474 37.952907162652885 48.439500893999075 50.122513003094504 9.763710160467824 37.67667145925756 37.08251159654393 37.1627098092451 2.504499223445029 9.660234735230102 37.70545171879492 33.31639481832546 32.12035369527305 0.6541863956484224 2.4420428195112773 9.68812852009895 37.70545171879492
7596.158354929403 3342.630224687211 803.7421515317819 193.09625855378914 47.294430438986815 11.617764000749704 2.8409912757716502 0.7002681399367439 0.17784808389949117 0.0436832697576101 0.011038433846755926 0.0027589181015611444 0.0006869155287621905 0.0001593813201971203 0.0001203343519776774 7596.100728242634 2133.4439203566753 134.29969268901075 130.9402314639402 130.94019523793065 130.9402499980709 130.94030592234444 1277.7023613784083 289.82743245674186 302.10808960554925 92.40277930093731 107.5666593496828 36.700788172354116 44.407474473657686 60.09530043518254 9.507477271994002 36.255231843587175 31.324843649420213 44.77256767415278 2.4592434345083602 9.346483017228227 36.2552369826946 27.3682985947986 38.75922946625283 0.6654846714873894 2.3614977284616923 9.346479578793366 36.25830464235682
js_param_list:
2099.4012978223045 756.8852841760424 165.48611276543275 38.661077867143234 9.465502346081184 2.3380773681828533 0.58692184552745 0.14620052066785197 0.03612237294214898 0.008920127583145731 0.002220705680193553 0.0005652568568322304 0.00012068297675722306 4.6465238698738316e-05 3.978560588757452e-06 2099.3806976929545 455.6542423459878 38.279215341990685 37.69143863173502 37.69138152272717 37.691410138985525 37.6914533221459 292.0013646264601 67.95022753681195 68.50200422208128 24.58495282264393 24.797314677833658 10.623403633392762 13.713035034152393 13.858532642027827 2.732777962743172 10.54799818234084 10.549391145072232 10.275098560105489 0.7004983113671814 2.704403793583995 10.559890009589033 9.493248004294491 8.881733731714347 0.1827908118651683 0.6834420451874454 2.716044028457789 10.559890009589033
2124.6901484547043 924.4573506960227 220.4801169201279 52.8504167013976 12.953296497807385 3.1853253087773608 0.7775012981543739 0.19165897232008872 0.04881389226260701 0.011927057061170981 0.003025213095666009 0.0007628906685143884 0.00019075741735802466 4.5081513290070486e-05 3.399542619773311e-05 2124.672207113351 587.9132319943658 37.55091744333372 36.635514313284844 36.63549899741198 36.63550255682983 36.63554387096417 352.76533450990684 80.14627267392913 83.49214909064641 25.77374926912755 29.78985840193075 10.26657509432266 12.541439619277627 16.648877037050198 2.658753685911974 10.145569681413992 8.911419624239798 12.401484350536222 0.6875383356517406 2.6148026740520978 10.145568951377017 7.808360107452724 10.735180287226289 0.18568651930732344 0.6608485065237775 2.614801761970877 10.151356149109917
ptq_acc_list:
10.0 10.09 56.08 77.58 83.1 84.84 84.88 85.06 85.06 85.11 85.07 85.08 85.08 85.08 85.08 10.0 13.0 71.41 73.08 72.96 72.65 73.37 24.5 66.46 51.17 77.72 77.3 82.21 81.77 81.53 84.03 81.85 81.93 82.88 84.83 84.21 51.75 82.91 83.36 85.13 84.77 59.81 51.53
10.0 10.1 54.63 76.21 84.14 85.66 86.1 86.1 86.12 86.09 86.09 86.08 86.08 86.05 86.08 10.0 13.77 74.96 75.06 74.89 74.91 74.79 20.04 67.57 50.1 77.82 79.72 82.7 81.61 83.88 85.78 82.82 82.84 84.97 86.16 85.76 82.66 82.83 85.16 86.17 86.22 85.95 82.78
acc_loss_list:
0.8824635637047484 0.8814057357780911 0.34085566525622946 0.08815232722143865 0.02327221438645985 0.0028208744710859768 0.002350728725905064 0.0002350728725904563 0.0002350728725904563 -0.00035260930888576797 0.00011753643629531167 0.0 0.0 0.0 0.0 0.8824635637047484 0.847202632816173 0.16067230841560887 0.14104372355430184 0.1424541607898449 0.14609779031499756 0.13763516690173946 0.7120357310766338 0.2188528443817584 0.3985660554771979 0.08650681711330512 0.0914433474377057 0.03373295721673724 0.038904560413728285 0.04172543488481426 0.012341325811001377 0.03796426892336629 0.03702397743300413 0.02585801598495537 0.0029384109073812884 0.010225669957686936 0.39174894217207334 0.0255054066760696 0.02021626704278325 -0.0005876821814762242 0.0036436295251528242 0.2970145745181006 0.3943347437705689
0.8838289962825279 0.8826672862453532 0.3653578066914498 0.11466078066914503 0.022537174721189566 0.004879182156133849 -0.000232342007434898 -0.000232342007434898 -0.00046468401486996115 -0.00011617100371753155 -0.00011617100371753155 0.0 0.0 0.00034851301115242957 0.0 0.8838289962825279 0.8400325278810409 0.12918215613382905 0.12802044609665422 0.12999535315985128 0.12976301115241637 0.13115706319702594 0.7671933085501857 0.21503252788104096 0.41798327137546465 0.09595724907063204 0.07388475836431226 0.03926579925650552 0.05192843866171003 0.0255576208178439 0.0034851301115241306 0.03787174721189597 0.03763940520446091 0.012894981412639398 -0.0009293680297397572 0.0037174721189590287 0.03973048327137549 0.03775557620817844 0.010687732342007455 -0.0010455390334572886 -0.0016263940520446162 0.0015102230483270847 0.03833643122676577
......@@ -46,14 +46,9 @@ def test(model, device, test_loader):
if __name__ == "__main__":
batch_size = 32
test_batch_size = 32
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
epochs_cfg = [20, 30, 20, 20, 10]
lr_cfg = [0.01, 0.005, 0.001, 0.0005, 0.0001]
momentum = 0.5
save_model = True
......@@ -62,7 +57,7 @@ if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
......@@ -73,31 +68,24 @@ if __name__ == "__main__":
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = AlexNet().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
epoch_start = 1
for epochs,lr in zip(epochs_cfg,lr_cfg):
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
epoch_end = epoch_start+epochs
for epoch in range(epoch_start,epoch_end):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
epoch_start += epochs
if save_model:
if not osp.exists('ckpt'):
......
import torch
import torch.nn as nn
import torch.nn.functional as F
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
......
ykl/AlexNet_BN/image/flops.png

37.4 KB | W: | H:

ykl/AlexNet_BN/image/flops.png

36.8 KB | W: | H:

ykl/AlexNet_BN/image/flops.png
ykl/AlexNet_BN/image/flops.png
ykl/AlexNet_BN/image/flops.png
ykl/AlexNet_BN/image/flops.png
  • 2-up
  • Swipe
  • Onion skin
ykl/AlexNet_BN/image/param.png

36.2 KB | W: | H:

ykl/AlexNet_BN/image/param.png

36.4 KB | W: | H:

ykl/AlexNet_BN/image/param.png
ykl/AlexNet_BN/image/param.png
ykl/AlexNet_BN/image/param.png
ykl/AlexNet_BN/image/param.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -8,6 +8,19 @@ from torch.autograd import Variable
from function import FakeQuantize
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
# 获取最近的量化值
def get_nearest_val(quant_type,x,is_bias=False):
if quant_type=='INT':
......@@ -64,7 +77,7 @@ def bias_qmax(quant_type):
elif quant_type == 'POT':
return get_qmax(quant_type)
else:
return get_qmax(quant_type, 16, 5)
return get_qmax(quant_type, 16, 7)
# 转化为FP32,不需再做限制
......
......@@ -51,18 +51,6 @@ def quantize_inference(model, test_loader, device):
return 100. * correct / len(test_loader.dataset)
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
if __name__ == "__main__":
batch_size = 32
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
7398.262055529559 2629.3751622617588 590.6821895683953 140.07310170087658 33.86048167345483 8.284908066398648 2.0380663672630033 0.5092279871870147 0.12684254729585437 0.031863777467731946 0.007841108109205986 0.0019867625414859602 0.000524805638519184 0.00015510881465292402 4.128433975522605e-05 7398.228137001189 1620.2559603222214 133.7304911846874 131.62907676756663 131.62871096032845 131.6289253991913 131.62875302977494 1069.390471249252 255.89338592444176 239.72194344867773 94.18685807791533 86.02821389442595 36.77673254387978 54.05171226668418 47.849590815698406 9.560264345209177 36.522392073949895 42.22214551260318 35.435660108435656 2.4388559138472727 9.464983665314236 36.53673927235493 38.25590966717735 30.62508628497586 0.6417078348950557 2.3826050378452384 9.478115589519865 36.536710369840804
7315.84863077462 2447.6990359739557 554.9539055398842 130.3057613397198 31.72854680369856 7.739889012396159 1.9503014204883173 0.47476923180404534 0.11812657336071226 0.03002884209117736 0.007400953973311122 0.001932886049776867 0.00044452340426047595 9.522531481526753e-05 2.9848570843407086e-06 7315.828621431085 1509.6239277040813 132.669444214901 130.7354806274539 130.73505398794512 130.73538243275962 130.7352909353478 1040.798298841082 253.67138749112172 224.5416693145032 97.71660734169551 80.35741771506743 36.57394847264219 58.08625652423495 44.466842189195326 9.423972626618646 36.3279041181201 46.176068440021766 32.78338584270386 2.408595598953836 9.337974512410408 36.32789770432847 41.972310900423636 28.345412246072254 0.6323033918133659 2.3565326800100244 9.337984500288826 36.35255195223587
js_param_list:
2072.6257003788373 729.7410477619161 163.5885811737197 38.802816944087496 9.38263941212123 2.2969966695619557 0.5644727100537247 0.1411507960630098 0.035134381675847655 0.008839262947625631 0.0021727478905621904 0.0005508804306955368 0.00014678034905458157 4.388988379219416e-05 1.1738754839340505e-05 2072.6054855115713 448.9787413512609 37.516807535381915 36.936289293152434 36.93619699724665 36.93626107722789 36.936217327546416 297.895038404521 71.53561204104008 66.59311683208784 26.516888484693222 23.910032735846137 10.313884888190998 15.311977796193256 13.29234301305728 2.6803832302061097 10.243483526646026 11.995040344505467 9.842200480067097 0.6838675134469407 2.65413470553611 10.251089473912115 10.878179543033623 8.50328530859542 0.17976968612984542 0.66836506379984 2.6613930698773456 10.251080477543102
2049.6929693507795 679.6962663951226 153.76103950048937 36.08036236374624 8.775254036917314 2.1428681018671267 0.540685220427713 0.13159647793378718 0.03272457732749732 0.008309703792123481 0.002061388543162421 0.0005373214024740875 0.00012932799821557616 3.358989301840365e-05 1.9910609472068016e-06 2049.682281691373 418.6989301615079 37.19252279367339 36.66040315748872 36.66028729152239 36.660375447269956 36.66034282835435 290.1050259737521 70.90992813610583 62.413211449939176 27.472679564357314 22.33696197321313 10.256081766983586 16.41390223671071 12.354416475776839 2.642681146872554 10.188908715006807 13.078305876994543 9.106313700653434 0.6753022926659958 2.6192195012070716 10.1889069543645 11.896313258304504 7.873233459924485 0.17713436485239206 0.6610392339533913 2.61922152391777 10.201050207834063
ptq_acc_list:
10.0 17.19 49.41 81.48 85.79 86.8 87.02 87.02 87.16 87.1 87.12 87.06 87.08 87.08 87.08 10.0 22.87 42.57 40.51 42.66 40.73 42.47 15.91 58.82 69.34 75.77 81.84 82.5 79.29 84.62 86.41 77.1 80.77 85.88 86.82 77.91 36.61 81.32 85.81 86.93 75.9 41.38 37.26
10.0 17.5 55.9 80.28 86.05 87.16 87.25 87.27 87.38 87.29 87.32 87.32 87.31 87.32 87.3 10.0 23.75 64.18 66.4 65.77 66.53 58.32 19.99 54.54 68.29 72.88 82.37 83.32 77.04 85.29 86.28 83.07 79.34 86.0 87.05 86.34 83.1 79.9 86.01 87.06 87.01 86.26 81.18
acc_loss_list:
0.8851762544494202 0.8026179813985532 0.432655873234585 0.06441612125387529 0.014927086921575348 0.0033298886209668878 0.0008037662188541438 0.0008037662188541438 -0.0008037662188539806 -0.00011482374555047542 -0.0003444712366517526 0.0003444712366517526 0.0001148237455506386 0.0001148237455506386 0.0001148237455506386 0.8851762544494202 0.7373980939258238 0.5111953151911816 0.534849006774601 0.5101619014812264 0.5323228843724883 0.5123435526466874 0.8173154208290275 0.32460672867148926 0.20381214835227923 0.12998047996325648 0.06028246641405442 0.052704099207716196 0.08956252152945225 0.02836146515099321 0.007808014697439508 0.11470892180502938 0.07256860718796655 0.013893673211620253 0.0031002411298657736 0.10540819841543239 0.5796302675393271 0.06625330118268469 0.014697439430474234 0.0018371799288092385 0.12848777127109884 0.5248593409117005 0.5721667240785395
0.8854393401305991 0.7995188452285485 0.3596059113300493 0.08030702256845004 0.014205521823805809 0.0014892885783023217 0.000458242639477675 0.0002291213197389189 -0.0010310459388244838 0.0 -0.00034368197960805275 -0.00034368197960805275 -0.0002291213197387561 -0.00034368197960805275 -0.00011456065986929664 0.8854393401305991 0.727918432810173 0.26474968495818535 0.23931721846717835 0.24653454003895073 0.2378279298888762 0.3318822316416543 0.7709932409210678 0.3751861610722878 0.2176652537518616 0.16508191087180674 0.05636384465574523 0.04548058196815228 0.11742467636613586 0.022912131973880166 0.011570626646809544 0.048344598464887305 0.0910757245961737 0.01477832512315278 0.0027494558368657243 0.010883262687593112 0.048000916485279085 0.08466032764348723 0.01466376446328332 0.002634895176996265 0.0032076984763432367 0.011799747966548299 0.0699965631802039
......@@ -46,14 +46,9 @@ def test(model, device, test_loader):
if __name__ == "__main__":
batch_size = 32
test_batch_size = 32
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
epochs_cfg = [15, 20, 20, 20, 10, 10]
lr_cfg = [0.01, 0.005, 0.002, 0.001, 0.0005, 0.0001]
momentum = 0.5
save_model = True
......@@ -62,7 +57,7 @@ if __name__ == "__main__":
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
datasets.CIFAR10('../data', train=True, download=True,
transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(),
......@@ -73,31 +68,24 @@ if __name__ == "__main__":
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.Resize((32, 32), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = AlexNet_BN().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
epoch_start = 1
for epochs,lr in zip(epochs_cfg,lr_cfg):
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
epoch_end = epoch_start+epochs
for epoch in range(epoch_start,epoch_end):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
epoch_start += epochs
if save_model:
if not osp.exists('ckpt'):
......
# 改动说明
## update:2023/04/17
+ 指定了新的梯度学习率方案,对全精度模型重新训练以达到更高的acc,并重新进行ptq和fit
## update: 2023/04/16
+ 添加了matlab的拟合及绘图脚本,支持模型分类标记,且曲线拟合相比cftool更加平滑
+ ptq.py中计算js_param笔误,应由flop_ratio改为par_ratio。否则flops和param拟合没有区别
+ module.py中bias_qmax方法,应当为float类型传参num_bits为16,e_bits为7.
+ 这里主要关注e_bits,拟合离群点主要为FLOAT_7_E5 / FLOAT_8_E5 / FLOAT_8_E6,其表现为bias两极分布,与之前int量化bias溢出的问题现象相似。
+ 原先指定e_bits为5,由于bias的scale为input和weight的scale乘积,bias量化范围应当大致为x和weight量化范围的平方倍。目前代码支持的最高x和weight量化范围大致为 $2^{2^{6}}$ ,因此bias范围应当近似取到$2^{2^7}$,即将e_bits指定为7
+ 改动之后,离群点消失,拟合效果显著提高
ykl/VGG_16/image/VGG16_table.png

21.8 KB | W: | H:

ykl/VGG_16/image/VGG16_table.png

21 KB | W: | H:

ykl/VGG_16/image/VGG16_table.png
ykl/VGG_16/image/VGG16_table.png
ykl/VGG_16/image/VGG16_table.png
ykl/VGG_16/image/VGG16_table.png
  • 2-up
  • Swipe
  • Onion skin
ykl/VGG_16/image/flops.png

33.7 KB | W: | H:

ykl/VGG_16/image/flops.png

33.2 KB | W: | H:

ykl/VGG_16/image/flops.png
ykl/VGG_16/image/flops.png
ykl/VGG_16/image/flops.png
ykl/VGG_16/image/flops.png
  • 2-up
  • Swipe
  • Onion skin
ykl/VGG_16/image/param.png

34.2 KB | W: | H:

ykl/VGG_16/image/param.png

33.4 KB | W: | H:

ykl/VGG_16/image/param.png
ykl/VGG_16/image/param.png
ykl/VGG_16/image/param.png
ykl/VGG_16/image/param.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -14,7 +14,7 @@ import module
feature_cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
classifier_cfg = [4096, 4096, 'LF']
def make_feature_layers(cfg, batch_norm=False):
def make_feature_layers(cfg, batch_norm=True):
layers = []
names = []
input_channel = 3
......@@ -103,13 +103,14 @@ def quantize_classifier_layers(model,name_list,quant_type,num_bits,e_bits):
return names, layers
def quantize_utils(model,qfeature_name,qclassifier_name,func, x=None):
# 原层和量化层的forward都可以使用该函数
def model_utils(model,feature_name,classifier_name,func, x=None):
if func == 'inference':
layer=getattr(model,qfeature_name[0])
layer=getattr(model,feature_name[0])
x = layer.qi.quantize_tensor(x)
last_qo = None
for name in qfeature_name:
for name in feature_name:
layer = getattr(model,name)
if func == 'forward':
x = layer(x)
......@@ -123,7 +124,7 @@ def quantize_utils(model,qfeature_name,qclassifier_name,func, x=None):
if func != 'freeze':
x = torch.flatten(x, start_dim=1)
for name in qclassifier_name:
for name in classifier_name:
layer = getattr(model,name)
if func == 'forward':
x = layer(x)
......@@ -164,26 +165,9 @@ class VGG_16(nn.Module):
# self.fc3 = nn.Linear(4096, num_class)
def forward(self, x):
#feature
for name in self.feature_name:
layer = getattr(self,name)
x = layer(x)
x = torch.flatten(x, start_dim=1)
#classifier
for name in self.classifier_name:
layer = getattr(self,name)
x = layer(x)
# x = self.fc1(x)
# x = self.crelu1(x)
# x = self.drop1(x)
# x = self.fc2(x)
# x = self.crelu2(x)
# x = self.drop2(x)
# x = self.fc3(x)
x = model_utils(self, self.feature_name, self.classifier_name,
func='forward', x=x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
......@@ -207,17 +191,17 @@ class VGG_16(nn.Module):
# self.qfc3 = QLinear(quant_type, self.fc3, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
def quantize_forward(self,x):
x = quantize_utils(self, self.qfeature_name, self.qclassifier_name,
x = model_utils(self, self.qfeature_name, self.qclassifier_name,
func='forward', x=x)
return x
def freeze(self):
quantize_utils(self, self.qfeature_name, self.qclassifier_name,
model_utils(self, self.qfeature_name, self.qclassifier_name,
func='freeze', x=None)
def quantize_inference(self,x):
x = quantize_utils(self, self.qfeature_name, self.qclassifier_name,
x = model_utils(self, self.qfeature_name, self.qclassifier_name,
func='inference', x=x)
return x
\ No newline at end of file
......@@ -8,6 +8,19 @@ from torch.autograd import Variable
from function import FakeQuantize
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
# 获取最近的量化值
def get_nearest_val(quant_type, x, is_bias=False, block_size=1000000):
if quant_type == 'INT':
......@@ -80,7 +93,7 @@ def bias_qmax(quant_type):
elif quant_type == 'POT':
return get_qmax(quant_type)
else:
return get_qmax(quant_type, 16, 5)
return get_qmax(quant_type, 16, 7)
# 转化为FP32,不需再做限制
......
......@@ -51,18 +51,6 @@ def quantize_inference(model, test_loader, device):
return 100. * correct / len(test_loader.dataset)
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
if __name__ == "__main__":
batch_size = 32
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
9536.471074104704 2226.062767717981 479.08937301581057 110.29737790259509 26.51254683672561 6.543175408097222 1.6082547229117354 0.4010994193665803 0.09957335073633722 0.025111075393055696 0.0061973498640195265 0.0015194423491633707 0.00038995051898299435 7.942830031237921e-05 5.2415376317010985e-05 9536.459434888866 1346.1955987678602 186.0432674146965 184.66847178581048 184.66811538285353 184.66788024897852 184.66821779056917 1162.9049832114217 334.88738457777345 213.59354200345794 162.9083436791504 74.46976130202673 51.093367953882556 114.00986841709613 39.88037886445475 13.180682344029526 50.92046138218437 97.04451701067691 28.848989954377853 3.333524419321254 13.121054175424074 50.939926759367474 90.27653874927127 24.69379491562014 0.8539878687852623 3.300629416372651 13.138977654051457 50.93992491304259
9678.738526193363 2754.445091033171 599.5177886615332 139.45004407102132 33.61972833441705 8.327290718295096 2.0621958796228017 0.507271590567208 0.12638558978042777 0.03172725299824697 0.00787383157101955 0.0019659234261853076 0.0004929713753570182 0.00013082873931775967 6.58462511070724e-05 9678.721786560653 1669.6727215221636 181.2126746364518 179.2569959125805 179.25701832538817 179.2571686197933 179.2571438439814 1252.5642980710986 326.4085549727165 257.15456948396184 140.83288095342002 91.39911193737521 49.66634546455573 91.6028944861708 50.03401260763902 12.92006159573897 49.428214414171954 75.71805318590293 36.69642341667939 3.284018584054309 12.831006283541836 49.4282155693062 69.7655678600058 31.59066644962877 0.8486327021161753 3.2312358866685056 12.83099812618559 49.46498235942715
js_param_list:
6887.257101138355 1340.7850727629186 280.73809433329967 63.94496255963192 15.289535102058974 3.7565034795627565 0.9259750231943055 0.2288807692260617 0.057107940186636466 0.014183206997691562 0.003557619868239359 0.000877363417377142 0.00022626713835566516 3.542348440969017e-05 1.7145957804444833e-05 6887.256730096761 806.0595586054458 139.7164856361964 139.0401533816038 139.04027885232878 139.0402887905267 139.0402936537755 795.7370567027401 249.62616743230893 132.75014085752875 132.0024940072834 45.6707525557016 38.23747791072957 96.5944785050857 23.885526002706396 9.845974027938183 38.151876742443726 83.61522220389249 17.043439224317503 2.4726691755122503 9.815686765994785 38.15642983446397 78.20951581132825 14.518605994908436 0.6269759716394059 2.4555938958183803 9.819743080524256 38.156425344733975
6979.995257775382 1676.8544762024308 357.07416358579195 82.21848562208979 19.71659591761924 4.858750496313133 1.204106769484056 0.2970651538514128 0.07389159465364287 0.01847686680473944 0.004583820583860123 0.0011690277325588727 0.0003001505161697722 8.054369623471168e-05 3.8545500662813255e-05 6979.99738972489 1008.874309342299 135.7925378028479 134.79040735180877 134.7905911406337 134.79058895538347 134.79061404359229 841.4093324566583 236.70954280105926 160.0719446844072 112.93571614580846 56.470927826653764 36.93548609003767 78.37733457244148 30.396125541872564 9.673987287874596 36.81191428618332 66.55321174809056 22.060163046203872 2.446769401900793 9.628841768242337 36.811916984902815 61.86831406132784 18.91259887812061 0.6248787079335093 2.4197557895059423 9.62881847680995 36.822419165985735
ptq_acc_list:
10.0 11.77 60.12 86.44 88.81 89.35 89.3 89.5 89.51 89.44 89.46 89.43 89.43 89.44 89.44 10.0 17.19 66.85 66.15 66.9 65.02 65.06 14.46 64.07 78.19 78.84 87.19 86.74 83.16 88.48 88.84 75.33 83.87 88.84 89.36 67.65 10.23 84.71 88.79 89.43 65.56 10.35 10.19
10.0 11.13 56.25 87.31 89.42 89.68 89.8 89.82 89.8 89.75 89.77 89.81 89.79 89.8 89.8 10.0 19.8 68.55 69.75 70.61 69.72 67.62 12.29 54.35 78.59 80.14 87.26 87.17 83.65 88.89 89.41 87.03 84.94 89.19 89.78 89.18 87.11 85.27 89.42 89.76 89.81 89.41 74.75
acc_loss_list:
0.8881932021466905 0.8684033989266547 0.3278175313059034 0.03354203935599284 0.007043828264758446 0.0010062611806798236 0.001565295169946339 -0.0006708407871198823 -0.000782647584973249 0.0 -0.00022361359570657448 0.0001118067978532078 0.0001118067978532078 0.0 0.0 0.8881932021466905 0.807804114490161 0.25257155635062617 0.2603980322003577 0.2520125223613595 0.2730322003577818 0.2725849731663685 0.8383273703041144 0.2836538461538462 0.12578264758497318 0.118515205724508 0.025156529516994635 0.03018783542039359 0.07021466905187837 0.01073345259391764 0.006708407871198505 0.15775939177101966 0.0622763864042933 0.006708407871198505 0.0008944543828264568 0.24362701252236127 0.8856216457960644 0.05288461538461543 0.007267441860465021 0.0001118067978532078 0.266994633273703 0.8842799642218248 0.8860688729874776
0.888641425389755 0.8760579064587973 0.3736080178173719 0.027728285077950946 0.004231625835189259 0.0013363028953228323 0.0 -0.00022271714922044567 0.0 0.0005567928730511933 0.00033407572383074765 -0.00011135857461030197 0.00011135857461014371 0.0 0.0 0.888641425389755 0.779510022271715 0.23663697104677062 0.22327394209354118 0.2136971046770601 0.2236080178173719 0.2469933184855233 0.8631403118040089 0.39476614699331847 0.12483296213808456 0.10757238307349662 0.02828507795100214 0.029287305122494382 0.06848552338530058 0.010133630289532257 0.004342984409799561 0.03084632516703782 0.05412026726057906 0.006792873051224938 0.00022271714922044567 0.006904231625835082 0.02995545657015588 0.05044543429844099 0.004231625835189259 0.00044543429844089133 -0.00011135857461030197 0.004342984409799561 0.16759465478841867
......@@ -46,14 +46,9 @@ def test(model, device, test_loader):
if __name__ == "__main__":
batch_size = 32
test_batch_size = 32
seed = 1
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
epochs_cfg = [25, 30, 30, 20, 20, 10, 10]
lr_cfg = [0.01, 0.008, 0.005, 0.002, 0.001, 0.0005, 0.0001]
momentum = 0.5
save_model = True
......@@ -78,26 +73,19 @@ if __name__ == "__main__":
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = VGG_16().to(device)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
epoch_start = 1
for epochs,lr in zip(epochs_cfg,lr_cfg):
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
epoch_end = epoch_start+epochs
for epoch in range(epoch_start,epoch_end):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
epoch_start += epochs
if save_model:
if not osp.exists('ckpt'):
......
......@@ -12,8 +12,6 @@
+ 拟合结果
+ 应该是FLOAT量化中js散度不大,但精度仅为10的量化点导致的R方较低
![flops](image/flops.png)
![param](image/param.png)
......
ykl/VGG_19/image/VGG19_table.png

22 KB | W: | H:

ykl/VGG_19/image/VGG19_table.png

21.4 KB | W: | H:

ykl/VGG_19/image/VGG19_table.png
ykl/VGG_19/image/VGG19_table.png
ykl/VGG_19/image/VGG19_table.png
ykl/VGG_19/image/VGG19_table.png
  • 2-up
  • Swipe
  • Onion skin
ykl/VGG_19/image/flops.png

35.9 KB | W: | H:

ykl/VGG_19/image/flops.png

33.4 KB | W: | H:

ykl/VGG_19/image/flops.png
ykl/VGG_19/image/flops.png
ykl/VGG_19/image/flops.png
ykl/VGG_19/image/flops.png
  • 2-up
  • Swipe
  • Onion skin
ykl/VGG_19/image/param.png

37.6 KB | W: | H:

ykl/VGG_19/image/param.png

34.7 KB | W: | H:

ykl/VGG_19/image/param.png
ykl/VGG_19/image/param.png
ykl/VGG_19/image/param.png
ykl/VGG_19/image/param.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -8,6 +8,19 @@ from torch.autograd import Variable
from function import FakeQuantize
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
# 获取最近的量化值
def get_nearest_val(quant_type, x, is_bias=False, block_size=1000000):
if quant_type == 'INT':
......@@ -80,7 +93,7 @@ def bias_qmax(quant_type):
elif quant_type == 'POT':
return get_qmax(quant_type)
else:
return get_qmax(quant_type, 16, 5)
return get_qmax(quant_type, 16, 7)
# 转化为FP32,不需再做限制
......
......@@ -51,18 +51,6 @@ def quantize_inference(model, test_loader, device):
return 100. * correct / len(test_loader.dataset)
def js_div(p_output, q_output, get_softmax=True):
"""
Function that measures JS divergence between target and output logits:
"""
KLDivLoss = nn.KLDivLoss(reduction='sum')
if get_softmax:
p_output = F.softmax(p_output)
q_output = F.softmax(q_output)
log_mean_output = ((p_output + q_output)/2).log()
return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2
if __name__ == "__main__":
batch_size = 32
......
title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list:
10125.068777303843 2125.8746059892997 448.1684269275655 102.72227108791047 24.664131592897466 6.028952690039004 1.4808848911979104 0.3654121554844982 0.09188678563921866 0.022862010935799246 0.005681738921256071 0.0014282347355091774 0.0003390110858486199 0.00010615918609977952 3.489823651098289e-05 10125.058309270402 1275.5847525159818 202.11004564591005 200.95585017589545 200.95539368760936 200.95529631550048 200.95542079271524 1204.751815912724 367.27809530828097 207.51154219067092 188.4010741493603 71.77026853544704 55.978675489218205 135.4187671720789 37.94656880869658 14.343459472762477 55.84021870450606 116.4026136353036 27.25354110333402 3.606044260373801 14.295147200396164 55.85538278221231 108.58905223595819 23.270040334049046 0.9179187347666412 3.5799479606638203 14.308950780559794 55.85538134242835
10335.410852393428 2702.374876526504 586.8085986640358 136.19783508263313 32.76686433160855 8.108710198358402 2.017359787423094 0.5043751741104074 0.12519450821641712 0.03078375848492338 0.007759517552654125 0.001915080842767542 0.0004650357624047513 0.00010331804808074696 5.051867857804825e-05 10335.404291659883 1639.8833350853 197.21747754065524 195.40539442483635 195.40543311681287 195.40540281991127 195.40546678069262 1301.4075583496397 353.91370485500016 255.05728255399467 161.1495511523927 90.18686133261879 54.04437942140043 108.573018962868 48.961241334358604 14.062757558415433 53.804179496106855 91.04968596750315 35.698608089109584 3.568426481512405 13.974762150284537 53.80416850535668 84.32169873971353 30.648531893335964 0.9205644289947652 3.5162305719167306 13.974779013299235 53.8384368168758
js_param_list:
7919.9999946725075 1368.1564225363845 283.16934566896754 64.2942694436596 15.321270118010474 3.745277689238688 0.9178676871947216 0.22807088096042888 0.056904570420633724 0.014193230067183694 0.0035684240604433215 0.0008905876260726569 0.0002163833713661727 7.152718652977152e-05 3.2130914002021176e-05 7920.002545364124 818.8838272570443 164.4618730779971 163.84526969797696 163.84537575503464 163.84534091325256 163.84535644152615 912.7525366351027 304.20939173305254 138.97594773808592 167.73729464269246 46.84141581944599 45.95750081299213 124.95158148181119 23.98366626289621 11.613224296186155 45.88300992073913 108.73234446521248 16.892437923675633 2.904718790614561 11.58708466778401 45.88710644854956 101.79740772746169 14.310069748579027 0.7348164303199256 2.8903894156159375 11.59069207978545 45.887101821422974
8185.907782380622 1800.9911591065213 381.9598699018292 87.55210816175092 20.97233721922892 5.157556218764395 1.273967796169326 0.3167936166061732 0.07878698713657759 0.01962915205843054 0.0049230819182810885 0.001218414473634084 0.00029934385592657504 6.540185677536515e-05 1.9836379728102926e-05 8185.911496782605 1085.762959717068 161.20373387464056 160.2127414792563 160.2128133174403 160.21278421049763 160.21274821008095 958.9015917875203 279.7889215111697 174.41461406276403 139.022264040724 61.1200446333161 43.83433305556267 98.6990126657353 32.570105813234655 11.541062350360075 43.70912334388383 84.54385769669656 23.46464257859416 2.916976983569494 11.494558828605104 43.70911423954581 78.86080293161805 20.057860207412013 0.7429960203642413 2.8894343747747815 11.494570441006683 43.722130173425036
ptq_acc_list:
10.0 12.8 62.98 87.16 89.05 89.19 89.23 89.21 89.26 89.28 89.26 89.26 89.25 89.25 89.25 10.0 17.81 67.69 68.82 65.58 65.75 66.39 12.63 57.06 79.98 79.55 87.25 85.88 83.01 88.34 88.74 74.07 84.41 88.6 89.41 63.55 10.0 84.55 88.99 89.33 61.4 10.0 10.0
10.0 10.47 69.02 88.3 89.73 89.92 89.99 90.08 90.09 90.08 90.11 90.1 90.11 90.1 90.09 10.0 14.58 73.48 72.92 65.82 72.06 71.49 10.35 50.78 80.42 79.11 87.53 87.77 84.05 89.15 89.39 87.83 84.85 89.4 89.9 89.32 87.61 85.69 89.65 90.18 89.87 89.44 78.02
acc_loss_list:
0.8879551820728291 0.8565826330532214 0.2943417366946779 0.02341736694677875 0.0022408963585434493 0.0006722689075630507 0.00022408963585429716 0.0004481792717087535 -0.00011204481792722819 -0.00033613445378152535 -0.00011204481792722819 -0.00011204481792722819 0.0 0.0 0.0 0.8879551820728291 0.8004481792717086 0.24156862745098043 0.22890756302521015 0.26521008403361346 0.26330532212885155 0.2561344537815126 0.8584873949579832 0.360672268907563 0.10386554621848736 0.10868347338935577 0.022408963585434174 0.03775910364145663 0.06991596638655456 0.010196078431372511 0.0057142857142857715 0.17008403361344546 0.05422969187675074 0.00728291316526617 -0.0017927170868346958 0.28795518207282916 0.8879551820728291 0.05266106442577034 0.0029131652661065 -0.0008963585434173479 0.31204481792717087 0.8879551820728291 0.8879551820728291
0.8890122086570477 0.883795782463929 0.2339622641509434 0.01997780244173138 0.0041065482796891276 0.001997780244173059 0.0012208657047724687 0.0002219755826858604 0.00011098779134285134 0.0002219755826858604 -0.00011098779134300907 0.0 -0.00011098779134300907 0.0 0.00011098779134285134 0.8890122086570477 0.8381798002219756 0.1844617092119866 0.19067702552719193 0.26947835738068815 0.20022197558268584 0.2065482796892342 0.8851276359600444 0.4364039955604883 0.10743618201997773 0.1219755826859045 0.02852386237513866 0.025860155382907864 0.0671476137624861 0.01054384017758034 0.007880133185349542 0.025194228634850123 0.05826859045504995 0.007769145394006534 0.0022197558268589193 0.00865704772475029 0.02763596004439506 0.04894561598224192 0.004994450610432726 -0.000887902330743757 0.0025527192008877888 0.007325194228634813 0.13407325194228634
......@@ -3,7 +3,6 @@ from model import *
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import os
......@@ -47,16 +46,9 @@ def test(model, device, test_loader):
if __name__ == "__main__":
batch_size = 32
test_batch_size = 32
seed = 1
# epoch = 35
# lr = 0.01
epochs1 = 15
epochs2 = epochs1+10
epochs3 = epochs2+10
lr1 = 0.01
lr2 = 0.001
lr3 = 0.0001
epochs_cfg = [30, 40, 30, 20, 20, 10, 10]
lr_cfg = [0.01, 0.008, 0.005, 0.002, 0.001, 0.0005, 0.0001]
momentum = 0.5
save_model = True
......@@ -81,33 +73,19 @@ if __name__ == "__main__":
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
model = VGG_19().to(device)
# optimizer = optim.Adam(model.parameters(), lr=lr)
# lr_scheduler = CosineAnnealingLR(optimizer, T_max=epoch)
# for epoch in range(1, epoch + 1):
# train(model, device, train_loader, optimizer, epoch)
# # lr_scheduler.step()
# test(model, device, test_loader)
optimizer1 = optim.SGD(model.parameters(), lr=lr1, momentum=momentum)
optimizer2 = optim.SGD(model.parameters(), lr=lr2, momentum=momentum)
optimizer3 = optim.SGD(model.parameters(), lr=lr3, momentum=momentum)
for epoch in range(1, epochs1 + 1):
train(model, device, train_loader, optimizer1, epoch)
test(model, device, test_loader)
for epoch in range(epochs1 + 1, epochs2 + 1):
train(model, device, train_loader, optimizer2, epoch)
test(model, device, test_loader)
for epoch in range(epochs2 + 1, epochs3 + 1):
train(model, device, train_loader, optimizer3, epoch)
test(model, device, test_loader)
epoch_start = 1
for epochs,lr in zip(epochs_cfg,lr_cfg):
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
epoch_end = epoch_start+epochs
for epoch in range(epoch_start,epoch_end):
train(model, device, train_loader, optimizer, epoch)
test(model, device, test_loader)
epoch_start += epochs
if save_model:
if not osp.exists('ckpt'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment