Commit f641275d by Klin

feat: add ptq:FP3-FP7

parent 06dfc82c
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
+ POT:取POT2-POT8 (POT8之后容易出现Overflow) + POT:取POT2-POT8 (POT8之后容易出现Overflow)
+ FP8:取E1-E6 (E0相当于INT量化,E7相当于POT量化,直接取相应策略效果更好) + FP8:取E1-E6 (E0相当于INT量化,E7相当于POT量化,直接取相应策略效果更好)
+ 支持调整FP的位宽 + 支持调整FP的位宽
+ 关于量化点选择,可以更改`utils.py`中的`bit_list`函数 + 关于量化点选择,可以更改`utils.py`中的`numbit_list`函数
+ 量化结果: + 量化结果:
...@@ -31,6 +31,8 @@ ...@@ -31,6 +31,8 @@
matlab导入数据,选择列向量 matlab导入数据,选择列向量
+ 加入FP3-FP7前:
+ js_flops - acc_loss + js_flops - acc_loss
Rational: Numerator degree 2 / Denominator degree 2 Rational: Numerator degree 2 / Denominator degree 2
...@@ -43,6 +45,7 @@ ...@@ -43,6 +45,7 @@
![fig2](image/fig2.png) ![fig2](image/fig2.png)
+ js_param - acc_loss + js_param - acc_loss
Rational: Numerator degree 2 / Denominator degree 2 Rational: Numerator degree 2 / Denominator degree 2
...@@ -54,3 +57,29 @@ ...@@ -54,3 +57,29 @@
- [x] center and scale - [x] center and scale
![fig4](image/fig4.png) ![fig4](image/fig4.png)
+ 加入FP3-FP7后
+ js_flops - acc_loss
Rational: Numerator degree 2 / Denominator degree 2
- [ ] center and scale
![image-20230407010858191](image/fig5.png)
- [x] center and scale
![image-20230407011501987](image/fig6.png)
+ js_param - acc_loss
Rational: Numerator degree 2 / Denominator degree 2
- [ ] center and scale
![image-20230407010945342](image/fig7.png)
- [x] center and scale
![image-20230407010958875](image/fig8.png)
\ No newline at end of file
ykl/AlexNet/image/table.png

113 KB | W: | H:

ykl/AlexNet/image/table.png

18.4 KB | W: | H:

ykl/AlexNet/image/table.png
ykl/AlexNet/image/table.png
ykl/AlexNet/image/table.png
ykl/AlexNet/image/table.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -116,7 +116,7 @@ if __name__ == "__main__": ...@@ -116,7 +116,7 @@ if __name__ == "__main__":
acc_loss_list = [] acc_loss_list = []
for quant_type in quant_type_list: for quant_type in quant_type_list:
num_bit_list, e_bit_list = bit_list(quant_type) num_bit_list = numbit_list(quant_type)
# 对一个量化类别,只需设置一次bias量化表 # 对一个量化类别,只需设置一次bias量化表
# int由于位宽大,使用量化表开销过大,直接_round即可 # int由于位宽大,使用量化表开销过大,直接_round即可
...@@ -125,6 +125,7 @@ if __name__ == "__main__": ...@@ -125,6 +125,7 @@ if __name__ == "__main__":
gol.set_value(bias_list, is_bias=True) gol.set_value(bias_list, is_bias=True)
for num_bits in num_bit_list: for num_bits in num_bit_list:
e_bit_list = ebit_list(quant_type,num_bits)
for e_bits in e_bit_list: for e_bits in e_bit_list:
model_ptq = AlexNet() model_ptq = AlexNet()
if quant_type == 'FLOAT': if quant_type == 'FLOAT':
......
title_list: title_list:
INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6 INT_2 INT_3 INT_4 INT_5 INT_6 INT_7 INT_8 INT_9 INT_10 INT_11 INT_12 INT_13 INT_14 INT_15 INT_16 POT_2 POT_3 POT_4 POT_5 POT_6 POT_7 POT_8 FLOAT_3_E1 FLOAT_4_E1 FLOAT_4_E2 FLOAT_5_E1 FLOAT_5_E2 FLOAT_5_E3 FLOAT_6_E1 FLOAT_6_E2 FLOAT_6_E3 FLOAT_6_E4 FLOAT_7_E1 FLOAT_7_E2 FLOAT_7_E3 FLOAT_7_E4 FLOAT_7_E5 FLOAT_8_E1 FLOAT_8_E2 FLOAT_8_E3 FLOAT_8_E4 FLOAT_8_E5 FLOAT_8_E6
js_flops_list: js_flops_list:
7507.750226317713 2739.698390971301 602.5613310246055 140.92197221503724 34.51721888016634 8.518508718865842 2.1353732883428638 0.5319411628570782 0.1316271020831477 0.03249564657892055 0.008037284252895557 0.0020460099353784723 0.00041867764927864105 0.0001321614950419231 5.841430176387608e-06 7507.667348902921 1654.3775934528933 136.7401730898288 134.5782970456457 134.57841422062364 134.5783939274636 134.5782945727605 33.31638902152266 32.12034308540418 0.6541880874259414 2.442034364817909 9.688117360231624 37.70544899186622 7507.750226317713 2739.698390971301 602.5613310246055 140.92197221503724 34.51721888016634 8.518508718865842 2.1353732883428638 0.5319411628570782 0.1316271020831477 0.03249564657892055 0.008037284252895557 0.0020460099353784723 0.00041867764927864105 0.0001321614950419231 5.841430176387608e-06 7507.667348902921 1654.3775934528933 136.7401730898288 134.5782970456457 134.57841422062364 134.5783939274636 134.5782945727605 1054.3432278105702 244.48311696489273 247.89704518368768 87.65672091651302 89.63831617681878 37.95288917539117 48.439491059469624 50.122494451137555 9.763717383777191 37.67666314899965 37.082531966568794 37.162725668876305 2.504495253790035 9.660221946273799 37.70544899186622 33.31638902152266 32.12034308540418 0.6541880874259414 2.442034364817909 9.688117360231624 37.70544899186622
js_param_list: js_param_list:
7507.750226317713 2739.698390971301 602.5613310246055 140.92197221503724 34.51721888016634 8.518508718865842 2.1353732883428638 0.5319411628570782 0.1316271020831477 0.03249564657892055 0.008037284252895557 0.0020460099353784723 0.00041867764927864105 0.0001321614950419231 5.841430176387608e-06 7507.667348902921 1654.3775934528933 136.7401730898288 134.5782970456457 134.57841422062364 134.5783939274636 134.5782945727605 33.31638902152266 32.12034308540418 0.6541880874259414 2.442034364817909 9.688117360231624 37.70544899186622 7507.750226317713 2739.698390971301 602.5613310246055 140.92197221503724 34.51721888016634 8.518508718865842 2.1353732883428638 0.5319411628570782 0.1316271020831477 0.03249564657892055 0.008037284252895557 0.0020460099353784723 0.00041867764927864105 0.0001321614950419231 5.841430176387608e-06 7507.667348902921 1654.3775934528933 136.7401730898288 134.5782970456457 134.57841422062364 134.5783939274636 134.5782945727605 1054.3432278105702 244.48311696489273 247.89704518368768 87.65672091651302 89.63831617681878 37.95288917539117 48.439491059469624 50.122494451137555 9.763717383777191 37.67666314899965 37.082531966568794 37.162725668876305 2.504495253790035 9.660221946273799 37.70544899186622 33.31638902152266 32.12034308540418 0.6541880874259414 2.442034364817909 9.688117360231624 37.70544899186622
ptq_acc_list: ptq_acc_list:
10.0 10.16 51.21 77.39 83.03 84.73 84.84 85.01 85.08 85.07 85.06 85.08 85.08 85.08 85.08 10.0 14.32 72.49 72.65 72.95 72.08 72.23 82.73 83.3 85.01 84.77 59.86 51.87 10.0 10.16 51.21 77.39 83.03 84.73 84.84 85.01 85.08 85.07 85.06 85.08 85.08 85.08 85.08 10.0 14.32 72.49 72.65 72.95 72.08 72.23 24.42 66.66 47.53 77.89 76.18 81.78 81.76 81.37 84.11 81.87 82.02 82.5 84.72 84.18 52.15 82.73 83.3 85.01 84.77 59.86 51.87
acc_loss_list: acc_loss_list:
0.8824635637047484 0.8805829807240245 0.3980959097320169 0.0903855195110484 0.02409496944052653 0.004113775270333736 0.0028208744710859768 0.0008227550540666805 0.0 0.00011753643629531167 0.0002350728725904563 0.0 0.0 0.0 0.0 0.8824635637047484 0.8316878232251997 0.14797837329572172 0.14609779031499756 0.14257169722614005 0.152797367183827 0.15103432063939815 0.027621062529384042 0.020921485660554785 0.0008227550540666805 0.0036436295251528242 0.29642689233662434 0.39033850493653033 0.8824635637047484 0.8805829807240245 0.3980959097320169 0.0903855195110484 0.02409496944052653 0.004113775270333736 0.0028208744710859768 0.0008227550540666805 0.0 0.00011753643629531167 0.0002350728725904563 0.0 0.0 0.0 0.0 0.8824635637047484 0.8316878232251997 0.14797837329572172 0.14609779031499756 0.14257169722614005 0.152797367183827 0.15103432063939815 0.7129760225669958 0.21650211565585334 0.44134931828866947 0.08450869769628583 0.10460742830277377 0.03878702397743297 0.039022096850023426 0.04360601786553824 0.011401034320639386 0.03772919605077567 0.035966149506346995 0.030324400564174875 0.004231311706629048 0.010578279266572538 0.3870474847202633 0.027621062529384042 0.020921485660554785 0.0008227550540666805 0.0036436295251528242 0.29642689233662434 0.39033850493653033
import torch import torch
def bit_list(quant_type):
def ebit_list(quant_type, num_bits):
if quant_type == 'FLOAT':
e_bit_list = list(range(1,num_bits-1))
else:
e_bit_list = [0]
return e_bit_list
def numbit_list(quant_type):
if quant_type == 'INT': if quant_type == 'INT':
num_bit_list = list(range(2,17)) num_bit_list = list(range(2,17))
e_bit_list = [0]
elif quant_type == 'POT': elif quant_type == 'POT':
num_bit_list = list(range(2,9)) num_bit_list = list(range(2,9))
e_bit_list = [0]
else: else:
num_bit_list = [8] num_bit_list = list(range(2,9))
e_bit_list = list(range(1,7))
return num_bit_list, e_bit_list return num_bit_list
def build_bias_list(quant_type): def build_bias_list(quant_type):
if quant_type == 'POT': if quant_type == 'POT':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment