Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
M
Model-Transfer-Adaptability
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
haoyifan
Model-Transfer-Adaptability
Commits
91d53d31
Commit
91d53d31
authored
Apr 25, 2023
by
Zhihong Ma
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Delete data_analysis_mmd.py
parent
70d2f7fe
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
90 deletions
+0
-90
mzh/data_analysis_mmd.py
+0
-90
No files found.
mzh/data_analysis_mmd.py
deleted
100644 → 0
View file @
70d2f7fe
# -*- coding: utf-8 -*-
import
numpy
import
numpy
as
np
import
torch
import
sys
from
mmd_loss
import
*
from
collections
import
OrderedDict
d1
=
sys
.
argv
[
1
]
# bit
d2
=
sys
.
argv
[
2
]
# epoch
# d1=4
# d2=5
sum
=
0
flag
=
0
total_quan_list
=
list
()
total_base_list
=
list
()
# CNN FLOPs = Cout * Hout * Wout * (2 * Cin * K * K ) 是考虑bias 否则-1
# FCN FLOPs = Cout * Cin 是考虑bias 否则-1
# 把相关的relu,pool也考虑进去了
# MAdd
# weight0 =np.array( [ 705600.0+4704.0+ 3528.0 , 480000.0+ 1600.0 + 1200.0 , 95880.0 + 120.0,
# 20076.0 + 84.0 , 1670.0 ])
# weight1=np.array([705,600.0 , 480,000.0,+ 95,880.0 ,
# 20,076.0 , 1,670.0 ])
# flops
weight_f0
=
np
.
array
([
357504
+
4704
+
4704
,
241600
+
1600
+
1600
,
48000
+
120
,
10080
+
84
,
840
])
weight_f1
=
np
.
array
([
357504
,
241600
,
48000
,
10080
,
840
])
summary_quan_dict
=
OrderedDict
()
summary_base_dict
=
OrderedDict
()
losses
=
[]
# 最外层:不同epoch的字典 内层:各个网络层的grads
for
i
in
range
(
int
(
d2
)):
total_quan_list
.
append
(
torch
.
load
(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/'
+
str
(
d1
)
+
'/ckpt_cifar-10_lenet_bn_quant_'
+
str
(
i
+
1
)
+
'.pth'
))
#total_quan_list.append(torch.load('checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(d2) + '.pth'))
total_base_list
.
append
(
torch
.
load
(
'./project/p/checkpoint/cifar-10_lenet_bn/full'
+
'/ckpt_cifar-10_lenet_bn_'
+
str
(
i
+
1
)
+
'.pth'
))
for
k
,
_
in
total_base_list
[
i
][
'grads'
]
.
items
():
if
flag
==
0
:
summary_quan_dict
[
k
]
=
total_quan_list
[
i
][
'grads'
][
k
]
.
reshape
(
1
,
-
1
)
summary_base_dict
[
k
]
=
total_base_list
[
i
][
'grads'
][
k
]
.
reshape
(
1
,
-
1
)
else
:
# 字典里的数据不能直接改,需要重新赋值
a
=
summary_quan_dict
[
k
]
b
=
total_quan_list
[
i
][
'grads'
][
k
]
.
reshape
(
1
,
-
1
)
c
=
np
.
vstack
((
a
,
b
))
summary_quan_dict
[
k
]
=
c
a
=
summary_base_dict
[
k
]
b
=
total_base_list
[
i
][
'grads'
][
k
]
.
reshape
(
1
,
-
1
)
c
=
np
.
vstack
((
a
,
b
))
summary_base_dict
[
k
]
=
c
flag
=
1
cnt
=
0
flag
=
0
for
k
,
_
in
summary_quan_dict
.
items
():
if
flag
==
0
:
sum
+=
0.99
*
weight_f1
[
cnt
]
*
MK_MMD
(
source
=
summary_base_dict
[
k
],
target
=
summary_quan_dict
[
k
])
# weight
else
:
sum
+=
0.01
*
weight_f1
[
cnt
]
*
MK_MMD
(
source
=
summary_base_dict
[
k
],
target
=
summary_quan_dict
[
k
])
#bias
if
flag
==
1
:
cnt
=
cnt
+
1
flag
=
0
else
:
flag
=
1
sum
=
sum
/
(
weight_f0
.
sum
()
*
2
)
print
(
sum
)
f
=
open
(
'./project/p/lenet_ptq_similarity.txt'
,
'a'
)
f
.
write
(
'bit:'
+
str
(
d1
)
+
' epoch_num:'
+
str
(
d2
)
+
': '
+
str
(
sum
)
+
'
\n
'
)
f
.
close
()
# for k,v in summary_base_dict.items():
# if k== 'conv_layers.conv1.weight':
# print(v)
# print('===========')
# print(summary_quan_dict[k])
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment