Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
M
Model-Transfer-Adaptability
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
haoyifan
Model-Transfer-Adaptability
Commits
32783663
Commit
32783663
authored
Apr 03, 2023
by
Zhihong Ma
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat post_training_quantize (LeNet)
parent
8f0b3e4e
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
140 additions
and
0 deletions
+140
-0
mzh/post_training_quantize.py
+140
-0
No files found.
mzh/post_training_quantize.py
0 → 100644
View file @
32783663
# -*- coding: utf-8 -*-
from
torch.serialization
import
load
from
model
import
*
import
argparse
import
torch
import
sys
import
torch.nn
as
nn
import
torch.optim
as
optim
from
torchvision
import
datasets
,
transforms
import
os
import
os.path
as
osp
from
torch.utils.tensorboard
import
SummaryWriter
def
direct_quantize
(
model
,
test_loader
,
device
):
for
i
,
(
data
,
target
)
in
enumerate
(
test_loader
,
1
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
.
quantize_forward
(
data
)
# 这里会依次调用model中各个层的forward,则会update qw
if
i
%
5000
==
0
:
break
print
(
'direct quantization finish'
)
def
full_inference
(
model
,
test_loader
,
device
):
correct
=
0
for
i
,
(
data
,
target
)
in
enumerate
(
test_loader
,
1
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
))
.
sum
()
.
item
()
print
(
'
\n
Test set: Full Model Accuracy: {:.4f}
%
\n
'
.
format
(
100.
*
correct
/
len
(
test_loader
.
dataset
)))
def
quantize_inference
(
model
,
test_loader
,
device
):
correct
=
0
for
i
,
(
data
,
target
)
in
enumerate
(
test_loader
,
1
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
.
quantize_inference
(
data
)
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
))
.
sum
()
.
item
()
acc
=
100.
*
correct
/
len
(
test_loader
.
dataset
)
print
(
'
\n
Test set: Quant Model Accuracy: {:.4f}
%
\n
'
.
format
(
acc
))
return
acc
if
__name__
==
"__main__"
:
d1
=
sys
.
argv
[
1
]
# num_bits
d2
=
sys
.
argv
[
2
]
# mode
d3
=
sys
.
argv
[
3
]
# n_exp
# d1 = 8
# d2 = 3
# d3 = 4
batch_size
=
32
using_bn
=
True
load_quant_model_file
=
None
# load_model_file = None
net
=
'LeNet'
# 1:
acc
=
0
device
=
torch
.
device
(
'cuda'
if
torch
.
cuda
.
is_available
()
else
'cpu'
)
print
(
device
)
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./project/p/data'
,
train
=
True
,
download
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
))
])),
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
1
,
pin_memory
=
False
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./project/p/data'
,
train
=
False
,
download
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
))
])),
batch_size
=
batch_size
,
shuffle
=
True
,
num_workers
=
1
,
pin_memory
=
False
)
if
using_bn
:
model
=
LeNet
(
n_exp
=
int
(
d3
),
mode
=
int
(
d2
))
.
to
(
device
)
# 生成梯度分布图的时候是从0开始训练的
model
.
load_state_dict
(
torch
.
load
(
'./project/p/ckpt/cifar-10_lenet_bn.pt'
,
map_location
=
'cpu'
))
# else:
# model = Net()
# model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu'))
# save_file = "ckpt/mnist_cnn_ptq.pt"
# model.to(device)
model
.
eval
()
full_inference
(
model
,
test_loader
,
device
)
full_writer
=
SummaryWriter
(
log_dir
=
'./project/p/ptqlog_mode'
+
str
(
d2
)
+
'/'
+
str
(
d3
)
+
'/'
+
'full_log'
)
for
name
,
param
in
model
.
named_parameters
():
full_writer
.
add_histogram
(
tag
=
name
+
'_data'
,
values
=
param
.
data
)
num_bits
=
int
(
d1
)
model
.
quantize
(
num_bits
=
num_bits
)
model
.
eval
()
print
(
'Quantization bit:
%
d'
%
num_bits
)
writer
=
SummaryWriter
(
log_dir
=
'./project/p/ptqlog_mode'
+
str
(
d2
)
+
'/'
+
str
(
d3
)
+
'/'
+
'quant_bit_'
+
str
(
d1
)
+
'_log'
)
if
load_quant_model_file
is
not
None
:
model
.
load_state_dict
(
torch
.
load
(
load_quant_model_file
))
print
(
"Successfully load quantized model
%
s"
%
load_quant_model_file
)
direct_quantize
(
model
,
train_loader
,
device
)
model
.
freeze
()
# 权重量化
for
name
,
param
in
model
.
named_parameters
():
writer
.
add_histogram
(
tag
=
name
+
'_data'
,
values
=
param
.
data
)
# 原PTQ mode=1时
# save_file = 'ckpt/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt'
dir_name
=
'./project/p/ckpt/mode'
+
str
(
d2
)
+
'_'
+
str
(
d3
)
+
'/ptq'
if
not
os
.
path
.
isdir
(
dir_name
):
os
.
makedirs
(
dir_name
,
mode
=
0
o777
)
os
.
chmod
(
dir_name
,
mode
=
0
o777
)
save_file
=
'./project/p/ckpt/mode'
+
str
(
d2
)
+
'_'
+
str
(
d3
)
+
'/ptq'
+
'/cifar-10_lenet_bn_ptq_'
+
str
(
d1
)
+
'_.pt'
torch
.
save
(
model
.
state_dict
(),
save_file
)
# 测试是否设备转移是否正确
# model.cuda()
# print(model.qconv1.M.device)
# model.cpu()
# print(model.qconv1.M.device)
acc
=
quantize_inference
(
model
,
test_loader
,
device
)
f
=
open
(
'./project/p/lenet_ptq_acc'
+
'.txt'
,
'a'
)
f
.
write
(
'bit '
+
str
(
d1
)
+
': '
+
str
(
acc
)
+
'
\n
'
)
f
.
close
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment