Commit b78e8fd4 by Zhihong Ma

feat old version for LeNet. (Resnet,Module not finished)

parent 32783663
# -*- coding: utf-8 -*-
import numpy
import numpy as np
import torch
import sys
from mmd_loss import *
from collections import OrderedDict
d1 = sys.argv[1] # bit
d2 = sys.argv[2] # epoch
# d1=4
# d2=5
sum=0
flag=0
total_quan_list=list()
total_base_list=list()
# CNN FLOPs = Cout * Hout * Wout * (2 * Cin * K * K ) 是考虑bias 否则-1
# FCN FLOPs = Cout * Cin 是考虑bias 否则-1
# 把相关的relu,pool也考虑进去了
# MAdd
# weight0 =np.array( [ 705600.0+4704.0+ 3528.0 , 480000.0+ 1600.0 + 1200.0 , 95880.0 + 120.0,
# 20076.0 + 84.0 , 1670.0 ])
# weight1=np.array([705,600.0 , 480,000.0,+ 95,880.0 ,
# 20,076.0 , 1,670.0 ])
# flops
weight_f0= np.array([357504+4704+4704, 241600+1600+1600,48000+120,10080+84,840])
weight_f1=np.array([357504, 241600,48000,10080,840])
summary_quan_dict=OrderedDict()
summary_base_dict=OrderedDict()
losses=[]
# 最外层:不同epoch的字典 内层:各个网络层的grads
for i in range(int(d2)):
total_quan_list.append(torch.load('./project/p/checkpoint/cifar-10_lenet_bn_quant/' + str(d1) + '/ckpt_cifar-10_lenet_bn_quant_'+str(i+1)+'.pth'))
#total_quan_list.append(torch.load('checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(d2) + '.pth'))
total_base_list.append(torch.load('./project/p/checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(i+1) + '.pth'))
for k, _ in total_base_list[i]['grads'].items():
if flag == 0:
summary_quan_dict[k] = total_quan_list[i]['grads'][k].reshape(1,-1)
summary_base_dict[k] = total_base_list[i]['grads'][k].reshape(1,-1)
else :
# 字典里的数据不能直接改,需要重新赋值
a=summary_quan_dict[k]
b=total_quan_list[i]['grads'][k].reshape(1,-1)
c=np.vstack((a,b))
summary_quan_dict[k] = c
a = summary_base_dict[k]
b = total_base_list[i]['grads'][k].reshape(1,-1)
c = np.vstack((a, b))
summary_base_dict[k] = c
flag = 1
cnt = 0
flag = 0
for k, _ in summary_quan_dict.items():
if flag == 0:
sum += 0.99*weight_f1[cnt] * MK_MMD(source=summary_base_dict[k], target=summary_quan_dict[k]) # weight
else:
sum += 0.01*weight_f1[cnt] * MK_MMD(source=summary_base_dict[k], target=summary_quan_dict[k]) #bias
if flag == 1:
cnt = cnt + 1
flag = 0
else:
flag=1
sum=sum/(weight_f0.sum()*2)
print(sum)
f = open('./project/p/lenet_ptq_similarity.txt','a')
f.write('bit:' + str(d1) + ' epoch_num:' + str(d2) +': '+str(sum)+'\n')
f.close()
# for k,v in summary_base_dict.items():
# if k== 'conv_layers.conv1.weight':
# print(v)
# print('===========')
# print(summary_quan_dict[k])
\ No newline at end of file
# -*- coding: utf-8 -*-
import numpy
import numpy as np
import torch
import sys
from collections import OrderedDict
import scipy.stats
import pandas as pd
import os
# 整体思路: 本函数实现的是关于bit的,在不同epoch节点(5, 10, ...) 的梯度分布相似度计算 (考虑到是不同epoch节点,则需要在这一段epoch内取平均相似度?)
# 外界调用: 会用不同的bit分别调用该函数
# csv中每行记录的是该bit量化情况下,不同epoch节点的平均加权梯度分布相似度
#
d1 = sys.argv[1] # bit
d2 = sys.argv[2] # mode
d3 = sys.argv[3] # n_exp
# d2 = sys.argv[2] # epoch
# d1=4
# d2=5
tag = 0
dirpath = './project/p/qat_analysis_data/mode' + str(d2)
if not os.path.isdir(dirpath):
os.makedirs(dirpath, mode=0o777)
os.chmod(dirpath, mode=0o777)
# if int(d2) == 1:
# csvpath = './project/p/qat_analysis_data/wasserstein_distance.csv'
# else:
if int(d2) != 3:
csvpath = './project/p/qat_analysis_data/mode' + str(d2) + '/wasserstein_distance.csv'
else:
csvpath = './project/p/qat_analysis_data/mode' + str(d2) + '/wasserstein_distance_' + str(d3) + '.csv'
# if os.path.exists("./qat_analysis_data/wasserstein_distance.csv"):
if os.path.exists(csvpath):
tag = 1
if tag == 0: # 还没有csv
df = pd.DataFrame()
else: # 已有csv
# df = pd.read_csv("./qat_analysis_data/wasserstein_distance.csv", index_col=0)
df = pd.read_csv(csvpath, index_col=0)
df2 = pd.DataFrame()
# CNN FLOPs = Cout * Hout * Wout * (2 * Cin * K * K ) 是考虑bias 否则-1
# FCN FLOPs = Cout * Cin 是考虑bias 否则-1
# 把相关的relu,pool也考虑进去了
# MAdd
# weight0 =np.array( [ 705600.0+4704.0+ 3528.0 , 480000.0+ 1600.0 + 1200.0 , 95880.0 + 120.0,
# 20076.0 + 84.0 , 1670.0 ])
# weight1=np.array([705,600.0 , 480,000.0,+ 95,880.0 ,
# 20,076.0 , 1,670.0 ])
# flops
weight_f0= np.array([357504+4704+4704, 241600+1600+1600,48000+120,10080+84,840])
weight_f1=np.array([357504, 241600,48000,10080,840])
# 对不同的epoch节点
for epoch in [5, 10, 15, 20, 25, 30]:
total_quan_list = []
total_base_list = []
summary_quan_dict = OrderedDict()
summary_base_dict = OrderedDict()
flag = 0
result = 0
# 最外层:不同epoch的字典 内层:各个网络层的grads
# 遍历epoch节点内的epoch,收集梯度信息
for i in range(epoch):
if int(d2) == 1:
total_quan_list.append(torch.load(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/' + str(d1) + '/ckpt_cifar-10_lenet_bn_quant_' + str(
i + 1) + '.pth'))
elif int(d2) == 2:
total_quan_list.append(torch.load(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/mode' + str(d2) + '/' + str(d1)+ '/ckpt_cifar-10_lenet_bn_quant_' + str(
epoch) + '.pth'))
else:
total_quan_list.append(torch.load(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/mode' + str(d2) + '_' + str(d3) + '/' + str(d1)+ '/ckpt_cifar-10_lenet_bn_quant_' + str(
epoch) + '.pth'))
# total_quan_list.append(torch.load('checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(d2) + '.pth'))
# full的数据数不够
total_base_list.append(
torch.load('./project/p/checkpoint/cifar-10_lenet_bn/full' + '/ckpt_cifar-10_lenet_bn_' + str(i + 1) + '.pth'))
for k, _ in total_base_list[i]['grads'].items(): # 得到每个epoch i 的各个层的梯度
if flag == 0: # 读的第一个epoch i 要新建立个数据矩阵的第一行,后续的epoch i都是在这行的基础上向下拓展
summary_quan_dict[k] = total_quan_list[i]['grads'][k].reshape(1, -1)
summary_base_dict[k] = total_base_list[i]['grads'][k].reshape(1, -1)
else:
# 字典里的数据不能直接改,需要重新赋值
a = summary_quan_dict[k]
b = total_quan_list[i]['grads'][k].reshape(1, -1)
c = np.vstack((a, b))
summary_quan_dict[k] = c
a = summary_base_dict[k]
b = total_base_list[i]['grads'][k].reshape(1, -1)
c = np.vstack((a, b))
summary_base_dict[k] = c
flag = 1
# loss = total_quan_list[i]['losses']
# print(loss)
# df = pd.read_csv('./data_analysis_folder/data.csv', index_col=0)
# # df = pd.DataFrame()
# df2 = pd.DataFrame()
# 上面是在收集数据,下面才是求和
for j in range(epoch):
flag0 = 0 # 各个layer的weight和bias
cnt = 0 # 依次遍历各个layer
sum = 0 # sum只是对一个epoch j 的加权梯度分布相似度记录
for k, _ in summary_quan_dict.items():
w = summary_base_dict[k][j, :] # 这里不合适 要改造
v = summary_quan_dict[k][j, :]
if flag0 == 0:
cur_weight = weight_f1[cnt] * scipy.stats.wasserstein_distance(w, v) # weight
# 不是很方便存 需要三维了(sheet)
# if tag == 1:
# df2[k] = [cur_weight]
# else:
# df[k] = [cur_weight]
sum += 0.99 * cur_weight
else:
cur_bias = weight_f1[cnt] * scipy.stats.wasserstein_distance(w, v) # bias
# if tag == 1:
# df2[k] = [cur_bias]
# else:
# df[k] = [cur_bias]
sum += 0.01 * cur_bias
if flag0 == 1:
cnt = cnt + 1
flag0 = 0
else:
flag0 = 1
sum = sum / (weight_f1.sum() * 2)
result += sum # 对各个epoch i的加权梯度相似度求和
print(sum)
result /= epoch # 对epoch节点阶段内的梯度相似度求平均
if tag == 1:
df2[str(epoch)] = [result]
else :
df[str(epoch)] = [result]
result = 0
if tag == 1 :
df = df.append(df2)
# df.to_csv('./qat_analysis_data/wasserstein_distance.csv')
df.to_csv(csvpath)
else :
# df.to_csv('./qat_analysis_data/wasserstein_distance.csv')
df.to_csv(csvpath)
# f = open('lenet_ptq_wasserstein_similarity.txt','a')
# f.write('bit:' + str(d1) + ' epoch_num:' + str(d2) +': '+str(sum)+'\n')
# f.close()
# -*- coding: utf-8 -*-
import numpy
import numpy as np
import torch
import sys
from collections import OrderedDict
import scipy.stats
import pandas as pd
from model import *
# from audtorch.metrics.functional import pearsonr
import math
# 该函数用于读出全精度、量化模型的weight和bias值,以作观察
if __name__ == "__main__":
d1 = sys.argv[1]
# d2 = sys.argv[2]
# d1=8
# df = pd.read_csv('./ptq_analysis_data/seperate_data.csv', index_col=0)
df = pd.DataFrame()
# df2 = pd.DataFrame()
base_data = torch.load('./project/p/ckpt/trail/model_trail.pt')
# checkpoint_data = torch.load('./project/p/ckpt/trail/model_trail.pt')
print('full_precision weight/bias loaded!')
checkpoint_dir = './project/p/checkpoint/cifar-10_trail_model'
# quan_data = torch.load('ckpt/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt')
# print('quantization bit ' + str(d1) + ' weight/bias loaded!')
sum=0
if int(d1) == 1:
print(base_data)
# for k, _ in base_data.items():
# base_data[k] = base_data[k].reshape(1, -1)
# # quan_data[k] = quan_data[k].reshape(1, -1)
# print(base_data[k])
else:
for i in [4,9,14,19]:
check_data = torch.load(checkpoint_dir + '/ckpt_cifar-10_trail_model%s.pt' % (str(i)))
print(check_data)
# if int(d2) == 1:
# print(base_data[k])
# else:
# print(quan_data[k])
# -*- coding: utf-8 -*-
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam): # 有qparam i.e. self 中记录的mode、scale、zeropoint、n_exp等信息,其实不用再额外传参
x = qparam.quantize_tensor(x, qparam.mode) # INT
x = qparam.dequantize_tensor(x, qparam.mode) # FP(int)
return x
@staticmethod
def backward(ctx, grad_output): # 用线性粗略近似 STE
return grad_output, None
\ No newline at end of file
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import OrderedDict
def get_model_histogram(model):
"""
Description:
- get norm gradients from model, and store in a OrderDict
Args:
- model: (torch.nn.Module), torch model
Returns:
- grads in OrderDict
"""
gradshisto = OrderedDict()
grads = OrderedDict()
for name, params in model.named_parameters():
grad = params.grad
if grad is not None:
tmp = {}
params_np = grad.cpu().numpy()
histogram, bins = np.histogram(params_np.flatten(),bins=20)
tmp['histogram'] = list(histogram)
tmp['bins'] = list(bins)
gradshisto[name] = tmp
grads[name] = params_np
return gradshisto,grads
def get_model_norm_gradient(model):
"""
Description:
- get norm gradients from model, and store in a OrderDict
Args:
- model: (torch.nn.Module), torch model
Returns:
- grads in OrderDict
"""
grads = OrderedDict()
for name, params in model.named_parameters():
grad = params.grad
if grad is not None:
grads[name] = grad.norm().item()
return grads
def get_grad_histogram(grads_sum):
gradshisto = OrderedDict()
# grads = OrderedDict()
for name, params in grads_sum.items():
grad = params
if grad is not None:
tmp = {}
#params_np = grad.cpu().numpy()
params_np = grad
histogram, bins = np.histogram(params_np.flatten(),bins=20)
tmp['histogram'] = list(histogram)
tmp['bins'] = list(bins)
gradshisto[name] = tmp #每层一个histogram (tmp中的是描述直方图的信息)
# grads[name] = params_np
return gradshisto
\ No newline at end of file
# -*- coding: utf-8 -*-
import numpy
import numpy as np
import torch
import sys
from collections import OrderedDict
import scipy.stats
import pandas as pd
import os
import os.path
#
d1 = sys.argv[1] # bit
d2 = sys.argv[2] # mode
d3 = sys.argv[3] # n_exp
# d1=2
# d2 = sys.argv[2] # epoch
# d1=2
# d2=3
sum=0
flag=0
# CNN FLOPs = Cout * Hout * Wout * (2 * Cin * K * K ) 是考虑bias 否则-1
# FCN FLOPs = Cout * Cin 是考虑bias 否则-1
# 把相关的relu,pool也考虑进去了
# MAdd
# weight0 =np.array( [ 705600.0+4704.0+ 3528.0 , 480000.0+ 1600.0 + 1200.0 , 95880.0 + 120.0,
# 20076.0 + 84.0 , 1670.0 ])
# weight1=np.array([705,600.0 , 480,000.0,+ 95,880.0 ,
# 20,076.0 , 1,670.0 ])
# flops
weight_f0= np.array([357504+4704+4704, 241600+1600+1600,48000+120,10080+84,840])
weight_f1=np.array([357504, 241600,48000,10080,840])
summary_quan_dict=OrderedDict()
summary_base_dict=OrderedDict()
# 最外层:不同epoch的字典 内层:各个网络层的grads
flag = 0
dirpath = './project/p/qat_analysis_data/mode' + str(d2)
if not os.path.isdir(dirpath):
os.makedirs(dirpath, mode=0o777)
os.chmod(dirpath, mode=0o777)
if int(d2) == 1 or int(d2) == 2:
csvpath = dirpath + '/scratch_loss.csv'
else:
csvpath = dirpath + '/scratch_loss_' + str(d3) + '.csv'
if os.path.exists(csvpath):
flag = 1
if flag == 0: # 还没有csv
df = pd.DataFrame()
else: # 已有csv
df = pd.read_csv(csvpath, index_col=0)
df2 = pd.DataFrame()
for epoch in ([5, 10, 15, 20, 25, 30]):
sums = []
total_quan_list = []
total_base_list = []
for i in range(int(epoch)):
if int(d2) == 1:
total_quan_list.append(torch.load(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/' + str(d1) + '/ckpt_cifar-10_lenet_bn_quant_' + str(
i + 1) + '.pth'))
elif int(d2) == 2:
total_quan_list.append(torch.load(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/mode' + str(d2) + '/' + str(
d1) + '/ckpt_cifar-10_lenet_bn_quant_' + str(
i + 1) + '.pth'))
else:
total_quan_list.append(torch.load(
'./project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/mode' + str(d2) + '_' + str(d3) + '/' + str(
d1) + '/ckpt_cifar-10_lenet_bn_quant_' + str(
i + 1) + '.pth'))
sum_loss = 0
loss = total_quan_list[i]['losses']
# print(len(loss))
# 每个epoch的不同batch的
for j in range(len(loss)):
sum_loss += loss[j].cpu()
# print(sum_loss)
sum_loss /= j
sums.append(sum_loss)
# print(sums)
#print(sums[0] - sums[int(d1) - 1])
if flag == 0:
df[str(epoch)] = [(sums[0] - sums[int(epoch) - 1]).detach().numpy()]
else:
df2[str(epoch)] = [(sums[0] - sums[int(epoch) - 1]).detach().numpy()]
if flag == 0:
# df.to_csv('./qat_analysis_data/scratch_loss.csv')
df.to_csv(csvpath)
else:
df = df.append(df2)
# df.to_csv('./qat_analysis_data/scratch_loss.csv')
df.to_csv(csvpath)
\ No newline at end of file
This diff is collapsed. Click to expand it.
# -*- coding: utf-8 -*-
from torch.serialization import load
from model import *
import argparse
import torch
import sys
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
# 为了得到PTQ的权重数据的伪量化版 (先quantize再dequantize,与full precision的权重数据分布相似,便于用wasserstein距离求相似度)
def direct_quantize(model, test_loader, device):
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_forward(data) # 这里会依次调用model中各个层的forward,则会update qw
if i % 5000 == 0:
break
print('direct quantization finish')
def full_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
print('\nTest set: Full Model Accuracy: {:.4f}%\n'.format(100. * correct / len(test_loader.dataset)))
def quantize_inference(model, test_loader, device):
correct = 0
for i, (data, target) in enumerate(test_loader, 1):
data, target = data.to(device), target.to(device)
output = model.quantize_inference(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100. * correct / len(test_loader.dataset)
print('\nTest set: Quant Model Accuracy: {:.4f}%\n'.format(acc))
return acc
if __name__ == "__main__":
d1 = sys.argv[1]
batch_size = 32
using_bn = True
load_quant_model_file = None
# load_model_file = None
net = 'LeNet' # 1:
acc = 0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
)
if using_bn:
model = LeNet().to(device)
# 生成梯度分布图的时候是从0开始训练的
model.load_state_dict(torch.load('ckpt/cifar-10_lenet_bn.pt', map_location='cpu'))
# else:
# model = Net()
# model.load_state_dict(torch.load('ckpt/mnist_cnn.pt', map_location='cpu'))
# save_file = "ckpt/mnist_cnn_ptq.pt"
# model.to(device)
model.eval()
full_inference(model, test_loader, device)
num_bits = int(d1)
model.quantize(num_bits=num_bits)
model.eval()
print('Quantization bit: %d' % num_bits)
dir_name = './ptq_fake_log/' + 'quant_bit_' + str(d1) + '_log'
if not os.path.isdir(dir_name):
os.makedirs(dir_name, mode=0o777)
os.chmod(dir_name, mode=0o777)
qwriter = SummaryWriter(log_dir=dir_name)
# for name, param in model.named_parameters():
# qwriter.add_histogram(tag=name + '_data', values=param.data)
if load_quant_model_file is not None:
model.load_state_dict(torch.load(load_quant_model_file))
print("Successfully load quantized model %s" % load_quant_model_file)
direct_quantize(model, train_loader, device)
model.fakefreeze() # 权重量化
for name, param in model.named_parameters():
qwriter.add_histogram(tag=name + '_data', values=param.data)
dir_name ='ckpt/ptq_fakefreeze'
if not os.path.isdir(dir_name):
os.makedirs(dir_name, mode=0o777)
os.chmod(dir_name, mode=0o777)
save_file = 'ckpt/ptq_fakefreeze/cifar-10_lenet_bn_ptq_' + str(d1) + '_.pt'
torch.save(model.state_dict(), save_file)
This diff is collapsed. Click to expand it.
# -*- coding: utf-8 -*-
from model import *
from get_weight import *
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
import sys
import time
# import matplotlib.pyplot as plt
# import matplotlib
from torchvision.datasets import ImageFolder
from torch.utils.tensorboard import SummaryWriter
from absl import app, flags
# from easydict import EasyDict
# from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
# from cleverhans.torch.attacks.projected_gradient_descent import (
# projected_gradient_descent,
# )
def train(model, device, train_loader, optimizer, epoch):
model.train()
lossLayer = torch.nn.CrossEntropyLoss()
flag = 0
cnt = 0
for batch_idx, (data, target) in enumerate(train_loader):
cnt = cnt + 1
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = lossLayer(output, target)
loss.backward()
histo, grads = (get_model_histogram(model))
if flag == 0:
flag = 1
grads_sum = grads
else:
for k,v in grads_sum.items():
grads_sum[k] += grads[k]
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
))
for k, v in grads_sum.items():
grads_sum[k] = v / len(train_loader.dataset)
return grads_sum
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
acc=0
lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')
# report = EasyDict(nb_test=0, correct=0, correct_fgm=0, correct_pgd=0)
with torch.no_grad:
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# x_fgm = fast_gradient_method(model, data, 0.01, np.inf)
# x_pgd = projected_gradient_descent(model, data, 0.01, 0.01, 40, np.inf)
# model prediction on clean examples
# _, y_pred = model(data).max(1)
# # model prediction on FGM adversarial examples
# _, y_pred_fgm = model(x_fgm).max(1)
#
# # model prediction on PGD adversarial examples
# _, y_pred_pgd = model(x_pgd).max(1)
# report.nb_test += target.size(0)
# report.correct += y_pred.eq(target).sum().item()
# report.correct_fgm += y_pred_fgm.eq(target).sum().item()
# report.correct_pgd += y_pred_pgd.eq(target).sum().item()
test_loss += lossLayer(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc=100. * correct / len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format(
test_loss, acc
))
# print(
# "test acc on clean examples (%): {:.3f}".format(
# report.correct / report.nb_test * 100.0
# )
# )
# print(
# "test acc on FGM adversarial examples (%): {:.3f}".format(
# report.correct_fgm / report.nb_test * 100.0
# )
# )
# print(
# "test acc on PGD adversarial examples (%): {:.3f}".format(
# report.correct_pgd / report.nb_test * 100.0
# )
# )
return acc
batch_size = 32
test_batch_size = 32
seed = 1
# epochs = 15
d1 = sys.argv[1]
epochs = int(d1)
lr = 0.001
momentum = 0.5
save_model = False
using_bn = True
net = 'LeNet'
torch.manual_seed(seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
)
#if using_bn:
if (net == 'VGG19') == True:
model = VGG_19().to(device)
elif (net == 'LeNet') == True:
model = LeNet().to(device)
# else:
# model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
writer = SummaryWriter(log_dir='./fullprecision_log')
#optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9,0.999),eps=1e-08,weight_decay=0,amsgrad=False)
for epoch in range(1, epochs + 1):
grads_sum = train(model, device, train_loader, optimizer, epoch)
acc = test(model, device, test_loader)
print('epoch:', epoch)
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
'grads': grads_sum,
'accuracy':acc
}
# for name, param in model.named_parameters():
# writer.add_histogram(tag=name + '_grad', values=param.grad, global_step=epoch)
# writer.add_histogram(tag=name + '_data', values=param.data, global_step=epoch)
for name, param in grads_sum.items():
# 此处的grad是累加值吧 不是平均值
writer.add_histogram(tag=name + '_grad', values=param, global_step=epoch)
# 取这个epoch最后一个batch算完之后的weight
for name, param in model.named_parameters():
writer.add_histogram(tag=name + '_data', values=param.data, global_step=epoch)
if (net == 'LeNet') == True:
torch.save(checkpoint, 'checkpoint/cifar-10_lenet_bn/full/ckpt_cifar-10_lenet_bn_%s.pth' % (str(epoch)))
#保存参数
# if (net == 'VGG19') == True:
# torch.save(checkpoint, 'checkpoint/cifar-10_vgg19_bn/ckpt_cifar-10_vgg19_bn_%s.pth' % (str(epoch)))
# elif (net == 'LeNet') == True:
# torch.save(checkpoint, 'checkpoint/cifar-10_lenet_bn/ckpt_cifar-10_lenet_bn_%s.pth' % (str(epoch)))
#print('Saved all parameters!\n')
if save_model:
if not osp.exists('ckpt'):
os.makedirs('ckpt')
#if using_bn:
if (net == 'VGG19') == True:
torch.save(model.state_dict(), 'ckpt/cifar-10_vgg19_bn.pt')
elif (net == 'LeNet') == True:
torch.save(model.state_dict(), 'ckpt/cifar-10_lenet_bn.pt')
# else:
# torch.save(model.state_dict(), 'ckpt/cifar-10_vgg19.pt')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment