Commit 32783663 by Zhihong Ma

feat post_training_quantize (LeNet)

parent 8f0b3e4e
# -*- coding: utf-8 -*-
from torch.serialization import load
from model import *
import argparse
import torch
import sys
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
def direct_quantize(model, test_loader, device):
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_forward(data) # 这里会依次调用model中各个层的forward,则会update qw
if i % 5000 == 0:
break
print('direct quantization finish')
def full_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.4f}%\n'.format(100. * correct / len(test_loader.dataset)))
def quantize_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_inference(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100. * correct / len(test_loader.dataset)
print('\nTest set: Quant Model Accuracy: {:.4f}%\n'.format(acc))
return acc
if __name__ == "__main__":
d1 = sys.argv[1] # num_bits
d2 = sys.argv[2] # mode
d3 = sys.argv[3] # n_exp
# d1 = 8
# d2 = 3
# d3 = 4
batch_size = 32
using_bn = True
load_quant_model_file = None
# load_model_file = None
net = 'LeNet' # 1:
acc = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./project/p/data', train=True, download=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./project/p/data', train=False, download=False ,transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
if using_bn:
model = LeNet(n_exp=int(d3), mode=int(d2)).to(device)
# 生成梯度分布图的时候是从0开始训练的
model.load_state_dict(torch.load('./project/p/ckpt/cifar-10_lenet_bn.pt', map_location='cpu'))
# else:
# model = Net()
# model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu'))
# save_file = "ckpt/mnist_cnn_ptq.pt"
# model.to(device)
model.eval()
full_inference(model, test_loader, device)
full_writer = SummaryWriter(log_dir='./project/p/ptqlog_mode' + str(d2) + '/' + str(d3) + '/' + 'full_log')
for name, param in model.named_parameters():
full_writer.add_histogram(tag=name + '_data', values=param.data)
num_bits = int(d1)
model.quantize(num_bits=num_bits)
model.eval()
print('Quantization bit: %d' % num_bits)
writer = SummaryWriter(log_dir='./project/p/ptqlog_mode' + str(d2) + '/' + str(d3) + '/' + 'quant_bit_' + str(d1) + '_log')
if load_quant_model_file is not None:
model.load_state_dict(torch.load(load_quant_model_file))
print("Successfully load quantized model %s" % load_quant_model_file)
direct_quantize(model, train_loader, device)
model.freeze() # 权重量化
for name, param in model.named_parameters():
writer.add_histogram(tag=name + '_data', values=param.data)
# 原PTQ mode=1时
# save_file = 'ckpt/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt'
dir_name ='./project/p/ckpt/mode'+ str(d2) + '_' + str(d3) + '/ptq'
if not os.path.isdir(dir_name):
os.makedirs(dir_name, mode=0o777)
os.chmod(dir_name, mode=0o777)
save_file = './project/p/ckpt/mode'+ str(d2) + '_' + str(d3) + '/ptq' + '/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt'
torch.save(model.state_dict(), save_file)
# 测试是否设备转移是否正确
# model.cuda()
# print(model.qconv1.M.device)
# model.cpu()
# print(model.qconv1.M.device)
acc = quantize_inference(model, test_loader, device)
f = open('./project/p/lenet_ptq_acc' + '.txt', 'a')
f.write('bit ' + str(d1) + ': ' + str(acc) + '\n')
f.close()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment