Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
M
Model-Transfer-Adaptability
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
haoyifan
Model-Transfer-Adaptability
Commits
cea929c5
Commit
cea929c5
authored
Apr 16, 2023
by
Klin
Browse files
Options
Browse Files
Download
Plain Diff
Merge branch 'master' of
http://62.234.201.16/hao/Model-Transfer-Adaptability
parents
5f2495a6
b96538f0
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
15 deletions
+37
-15
mzh/new_mzh/ResNet18-50-152/model.py
+6
-4
mzh/new_mzh/ResNet18-50-152/module.py
+31
-11
No files found.
mzh/new_mzh/ResNet18-50-152/model.py
View file @
cea929c5
...
...
@@ -72,8 +72,9 @@ class ResNet(nn.Module):
self
.
layer2
.
quantize
(
quant_type
=
quant_type
,
num_bits
=
num_bits
,
e_bits
=
e_bits
)
self
.
layer3
.
quantize
(
quant_type
=
quant_type
,
num_bits
=
num_bits
,
e_bits
=
e_bits
)
self
.
layer4
.
quantize
(
quant_type
=
quant_type
,
num_bits
=
num_bits
,
e_bits
=
e_bits
)
self
.
qavgpool1
=
QAdaptiveAvgPool2d
(
quant_type
,
qi
=
False
,
qo
=
True
,
num_bits
=
num_bits
,
e_bits
=
e_bits
)
self
.
qfc1
=
QLinear
(
quant_type
,
self
.
fc
,
qi
=
False
,
qo
=
True
,
num_bits
=
num_bits
,
e_bits
=
e_bits
)
self
.
qavgpool1
=
QAdaptiveAvgPool2d
(
quant_type
,
qi
=
False
,
num_bits
=
num_bits
,
e_bits
=
e_bits
)
# self.qfc1 = QLinear(quant_type, self.fc,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self
.
qfc1
=
QLinear
(
quant_type
,
self
.
fc
,
qi
=
True
,
qo
=
True
,
num_bits
=
num_bits
,
e_bits
=
e_bits
)
def
quantize_forward
(
self
,
x
):
# for _, layer in self.quantize_layers.items():
...
...
@@ -101,7 +102,8 @@ class ResNet(nn.Module):
qo
=
self
.
layer3
.
freeze
(
qinput
=
qo
)
qo
=
self
.
layer4
.
freeze
(
qinput
=
qo
)
self
.
qavgpool1
.
freeze
(
qi
=
qo
)
self
.
qfc1
.
freeze
(
qi
=
self
.
qavgpool1
.
qo
)
# self.qfc1.freeze(qi=self.qavgpool1.qo)
self
.
qfc1
.
freeze
()
def
fakefreeze
(
self
):
pass
...
...
@@ -204,7 +206,7 @@ class BasicBlock(nn.Module):
else
:
self
.
qelementadd
.
freeze
(
qi0
=
self
.
qconvbn1
.
qo
,
qi1
=
qinput
)
# 这里或许需要补充个层来处理elementwise add
self
.
qrelu1
.
freeze
(
qi
=
self
.
qelementadd
.
qo
)
# 需要自己统计qi
self
.
qrelu1
.
freeze
(
qi
=
self
.
qelementadd
.
qo
)
return
self
.
qrelu1
.
qi
# relu后的qo可用relu统计的qi
def
quantize_inference
(
self
,
x
):
...
...
mzh/new_mzh/ResNet18-50-152/module.py
View file @
cea929c5
...
...
@@ -64,7 +64,7 @@ def bias_qmax(quant_type):
elif
quant_type
==
'POT'
:
return
get_qmax
(
quant_type
)
else
:
return
get_qmax
(
quant_type
,
16
,
5
)
return
get_qmax
(
quant_type
,
16
,
7
)
# e7 m9 (e5时不够大,导致数据溢出到两侧
)
# 转化为FP32,不需再做限制
...
...
@@ -574,17 +574,23 @@ class QConvBN(QModule):
# 待修改 需要有qo吧
class
QAdaptiveAvgPool2d
(
QModule
):
def
__init__
(
self
,
quant_type
,
qi
=
False
,
qo
=
True
,
num_bits
=
8
,
e_bits
=
3
):
def
__init__
(
self
,
quant_type
,
qi
=
False
,
qo
=
True
,
num_bits
=
8
,
e_bits
=
3
):
super
(
QAdaptiveAvgPool2d
,
self
)
.
__init__
(
quant_type
,
qi
,
qo
,
num_bits
,
e_bits
)
def
freeze
(
self
,
qi
=
None
):
def
freeze
(
self
,
qi
=
None
,
qo
=
None
):
if
hasattr
(
self
,
'qi'
)
and
qi
is
not
None
:
raise
ValueError
(
'qi has been provided in init function.'
)
if
not
hasattr
(
self
,
'qi'
)
and
qi
is
None
:
raise
ValueError
(
'qi is not existed, should be provided.'
)
# if hasattr(self, 'qo') and qo is not None:
# raise ValueError('qo has been provided in init function.')
# if not hasattr(self, 'qo') and qo is None:
# raise ValueError('qo is not existed, should be provided.')
if
qi
is
not
None
:
self
.
qi
=
qi
# if qo is not None:
# self.qo = qo
# def fakefreeze(self, qi=None):
# if hasattr(self, 'qi') and qi is not None:
...
...
@@ -603,15 +609,16 @@ class QAdaptiveAvgPool2d(QModule):
x
=
F
.
adaptive_avg_pool2d
(
x
,(
1
,
1
))
# 对输入输出都量化一下就算是量化了
if
hasattr
(
self
,
'qo'
):
self
.
qo
.
update
(
x
)
x
=
FakeQuantize
.
apply
(
x
,
self
.
qo
)
#
if hasattr(self, 'qo'):
#
self.qo.update(x)
#
x = FakeQuantize.apply(x, self.qo)
return
x
def
quantize_inference
(
self
,
x
):
x
=
F
.
adaptive_avg_pool2d
(
x
,(
1
,
1
))
x
=
FakeQuantize
.
apply
(
x
,
self
.
qo
)
# x = FakeQuantize.apply(x, self.qo
# x = get_nearest_val(self.quant_type,x) # 这里可能并不适配于PoT的情况 缺少一个放缩?
return
x
...
...
@@ -625,7 +632,12 @@ class QModule_2(nn.Module):
if
qi1
:
self
.
qi1
=
QParam
(
quant_type
,
num_bits
,
e_bits
)
# qi在此处就已经被num_bits和mode赋值了
if
qo
:
self
.
qo
=
QParam
(
quant_type
,
num_bits
,
e_bits
)
# qo在此处就已经被num_bits和mode赋值了
# if num_bits <=9 :
# self.qo = QParam(quant_type,num_bits, e_bits)
# if num_bits > 9 and num_bits<13:
# self.qo = QParam(quant_type,8, e_bits) # qo在此处就已经被num_bits和mode赋值了
# else:
self
.
qo
=
QParam
(
quant_type
,
num_bits
,
e_bits
)
self
.
quant_type
=
quant_type
self
.
num_bits
=
num_bits
...
...
@@ -694,11 +706,20 @@ class QElementwiseAdd(QModule_2):
return
x
def
quantize_inference
(
self
,
x0
,
x1
):
# 此处input为已经量化的qx
# x0_d = self.qi0.dequantize_tensor(x0)
# x1_d = self.qi1.dequantize_tensor(x1)
# print(f"x0={x0_d.reshape(-1)[:10]}")
# print(f"x1={x1_d.reshape(-1)[:10]}")
x0
=
x0
-
self
.
qi0
.
zero_point
x1
=
x1
-
self
.
qi1
.
zero_point
x
=
self
.
M0
*
(
x0
+
x1
*
self
.
M1
)
# x = get_nearest_val(self.quant_type,x)
x
=
get_nearest_val
(
self
.
quant_type
,
x
)
x
=
x
+
self
.
qo
.
zero_point
return
x
\ No newline at end of file
# x_d = self.qo.dequantize_tensor(x)
# print(f"x={x_d.reshape(-1)[:10]}")
# print(f"loss={x_d.reshape(-1)[:10]-(x0_d.reshape(-1)[:10]+x1_d.reshape(-1)[:10])}")
# print('=============')
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment