Commit 440fa1d7 by Zhihong Ma

feat: Naive MIA

parent 9a015751
#!/usr/bin/python
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.dataset import TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler
import model
from model import init_params as w_init
from train import train_model, train_attack_model, prepare_attack_data
from sklearn.metrics import classification_report
from sklearn.metrics import precision_score, recall_score
import argparse
import numpy as np
import os
import copy
import random
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR
#set the seed for reproducibility
np.random.seed(1234)
#Flag to enable early stopping
need_earlystop = False
########################
# Model Hyperparameters
########################
#Number of layer's parameters for target and shadow models
# target_filters = [128, 256, 256]
# shadow_filters = [64, 128, 128]
#New FC layers size for pretrained model
# n_fc= [256, 128]
#For CIFAR-10 and MNIST dataset
num_classes = 10
# #No. of training epocs
# num_epochs = 100
# #how many samples per batch to load
# batch_size = 128
# #learning rate
# learning_rate = 0.001
# #Learning rate decay
# lr_decay = 0.0001
#Regularizer
reg=1e-4
#percentage of dataset to use for shadow model
shadow_split = 0.6
#Number of validation samples
n_validation = 1000
#Number of processes
num_workers = 2
#Hidden units for MNIST model
n_hidden_mnist = 32
################################
#Attack Model Hyperparameters
################################
NUM_EPOCHS = 100
BATCH_SIZE = 10
#Learning rate
LR_ATTACK = 0.1
#L2 Regulariser
REG = 1e-7
#weight decay
LR_DECAY = 0.96
#No of hidden units
n_hidden = 256
#Binary Classsifier
out_classes = 2
def get_cmd_arguments():
parser = argparse.ArgumentParser(prog="Membership Inference Attack")
parser.add_argument('--model', default='resnet18', type=str, choices=['resnet18', 'resnet50', 'resnet152', 'mobilenetv2'], help='Which model to use (resnet18,50,152 or mobilenet)')
parser.add_argument('--dataset', default='CIFAR10', type=str, choices=['CIFAR10', 'CIFAR100' ,'MNIST'], help='Which dataset to use (CIFAR10 or MNIST)')
parser.add_argument('--dataPath', default='/lustre/datasets', type=str, help='Path to store data')
parser.add_argument('--modelPath', default='./ckpt',type=str, help='Path to save or load model checkpoints')
parser.add_argument('-e','--num_epochs', default=100, type=int, metavar='EPOCHS', help='number of total epochs to run')
parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
parser.add_argument('-lr', '--learning_rate', default=0.001, type=float, metavar='LR', help='initial learning rate')
parser.add_argument('-wd','--lr_decay',default=0.0001,type=float,metavar='WD',help='lr schduler weight decay')
parser.add_argument('-j','--num_workers', default=2, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')
parser.add_argument('--trainTargetModel', action='store_true', help='Train a target model, if false then load an already trained model')
parser.add_argument('--trainShadowModel', action='store_true', help='Train a shadow model, if false then load an already trained model')
parser.add_argument('--trainAttackModel', action='store_true', help='Train an attack model, if false then load an already trained model')
parser.add_argument('--need_augm',action='store_true', help='To use data augmentation on target and shadow training set or not')
parser.add_argument('--need_topk',action='store_true', help='Flag to enable using Top 3 posteriors for attack data')
parser.add_argument('--param_init', action='store_true', help='Flag to enable custom model params initialization')
parser.add_argument('--verbose',action='store_true', help='Add Verbosity')
return parser.parse_args()
def get_data_transforms(dataset, augm=False):
if dataset == 'CIFAR10':
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
test_transforms = transforms.Compose([transforms.ToTensor(),
normalize])
if augm:
train_transforms = transforms.Compose([transforms.RandomRotation(5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
normalize])
else:
train_transforms = transforms.Compose([transforms.ToTensor(),
normalize])
elif dataset == 'CIFAR100':
normalize = transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
test_transforms = transforms.Compose([transforms.ToTensor(),
normalize])
if augm:
train_transforms = transforms.Compose([transforms.RandomRotation(5),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
normalize])
else:
train_transforms = transforms.Compose([transforms.ToTensor(),
normalize])
else:
#The values 0.1307 and 0.3081 used for the Normalize() transformation below are the global mean and standard deviation
#of the MNIST dataset
test_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))])
if augm:
train_transforms = torchvision.transforms.Compose([transforms.RandomRotation(5),
transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))])
else:
train_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.1307,), (0.3081,))])
return train_transforms, test_transforms
def split_dataset(train_dataset):
#For simplicity we are only using orignal training set and splitting into 4 equal parts
#and assign it to Target train/test and Shadow train/test.
total_size = len(train_dataset)
split1 = total_size // 4
split2 = split1*2
split3 = split1*3
# 0 ~ total_size-1
indices = list(range(total_size))
np.random.shuffle(indices)
#Shadow model train and test set
s_train_idx = indices[:split1]
s_test_idx = indices[split1:split2]
#Target model train and test set
t_train_idx = indices[split2:split3]
t_test_idx = indices[split3:]
return s_train_idx, s_test_idx,t_train_idx,t_test_idx
#--------------------------------------------------------------------------------
# Get dataloaders for Shadow and Target models
# Data Strategy:
# - Split the entire training dataset into 4 parts(T_tain, T_test, S_train, S_test)
# Target - Train on T_train and T_test
# Shadow - Train on S_train and S_test
# Attack - Use T_train and T_test for evaluation
# Use S_train and S_test for training
#--------------------------------------------------------------------------------
def get_data_loader(dataset,
data_dir,
batch,
shadow_split=0.5, # 目前还没用上
augm_required=False,
num_workers=1):
"""
Utility function for loading and returning train and valid
iterators over the CIFAR-10 and MNIST dataset.
"""
# 暂时固定了shadow split为1:1:1:1
# error_msg = "[!] shadow_split should be in the range [0, 1]."
# assert ((shadow_split >= 0) and (shadow_split <= 1)), error_msg
train_transforms, test_transforms = get_data_transforms(dataset,augm_required)
#Download test and train dataset
print(f"Using Dataset:{dataset}")
if dataset == 'CIFAR10':
#CIFAR10 training set
train_set = torchvision.datasets.CIFAR10(root=data_dir,
train=True,
transform=train_transforms,
download=False)
#CIFAR10 test set 在之前版中没用上,现可以考虑将其与train_set合并,构成一个统一的train set
test_set = torchvision.datasets.CIFAR10(root=data_dir,
train = False,
transform = test_transforms)
train_set = torch.utils.data.ConcatDataset([train_set, test_set])
# 原本只分割了train set (包含了到in,out),效果等同于对train set再做划分得到t_in(t_train),t_out(t_test),s_in(s_train),s_out(s_test),后续版本将test_set也合并进来了
s_train_idx, s_out_idx, t_train_idx, t_out_idx = split_dataset(train_set)
elif dataset == 'CIFAR100':
train_set = torchvision.datasets.CIFAR100(root=data_dir,
train=True,
transform=train_transforms,
download=False)
test_set = torchvision.datasets.CIFAR100(root=data_dir,
train = False,
transform = test_transforms)
train_set = torch.utils.data.ConcatDataset([train_set, test_set])
s_train_idx, s_out_idx, t_train_idx, t_out_idx = split_dataset(train_set)
elif dataset == 'MNIST':
#MNIST train set
train_set = torchvision.datasets.MNIST(root=data_dir,
train=True,
transform=train_transforms,
download=False)
#MNIST test set
test_set = torchvision.datasets.MNIST(root=data_dir,
train = False,
transform = test_transforms)
train_set = torch.utils.data.ConcatDataset([train_set, test_set])
s_train_idx, s_out_idx, t_train_idx, t_out_idx = split_dataset(train_set)
# Data samplers
s_train_sampler = SubsetRandomSampler(s_train_idx)
s_out_sampler = SubsetRandomSampler(s_out_idx)
t_train_sampler = SubsetRandomSampler(t_train_idx)
t_out_sampler = SubsetRandomSampler(t_out_idx)
#In our implementation we are keeping validation set to measure training performance
#From the held out set for target and shadow, we take n_validation samples.
#As train set is already small we decided to take valid samples from held out set
#as these are samples not used in training.
# if dataset == 'CIFAR10':
# out中的前n_validation个元素构成了in的val集 (没用于训练,可以)
# target_val_idx = t_out_idx[:n_validation]
# shadow_val_idx = s_out_idx[:n_validation]
# 将val set 等同于 test set
target_val_idx = t_out_idx
shadow_val_idx = s_out_idx
t_val_sampler = SubsetRandomSampler(target_val_idx)
s_val_sampler = SubsetRandomSampler(shadow_val_idx)
# elif dataset == 'CIFAR100':
# target_val_idx = t_out_idx
# shadow_val_idx = s_out_idx
# t_val_sampler = SubsetRandomSampler(target_val_idx)
# s_val_sampler = SubsetRandomSampler(shadow_val_idx)
# # elif dataset == 'MNIST':
# # target_val_idx = t_out_idx[:n_validation]
# # shadow_val_idx = s_out_idx[:n_validation]
# target_val_idx = t_out_idx
# shadow_val_idx = s_out_idx
# t_val_sampler = SubsetRandomSampler(target_val_idx)
# s_val_sampler = SubsetRandomSampler(shadow_val_idx)
#-------------------------------------------------
# Data loader
#-------------------------------------------------
# if dataset == 'CIFAR10':
t_train_loader = torch.utils.data.DataLoader(dataset=train_set,
batch_size=batch,
sampler = t_train_sampler,
num_workers=num_workers)
t_out_loader = torch.utils.data.DataLoader(dataset=train_set,
batch_size=batch,
sampler = t_out_sampler,
num_workers=num_workers)
t_val_loader = torch.utils.data.DataLoader(dataset=train_set,
batch_size=batch,
sampler=t_val_sampler,
num_workers=num_workers)
s_train_loader = torch.utils.data.DataLoader(dataset=train_set,
batch_size=batch,
sampler=s_train_sampler,
num_workers=num_workers)
s_out_loader = torch.utils.data.DataLoader(dataset=train_set,
batch_size=batch,
sampler=s_out_sampler,
num_workers=num_workers)
s_val_loader = torch.utils.data.DataLoader(dataset=train_set,
batch_size=batch,
sampler=s_val_sampler,
num_workers=num_workers)
# elif dataset == 'CIFAR100':
# t_train_loader = torch.utils.data.DataLoader(dataset=train_set,
# batch_size=batch,
# sampler = t_train_sampler,
# num_workers=num_workers)
# t_out_loader = torch.utils.data.DataLoader(dataset=train_set,
# batch_size=batch,
# sampler = t_out_sampler,
# num_workers=num_workers)
# t_val_loader = torch.utils.data.DataLoader(dataset=train_set,
# batch_size=batch,
# sampler=t_val_sampler,
# num_workers=num_workers)
# s_train_loader = torch.utils.data.DataLoader(dataset=train_set,
# batch_size=batch,
# sampler=s_train_sampler,
# num_workers=num_workers)
# s_out_loader = torch.utils.data.DataLoader(dataset=train_set,
# batch_size=batch,
# sampler=s_out_sampler,
# num_workers=num_workers)
# s_val_loader = torch.utils.data.DataLoader(dataset=train_set,
# batch_size=batch,
# sampler=s_val_sampler,
# num_workers=num_workers)
# elif dataset == 'MNIST':
# t_train_loader = torch.utils.data.DataLoader(dataset=train_set,
# batch_size=batch,
# sampler=t_train_sampler,
# num_workers=num_workers)
# t_out_loader = torch.utils.data.DataLoader(dataset=train_set,
# batch_size=batch,
# sampler=t_out_sampler,
# num_workers=num_workers)
# t_val_loader = torch.utils.data.DataLoader(dataset=train_set,
# batch_size=batch,
# sampler=t_val_sampler,
# num_workers=num_workers)
# s_train_loader = torch.utils.data.DataLoader(dataset=train_set,
# batch_size=batch,
# sampler=s_train_sampler,
# num_workers=num_workers)
# s_out_loader = torch.utils.data.DataLoader(dataset=train_set,
# batch_size=batch,
# sampler=s_out_sampler,
# num_workers=num_workers)
# s_val_loader = torch.utils.data.DataLoader(dataset=train_set,
# batch_size=batch,
# sampler=s_val_sampler,
# num_workers=num_workers)
print('Total Combined Test samples in {} dataset : {}'.format(dataset, len(train_set)))
# print('Total Train samples in {} dataset : {}'.format(dataset, len(train_set)))
print('Number of Target train samples: {}'.format(len(t_train_sampler)))
print('Number of Target valid samples: {}'.format(len(t_val_sampler))) #从test中取出来了部分做val
print('Number of Target test samples: {}'.format(len(t_out_sampler)))
print('Number of Shadow train samples: {}'.format(len(s_train_sampler)))
print('Number of Shadow valid samples: {}'.format(len(s_val_sampler))) #从test中取出来了部分做val
print('Number of Shadow test samples: {}'.format(len(s_out_sampler)))
return t_train_loader, t_val_loader, t_out_loader, s_train_loader, s_val_loader, s_out_loader
def attack_inference(model,
test_X,
test_Y,
device):
print('----Attack Model Testing----')
targetnames= ['Non-Member', 'Member']
pred_y = []
true_y = []
#Tuple of tensors => full tensor
# (batch_size, channels, height, width)=>(total_samples, channels, height, width)
X = torch.cat(test_X)
Y = torch.cat(test_Y)
print(f"Y:{Y}")
#Create Inference dataset (reconstructed)
inferdataset = TensorDataset(X,Y)
print('Length of Attack Model test dataset : [{}]'.format(len(inferdataset)))
dataloader = torch.utils.data.DataLoader(dataset=inferdataset,
batch_size=50,
shuffle=False,
num_workers=num_workers)
#Evaluation of Attack Model
model.eval()
with torch.no_grad():
correct = 0
total = 0
for i, (inputs, labels) in enumerate(dataloader):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
#Predictions for accuracy calculations
_, predictions = torch.max(outputs.data, 1)
total+=labels.size(0)
correct+=(predictions == labels).sum().item()
# print('True Labels for Batch [{}] are : {}'.format(i,labels))
# print('Predictions for Batch [{}] are : {}'.format(i,predictions))
true_y.append(labels.cpu())
pred_y.append(predictions.cpu())
attack_acc = correct / total
print('Attack Test Accuracy is : {:.2f}%'.format(100*attack_acc))
true_y = torch.cat(true_y).numpy()
pred_y = torch.cat(pred_y).numpy()
print('---Detailed Results----')
print(classification_report(true_y,pred_y, target_names=targetnames))
#Main Method to initate model training and attack
def create_attack(modelname, dataset, dataPath, modelPath, num_epochs , batch_size, learning_rate, lr_decay , num_workers,trainTargetModel, trainShadowModel,trainAttackModel, need_augm, need_topk, param_init, verbose):
dataset = dataset
need_augm = need_augm
verbose = verbose
#For using top 3 posterior probabilities
top_k = need_topk
if dataset == 'CIFAR10' or dataset == 'CIFAR100':
img_size = 32
#Input Channels for the Image
input_dim = 3
else:#MNIST
img_size = 28
input_dim = 1
datasetDir = os.path.join(dataPath,dataset)
modelDir = os.path.join(modelPath, dataset)
modelDir = os.path.join(modelDir,modelname)
#Create dataset and model directories
if not os.path.exists(datasetDir):
try:
os.makedirs(datasetDir)
except OSError:
pass
if not os.path.exists(modelDir):
try:
os.makedirs(modelDir)
except OSError:
pass
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
#Creating data loaders
# val是由test提取出来的部分
t_train_loader, t_val_loader, t_test_loader,\
s_train_loader, s_val_loader, s_test_loader = get_data_loader(dataset,
datasetDir,
batch_size,
shadow_split,
need_augm,
num_workers)
if (trainTargetModel):
if dataset == 'CIFAR10':
# target_model = model.TargetNet(input_dim,target_filters,img_size,num_classes).to(device)
if modelname =='resnet18':
target_model = model.resnet18().to(device)
elif modelname == 'resnet50':
target_model = model.resnet50().to(device)
elif modelname == 'resnet152':
target_model = model.resnet152().to(device)
elif modelname == 'mobilenetv2':
target_model = model.MobileNetV2().to(device)
elif dataset == 'CIFAR100':
if modelname =='resnet18':
target_model = model.resnet18(num_classes=100).to(device)
elif modelname == 'resnet50':
target_model = model.resnet50(num_classes=100).to(device)
elif modelname == 'resnet152':
target_model = model.resnet152(num_classes=100).to(device)
elif modelname == 'mobilenetv2':
target_model = model.MobileNetV2(num_classes=100).to(device)
else:
target_model = model.MNISTNet(input_dim, n_hidden_mnist, num_classes).to(device)
if (param_init):
#Initialize params
target_model.apply(w_init)
# Print the model we just instantiated
if verbose:
print('----Target Model Architecure----')
print(target_model)
print('----Model Learnable Params----')
for name,param in target_model.named_parameters():
if param.requires_grad == True:
print("\t",name)
# Loss and optimizer for Tager Model
loss = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(target_model.parameters(), lr=learning_rate, weight_decay=reg)
# lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=lr_decay)
optimizer = torch.optim.Adam(target_model.parameters(), lr=learning_rate)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
# return attackX,attckY (均由CIFAR0的train set构建)
targetX, targetY = train_model(target_model,
t_train_loader,
t_val_loader,
t_test_loader, # 来源于CIFAR10的train+test set
loss,
optimizer,
lr_scheduler,
device,
modelDir,
verbose,
num_epochs,
top_k,
need_earlystop,
is_target=True)
else: #Target model training not required, load the saved checkpoint
target_file = os.path.join(modelDir,'best_target_model.pt')
print('Use Target model at the path ====> [{}] '.format(modelDir))
#Instantiate Target Model Class
if dataset == 'CIFAR10':
# target_model = model.TargetNet(input_dim,target_filters,img_size,num_classes).to(device)
# target_model = model.resnet18().to(device)
if modelname =='resnet18':
target_model = model.resnet18().to(device)
elif modelname == 'resnet50':
target_model = model.resnet50().to(device)
elif modelname == 'resnet152':
target_model = model.resnet152().to(device)
elif modelname == 'mobilenetv2':
target_model = model.MobileNetV2().to(device)
elif dataset == 'CIFAR100':
if modelname =='resnet18':
target_model = model.resnet18(num_classes=100).to(device)
elif modelname == 'resnet50':
target_model = model.resnet50(num_classes=100).to(device)
elif modelname == 'resnet152':
target_model = model.resnet152(num_classes=100).to(device)
elif modelname == 'mobilenetv2':
target_model = model.MobileNetV2(num_classes=100).to(device)
else:
target_model = model.MNISTNet(input_dim,n_hidden_mnist,num_classes).to(device)
target_model.load_state_dict(torch.load(target_file))
print('---Peparing Attack Training data---')
t_trainX = []
t_trainY = []
t_testX = []
t_testY = []
# test dataset for attack model : in
t_trainX, t_trainY = prepare_attack_data(t_trainX,t_trainY,target_model,t_train_loader,device,top_k)
# print(f"len of t_trainX:{len(t_trainX)},len of t_trainY:{len(t_trainY)}")
# test dataset for attack model : out
t_testX, t_testY = prepare_attack_data(t_testX,t_testY,target_model,t_test_loader,device,top_k,test_dataset=True)
# print(f"len of t_testX:{len(t_testX)},len of t_testY:{len(t_testY)}")
# list concat
targetX = t_trainX + t_testX
targetY = t_trainY + t_testY
# print(f"len of targetX:{len(targetX)},len of targetY:{len(targetY)}")
# print(f"targetX:{targetX}")
# print(f"targetY:{targetY}")
if (trainShadowModel):
if dataset == 'CIFAR10' :
# shadow_model = model.ShadowNet(input_dim,shadow_filters,img_size,num_classes).to(device)
if modelname =='resnet18':
shadow_model = model.resnet18().to(device)
elif modelname == 'resnet50':
shadow_model = model.resnet50().to(device)
elif modelname == 'resnet152':
shadow_model = model.resnet152().to(device)
elif modelname == 'mobilenetv2':
shadow_model = model.MobileNetV2().to(device)
elif dataset == 'CIFAR100':
if modelname =='resnet18':
shadow_model = model.resnet18(num_classes=100).to(device)
elif modelname == 'resnet50':
shadow_model = model.resnet50(num_classes=100).to(device)
elif modelname == 'resnet152':
shadow_model = model.resnet152(num_classes=100).to(device)
elif modelname == 'mobilenetv2':
shadow_model = model.MobileNetV2(num_classes=100).to(device)
else:
#Using less hidden units than target model to mimic the architecture
n_shadow_hidden = 16
shadow_model = model.MNISTNet(input_dim,n_shadow_hidden,num_classes).to(device)
if (param_init):
#Initialize params
shadow_model.apply(w_init)
# Print the model we just instantiated
if verbose:
print('----Shadow Model Architecure---')
print(shadow_model)
print('---Model Learnable Params----')
for name,param in shadow_model.named_parameters():
if param.requires_grad == True:
print("\t",name)
# Loss and optimizer
shadow_loss = nn.CrossEntropyLoss()
# shadow_optimizer = torch.optim.Adam(shadow_model.parameters(), lr=learning_rate, weight_decay=reg)
# shadow_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(shadow_optimizer,gamma=lr_decay)
# shadow_optimizer = torch.optim.Adam(shadow_model.parameters(), lr=learning_rate)
shadow_optimizer = torch.optim.Adam(shadow_model.parameters(), lr=learning_rate)
shadow_lr_scheduler = CosineAnnealingLR(shadow_optimizer, T_max=num_epochs)
# 返回的是由shadow model得到的attack model的训练数据集
shadowX, shadowY = train_model(shadow_model,
s_train_loader,
s_val_loader,
s_test_loader,
shadow_loss,
shadow_optimizer,
shadow_lr_scheduler,
device,
modelDir,
verbose,
num_epochs,
top_k,
need_earlystop,
is_target=False)
else: #Shadow model training not required, load the saved checkpoint
print('Using Shadow model at the path ====> [{}] '.format(modelDir))
shadow_file = os.path.join(modelDir,'best_shadow_model.pt')
assert os.path.isfile(shadow_file), 'Shadow Mode Checkpoint not found, aborting load'
#Instantiate Shadow Model Class
if dataset == 'CIFAR10' :
# shadow_model = model.ShadowNet(input_dim,shadow_filters,img_size,num_classes).to(device)
# shadow_model = model.resnet18().to(device)
if modelname =='resnet18':
shadow_model = model.resnet18().to(device)
elif modelname == 'resnet50':
shadow_model = model.resnet50().to(device)
elif modelname == 'resnet152':
shadow_model = model.resnet152().to(device)
elif modelname == 'mobilenetv2':
shadow_model = model.MobileNetV2().to(device)
elif dataset =='CIFAR100':
if modelname =='resnet18':
shadow_model = model.resnet18(num_classes=100).to(device)
elif modelname == 'resnet50':
shadow_model = model.resnet50(num_classes=100).to(device)
elif modelname == 'resnet152':
shadow_model = model.resnet152(num_classes=100).to(device)
elif modelname == 'mobilenetv2':
shadow_model = model.MobileNetV2(num_classes=100).to(device)
else:
#Using less hidden units than target model to mimic the architecture
n_shadow_hidden = 16
shadow_model = model.MNISTNet(input_dim,n_shadow_hidden,num_classes).to(device)
#Load the saved model
shadow_model.load_state_dict(torch.load(shadow_file))
#Prepare dataset for training attack model
print('----Preparing Attack training data---')
# 构造attack model的数据集 (train和test放在一起了) (只能用shadow model的各种输出)
# in training dataset
trainX =[]
trainY = []
testX=[]
testY=[]
# train dataset for attack model: in
trainX, trainY = prepare_attack_data(trainX,trainY,shadow_model,s_train_loader,device,top_k)
# train dataset for attack model: out
testX, testY = prepare_attack_data(testX,testY,shadow_model,s_test_loader,device,top_k,test_dataset=True)
shadowX = trainX + testX
shadowY = trainY + testY
###################################
# Attack Model Training
##################################
#The input dimension to MLP attack model
input_size = shadowX[0].size(1)
print('Input Feature dim for Attack Model : [{}]'.format(input_size))
attack_model = model.AttackMLP(input_size,n_hidden,out_classes).to(device)
if (trainAttackModel):
if (param_init):
#Initialize params
attack_model.apply(w_init)
# Loss and optimizer
attack_loss = nn.CrossEntropyLoss()
# attack_optimizer = torch.optim.Adam(attack_model.parameters(), lr=LR_ATTACK, weight_decay=REG)
# attack_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(attack_optimizer,gamma=LR_DECAY)
# attack_optimizer = torch.optim.Adam(attack_model.parameters(), lr=learning_rate)
attack_optimizer = torch.optim.SGD(attack_model.parameters(), lr=learning_rate,momentum=0.7, weight_decay=0.0001)
attack_lr_scheduler = CosineAnnealingLR(attack_optimizer, T_max=NUM_EPOCHS)
#Feature vector and labels for training Attack model
# 只能用shadow model的train/test set (都属于CIFAR10的train set) 作为attack model的train set
attackdataset = (shadowX, shadowY)
attack_valacc = train_attack_model(attack_model, attackdataset, attack_loss,
attack_optimizer, attack_lr_scheduler, device, modelDir,
NUM_EPOCHS, BATCH_SIZE, num_workers, verbose, earlystopping=False)
print('Validation Accuracy for the Best Attack Model is: {:.2f} %'.format(100* attack_valacc))
else:
# TODO
#Load the trained attack model
attack_path = os.path.join(modelDir,'best_attack_model.pt')
attack_model.load_state_dict(torch.load(attack_path))
#Inference on trained attack model
# 在target model的train/test set上去推理 (起始还是CIFAR10的train set的一部分)
# train attack model的时候用的是shadow model的,因此这里target model的数据集成为了attack model的测试集
# target_X中是target model的output vector,与量化相关
attack_inference(attack_model, targetX, targetY, device)
if __name__ == '__main__':
#get command line arguments from the user
args = get_cmd_arguments()
print(args)
#Generate Membership inference attack
create_attack(args.model,args.dataset, args.dataPath, args.modelPath, args.num_epochs , args.batch_size, args.learning_rate, args.lr_decay , args.num_workers,args.trainTargetModel, args.trainShadowModel, args.trainAttackModel,args.need_augm, args.need_topk, args.param_init, args.verbose)
\ No newline at end of file
from torch.autograd import Function
class FakeQuantize(Function):
@staticmethod
def forward(ctx, x, qparam):
x = qparam.quantize_tensor(x)
x = qparam.dequantize_tensor(x)
return x
@staticmethod
def backward(ctx, grad_output):
return grad_output, None
class GlobalVariables:
SELF_INPLANES = 0
# -*- coding: utf-8 -*-
# 用于多个module之间共享全局变量
def _init(): # 初始化
global _global_dict
_global_dict = {}
def set_value(value,is_bias=False):
# 定义一个全局变量
if is_bias:
_global_dict[0] = value
else:
_global_dict[1] = value
def get_value(is_bias=False): # 给bias独立于各变量外的精度
if is_bias:
return _global_dict[0]
else:
return _global_dict[1]
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
from module import *
import module
from global_var import GlobalVariables
#Below methods to claculate input featurs to the FC layer
#and weight initialization for CNN model is based on the below github repo
#Based on :https://github.com/Lab41/cyphercat/blob/master/Utils/models.py
def size_conv(size, kernel, stride=1, padding=0):
out = int(((size - kernel + 2*padding)/stride) + 1)
return out
def size_max_pool(size, kernel, stride=None, padding=0):
if stride == None:
stride = kernel
out = int(((size - kernel + 2*padding)/stride) + 1)
return out
#Calculate in_features for FC layer in Shadow Net
def calc_feat_linear_cifar(size):
feat = size_conv(size,3,1,1)
feat = size_max_pool(feat,2,2)
feat = size_conv(feat,3,1,1)
out = size_max_pool(feat,2,2)
return out
#Calculate in_features for FC layer in Shadow Net
def calc_feat_linear_mnist(size):
feat = size_conv(size,5,1)
feat = size_max_pool(feat,2,2)
feat = size_conv(feat,5,1)
out = size_max_pool(feat,2,2)
return out
#Parameter Initialization
def init_params(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight.data)
nn.init.zeros_(m.bias)
#####################################################
# Define Target, Shadow and Attack Model Architecture
#####################################################
#Target Model
class TargetNet(nn.Module):
def __init__(self, input_dim, hidden_layers, size, out_classes):
super(TargetNet, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=input_dim, out_channels=hidden_layers[0], kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_layers[0]),
# nn.Dropout(p=0.5),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=hidden_layers[0], out_channels=hidden_layers[1], kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_layers[1]),
# nn.Dropout(p=0.5),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
features = calc_feat_linear_cifar(size)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear((features**2 * hidden_layers[1]), hidden_layers[2]),
nn.ReLU(inplace=True),
nn.Linear(hidden_layers[2], out_classes)
)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.classifier(out)
return out
#Shadow Model mimicking target model architecture, for our implememtation is different than target
class ShadowNet(nn.Module):
def __init__(self, input_dim, hidden_layers,size,out_classes):
super(ShadowNet, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=input_dim, out_channels=hidden_layers[0], kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_layers[0]),
# nn.Dropout(p=0.5),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=hidden_layers[0], out_channels=hidden_layers[1], kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_layers[1]),
# nn.Dropout(p=0.5),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
features = calc_feat_linear_cifar(size)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear((features**2 * hidden_layers[1]), hidden_layers[2]),
nn.ReLU(inplace=True),
nn.Linear(hidden_layers[2], out_classes)
)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.classifier(out)
return out
#Pretrained VGG11 model for Target
class VggModel(nn.Module):
def __init__(self, num_classes,layer_config,pretrained=True):
super(VggModel, self).__init__()
#Load the pretrained VGG11_BN model
if pretrained:
pt_vgg = models.vgg11_bn(pretrained=pretrained)
#Deleting old FC layers from pretrained VGG model
print('### Deleting Avg pooling and FC Layers ####')
del pt_vgg.avgpool
del pt_vgg.classifier
self.model_features = nn.Sequential(*list(pt_vgg.features.children()))
#Adding new FC layers with BN and RELU for CIFAR10 classification
self.model_classifier = nn.Sequential(
nn.Linear(layer_config[0], layer_config[1]),
nn.BatchNorm1d(layer_config[1]),
nn.ReLU(inplace=True),
nn.Linear(layer_config[1], num_classes),
)
def forward(self, x):
x = self.model_features(x)
x = x.squeeze()
out = self.model_classifier(x)
return out
#Target/Shadow Model for MNIST
class MNISTNet(nn.Module):
def __init__(self, input_dim, n_hidden,out_classes=10,size=28):
super(MNISTNet, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=input_dim, out_channels=n_hidden, kernel_size=5),
nn.BatchNorm2d(n_hidden),
# nn.Dropout(p=0.5),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=n_hidden, out_channels=n_hidden*2, kernel_size=5),
nn.BatchNorm2d(n_hidden*2),
# nn.Dropout(p=0.5),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2)
)
features = calc_feat_linear_mnist(size)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(features**2 * (n_hidden*2), n_hidden*2),
nn.ReLU(inplace=True),
nn.Linear(n_hidden*2, out_classes)
)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.classifier(out)
return out
#Attack MLP Model
class AttackMLP(nn.Module):
def __init__(self, input_size, hidden_size=1024,out_classes=2):
super(AttackMLP, self).__init__()
self.classifier = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, out_classes)
)
def forward(self, x):
out = self.classifier(x)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10): # 这里将类别数设置为10
super(ResNet, self).__init__()
self.inplanes = 16 # 因为 CIFAR-10 图片较小,所以开始时需要更少的通道数
GlobalVariables.SELF_INPLANES = self.inplanes
# print('resnet init:'+ str(GlobalVariables.SELF_INPLANES))
# 输入层
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU()
# 残差层(4 个阶段,每个阶段包含 6n+2 个卷积层)
self.layer1 = MakeLayer_ResNet(block, 16, layers[0])
self.layer2 = MakeLayer_ResNet(block, 32, layers[1], stride=2)
self.layer3 = MakeLayer_ResNet(block, 64, layers[2], stride=2)
self.layer4 = MakeLayer_ResNet(block, 128, layers[3], stride=2)
# 分类层
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(128 * block.expansion, num_classes)
# 参数初始化
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
# 输入层
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
# 这里相比于imagenet的,少了一个maxpool,因为cifar10本身图片就小,如果再pool就太小了
# 残差层
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# 分类层
x = self.avgpool(x) # 输出的尺寸为 B,C,1,1
x = x.view(x.size(0), -1)
x = self.fc(x)
# out = F.softmax(x,dim = 1) # 这里不softmax也行 影响不大
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
self.qconvbnrelu1 = QConvBNReLU(quant_type,self.conv1,self.bn1,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)
# 没有输入num_bits 需修改
self.layer1.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer2.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer3.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer4.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.qavgpool1 = QAdaptiveAvgPool2d(quant_type,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qfc1 = QLinear(quant_type, self.fc,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
# self.qfc1 = QLinear(quant_type, self.fc,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)
def quantize_forward(self, x):
# for _, layer in self.quantize_layers.items():
# x = layer(x)
# out = F.softmax(x, dim=1)
# return out
x = self.qconvbnrelu1(x)
x = self.layer1.quantize_forward(x)
x = self.layer2.quantize_forward(x)
x = self.layer3.quantize_forward(x)
x = self.layer4.quantize_forward(x)
x = self.qavgpool1(x)
x = x.view(x.size(0), -1)
x = self.qfc1(x)
# out = F.softmax(x,dim = 1) # 这里不softmax也行 影响不大
return x
def freeze(self):
self.qconvbnrelu1.freeze() # 因为作为第一层是有qi的,所以freeze的时候无需再重新提供qi
qo = self.layer1.freeze(qinput = self.qconvbnrelu1.qo)
qo = self.layer2.freeze(qinput = qo)
qo = self.layer3.freeze(qinput = qo)
qo = self.layer4.freeze(qinput = qo)
self.qavgpool1.freeze(qi=qo)
self.qfc1.freeze(qi=self.qavgpool1.qo)
# self.qfc1.freeze()
def fakefreeze(self):
self.qconvbnrelu1.fakefreeze()
self.layer1.fakefreeze()
self.layer2.fakefreeze()
self.layer3.fakefreeze()
self.layer4.fakefreeze()
self.qfc1.fakefreeze()
def quantize_inference(self, x):
qx = self.qconvbnrelu1.qi.quantize_tensor(x)
qx = self.qconvbnrelu1.quantize_inference(qx)
qx = self.layer1.quantize_inference(qx)
qx = self.layer2.quantize_inference(qx)
qx = self.layer3.quantize_inference(qx)
qx = self.layer4.quantize_inference(qx)
qx = self.qavgpool1.quantize_inference(qx)
qx = qx.view(qx.size(0), -1)
qx = self.qfc1.quantize_inference(qx)
qx = self.qfc1.qo.dequantize_tensor(qx)
# out = F.softmax(qx,dim = 1) # 这里不softmax也行 影响不大
return qx
# BasicBlock 类
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
# 第一个卷积层
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
# 第二个卷积层
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
# shortcut
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(identity)
out += identity
out = self.relu(out)
return out
def quantize(self, quant_type ,num_bits=8, e_bits=3):
self.qconvbnrelu1 = QConvBNReLU(quant_type,self.conv1,self.bn1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qconvbn1 = QConvBN(quant_type,self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
if self.downsample is not None:
self.qconvbn2 = QConvBN(quant_type,self.downsample[0],self.downsample[1],qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qelementadd = QElementwiseAdd(quant_type,qi0=False, qi1=False, qo=True,num_bits=num_bits,e_bits=e_bits)
self.qrelu1 = QReLU(quant_type,qi= False,num_bits=num_bits,e_bits=e_bits) # 需要qi
def quantize_forward(self, x):
identity = x
out = self.qconvbnrelu1(x)
out = self.qconvbn1(out)
if self.downsample is not None:
identity = self.qconvbn2(identity)
# residual add
# out = identity + out # 这里是需要写一个elementwiseadd的变换的,待后续修改
out = self.qelementadd(out,identity)
out = self.qrelu1(out)
return out
def freeze(self, qinput):
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
self.qconvbnrelu1.freeze(qi= qinput) # 需要接前一个module的最后一个qo
self.qconvbn1.freeze(qi = self.qconvbnrelu1.qo)
if self.downsample is not None:
self.qconvbn2.freeze(qi = qinput) # 一条支路
self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = self.qconvbn2.qo)
else:
self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = qinput)
# 这里或许需要补充个层来处理elementwise add
self.qrelu1.freeze(qi = self.qelementadd.qo)
return self.qrelu1.qi # relu后的qo可用relu统计的qi
def fakefreeze(self):
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
self.qconvbnrelu1.fakefreeze() # 需要接前一个module的最后一个qo
self.qconvbn1.fakefreeze()
if self.downsample is not None:
self.qconvbn2.fakefreeze() # 一条支路
def quantize_inference(self, x):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
identity = x
out = self.qconvbnrelu1.quantize_inference(x)
out = self.qconvbn1.quantize_inference(out)
if self.downsample is not None:
identity = self.qconvbn2.quantize_inference(identity)
# out = identity + out # 这里可能需要写一个elementwiseadd的变换的,待后续修改
out = self.qelementadd.quantize_inference(out,identity)
out = self.qrelu1.quantize_inference(out)
return out
# Bottleneck 类
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
# 1x1 卷积层
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
# 3x3 卷积层
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
# 1x1 卷积层
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
# shortcut
self.relu = nn.ReLU()
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity # 相加是在这里处理的
out = self.relu(out)
return out
def quantize(self, quant_type ,num_bits=8, e_bits=3):
self.qconvbnrelu1 = QConvBNReLU(quant_type,self.conv1,self.bn1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qconvbnrelu2 = QConvBNReLU(quant_type,self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qconvbn1 = QConvBN(quant_type,self.conv3,self.bn3,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
if self.downsample is not None:
self.qconvbn2 = QConvBN(quant_type,self.downsample[0],self.downsample[1],qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qelementadd = QElementwiseAdd(quant_type,qi0=False, qi1=False, qo=True,num_bits=num_bits,e_bits=e_bits)
self.qrelu1 = QReLU(quant_type,qi= False,num_bits=num_bits,e_bits=e_bits) # 需要qi
def quantize_forward(self, x):
identity = x
out = self.qconvbnrelu1(x)
out = self.qconvbnrelu2(out)
out = self.qconvbn1(out)
if self.downsample is not None:
identity = self.qconvbn2(identity)
# residual add
# out = identity + out # 这里是需要写一个elementwiseadd的变换的,待后续修改
out = self.qelementadd(out,identity)
out = self.qrelu1(out)
return out
def freeze(self, qinput):
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
self.qconvbnrelu1.freeze(qi= qinput) # 需要接前一个module的最后一个qo
self.qconvbnrelu2.freeze(qi=self.qconvbnrelu1.qo)
self.qconvbn1.freeze(qi = self.qconvbnrelu2.qo)
if self.downsample is not None:
self.qconvbn2.freeze(qi = qinput) # 一条支路
self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = self.qconvbn2.qo)
else:
self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = qinput)
# 这里或许需要补充个层来处理elementwise add
self.qrelu1.freeze(qi = self.qelementadd.qo) # 需要自己统计qi
return self.qrelu1.qi # relu后的qo可用relu统计的qi
def fakefreeze(self):
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
self.qconvbnrelu1.fakefreeze()
self.qconvbnrelu2.fakefreeze()
self.qconvbn1.fakefreeze()
if self.downsample is not None:
self.qconvbn2.fakefreeze() # 一条支路
def quantize_inference(self, x):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
identity = x
out = self.qconvbnrelu1.quantize_inference(x)
out = self.qconvbnrelu2.quantize_inference(out)
out = self.qconvbn1.quantize_inference(out)
if self.downsample is not None:
identity = self.qconvbn2.quantize_inference(identity)
# out = identity + out # 这里可能需要写一个elementwiseadd的变换的,待后续修改
out = self.qelementadd.quantize_inference(out,identity)
out = self.qrelu1.quantize_inference(out)
return out
class MakeLayer_ResNet(nn.Module):
def __init__(self, block, planes, blocks, stride=1):
super(MakeLayer_ResNet, self).__init__()
# print('makelayer init:'+ str(GlobalVariables.SELF_INPLANES))
self.downsample = None
if stride != 1 or GlobalVariables.SELF_INPLANES != planes * block.expansion:
self.downsample = nn.Sequential(
nn.Conv2d(GlobalVariables.SELF_INPLANES, planes * block.expansion,kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion)
)
self.blockdict = nn.ModuleDict()
self.blockdict['block1'] = block(inplanes=GlobalVariables.SELF_INPLANES, planes=planes, stride=stride, downsample=self.downsample)
GlobalVariables.SELF_INPLANES = planes * block.expansion
for i in range(1, blocks): # block的个数 这里只能用字典了
self.blockdict['block' + str(i+1)] = block(inplanes=GlobalVariables.SELF_INPLANES, planes=planes) # 此处进行实例化了
def forward(self,x):
for _, layer in self.blockdict.items():
x = layer(x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
# 需检查
for _, layer in self.blockdict.items():
layer.quantize(quant_type=quant_type,num_bits=num_bits,e_bits=e_bits) # 这里是因为每一块都是block,而block中有具体的quantize策略, n_exp和mode已经在__init__中赋值了
def quantize_forward(self, x):
for _, layer in self.blockdict.items():
x = layer.quantize_forward(x) # 各个block中有具体的quantize_forward
return x
def freeze(self, qinput): # 需要在 Module Resnet的freeze里传出来
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
cnt = 0
for _, layer in self.blockdict.items():
if cnt == 0:
qo = layer.freeze(qinput = qinput)
cnt = 1
else:
qo = layer.freeze(qinput = qo) # 各个block中有具体的freeze
return qo # 供后续的层用
def fakefreeze(self):
for _, layer in self.blockdict.items():
layer.fakefreeze()
def quantize_inference(self, x):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
for _, layer in self.blockdict.items():
x = layer.quantize_inference(x) # 每个block中有具体的quantize_inference
return x
# 使用 ResNet18 模型
def resnet18(**kwargs):
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
return model
# 使用 ResNet50 模型
def resnet50(**kwargs):
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
return model
# 使用 ResNet152 模型
def resnet152(**kwargs):
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
return model
# ==========================================================
# MobileNetV2
class MobileNetV2(nn.Module):
def __init__(self, num_classes=10):
super(MobileNetV2, self).__init__()
self.conv1 = nn.Conv2d(3, 32, 3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU6(inplace=True)
# Bottleneck 层次, t指channel扩充系数
self.layer1 = MakeLayer_Mobile(32, 16, 1, t=1, stride=1)
self.layer2 = MakeLayer_Mobile(16, 24, 2, t=6, stride=2)
self.layer3 = MakeLayer_Mobile(24, 32, 3, t=6, stride=2)
# 根据CIFAR-10图像大小调整层数
self.layer4 = MakeLayer_Mobile(32, 96, 3, t=6, stride=1)
self.layer5 = MakeLayer_Mobile(96, 160, 3, t=6, stride=2)
self.layer6 = MakeLayer_Mobile(160, 320, 1, t=6, stride=1)
self.conv2 = nn.Conv2d(320, 1280, 1)
self.avg1 = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(1280, num_classes)
def forward(self, x):
# x = self.layers(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = self.layer6(x)
x = self.conv2(x)
x = self.avg1(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
self.qconvbnrelu1 = QConvBNReLU6(quant_type,self.conv1,self.bn1,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)
# 没有输入num_bits 需修改
self.layer1.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer2.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer3.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer4.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer5.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.layer6.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
self.qconv1 = QConv2d(quant_type, self.conv2, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
self.qavgpool1 = QAdaptiveAvgPool2d(quant_type,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qfc1 = QLinear(quant_type, self.fc,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
def quantize_forward(self, x):
# for _, layer in self.quantize_layers.items():
# x = layer(x)
# out = F.softmax(x, dim=1)
# return out
x = self.qconvbnrelu1(x)
x = self.layer1.quantize_forward(x)
x = self.layer2.quantize_forward(x)
x = self.layer3.quantize_forward(x)
x = self.layer4.quantize_forward(x)
x = self.layer5.quantize_forward(x)
x = self.layer6.quantize_forward(x)
x = self.qconv1(x)
x = self.qavgpool1(x)
x = x.view(x.size(0), -1)
x = self.qfc1(x)
out = F.softmax(x,dim = 1) # 这里不softmax也行 影响不大
return out
def freeze(self):
self.qconvbnrelu1.freeze() # 因为作为第一层是有qi的,所以freeze的时候无需再重新提供qi
qo = self.layer1.freeze(qinput = self.qconvbnrelu1.qo)
qo = self.layer2.freeze(qinput = qo)
qo = self.layer3.freeze(qinput = qo)
qo = self.layer4.freeze(qinput = qo)
qo = self.layer5.freeze(qinput = qo)
qo = self.layer6.freeze(qinput = qo)
self.qconv1.freeze(qi = qo)
self.qavgpool1.freeze(qi=self.qconv1.qo)
self.qfc1.freeze(qi=self.qavgpool1.qo)
# self.qfc1.freeze()
def fakefreeze(self):
self.qconvbnrelu1.fakefreeze()
self.layer1.fakefreeze()
self.layer2.fakefreeze()
self.layer3.fakefreeze()
self.layer4.fakefreeze()
self.layer5.fakefreeze()
self.layer6.fakefreeze()
self.qconv1.fakefreeze()
self.qfc1.fakefreeze()
def quantize_inference(self, x):
qx = self.qconvbnrelu1.qi.quantize_tensor(x)
qx = self.qconvbnrelu1.quantize_inference(qx)
qx = self.layer1.quantize_inference(qx)
qx = self.layer2.quantize_inference(qx)
qx = self.layer3.quantize_inference(qx)
qx = self.layer4.quantize_inference(qx)
qx = self.layer5.quantize_inference(qx)
qx = self.layer6.quantize_inference(qx)
qx = self.qconv1.quantize_inference(qx)
qx = self.qavgpool1.quantize_inference(qx)
qx = qx.view(qx.size(0), -1)
qx = self.qfc1.quantize_inference(qx)
qx = self.qfc1.qo.dequantize_tensor(qx)
out = F.softmax(qx,dim = 1) # 这里不softmax也行 影响不大
return out
class InvertedResidual(nn.Module):
def __init__(self, in_channels, out_channels, stride, expand_ratio):
super(InvertedResidual, self).__init__()
hidden_dims = int(in_channels * expand_ratio)
self.identity_flag = stride == 1 and in_channels == out_channels
# self.bottleneck = nn.Sequential(
# # Pointwise Convolution
# nn.Conv2d(in_channels, hidden_dims, 1),
# nn.BatchNorm2d(hidden_dims),
# nn.ReLU6(inplace=True),
# # Depthwise Convolution
# nn.Conv2d(hidden_dims, hidden_dims, 3, stride=stride, padding=1, groups=hidden_dims),
# nn.BatchNorm2d(hidden_dims),
# nn.ReLU6(inplace=True),
# # Pointwise & Linear Convolution
# nn.Conv2d(hidden_dims, out_channels, 1),
# nn.BatchNorm2d(out_channels),
# )
self.conv1 = nn.Conv2d(in_channels, hidden_dims, 1)
self.bn1 = nn.BatchNorm2d(hidden_dims)
self.relu1 = nn.ReLU6(inplace=True)
# Depthwise Convolution
self.conv2 = nn.Conv2d(hidden_dims, hidden_dims, 3, stride=stride, padding=1, groups=hidden_dims)
self.bn2 = nn.BatchNorm2d(hidden_dims)
self.relu2 = nn.ReLU6(inplace=True)
# Pointwise & Linear Convolution
self.conv3 = nn.Conv2d(hidden_dims, out_channels, 1)
self.bn3 = nn.BatchNorm2d(out_channels)
def forward(self, x):
# if self.identity_flag:
# return x + self.bottleneck(x)
# else:
# return self.bottleneck(x)
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.bn3(x)
if self.identity_flag:
return identity + x
else:
return x
def quantize(self, quant_type ,num_bits=8, e_bits=3):
self.qconvbnrelu1 = QConvBNReLU6(quant_type,self.conv1,self.bn1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qconvbnrelu2 = QConvBNReLU6(quant_type,self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qconvbn1 = QConvBN(quant_type,self.conv3,self.bn3,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
self.qelementadd = QElementwiseAdd(quant_type,qi0=False, qi1=False, qo=True,num_bits=num_bits,e_bits=e_bits)
def quantize_forward(self, x):
identity = x
out = self.qconvbnrelu1(x)
out = self.qconvbnrelu2(out)
out = self.qconvbn1(out)
if self.identity_flag:
out = self.qelementadd(out, identity)
return out
def freeze(self, qinput):
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
self.qconvbnrelu1.freeze(qi= qinput) # 需要接前一个module的最后一个qo
self.qconvbnrelu2.freeze(qi=self.qconvbnrelu1.qo)
self.qconvbn1.freeze(qi = self.qconvbnrelu2.qo)
if self.identity_flag:
self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = qinput)
return self.qelementadd.qo
else:
return self.qconvbn1.qo
def fakefreeze(self):
self.qconvbnrelu1.fakefreeze()
self.qconvbnrelu2.fakefreeze()
self.qconvbn1.fakefreeze()
def quantize_inference(self, x):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
identity = x
out = self.qconvbnrelu1.quantize_inference(x)
out = self.qconvbnrelu2.quantize_inference(out)
out = self.qconvbn1.quantize_inference(out)
if self.identity_flag:
out = self.qelementadd.quantize_inference(out, identity)
return out
class MakeLayer_Mobile(nn.Module):
def __init__(self, in_channels, out_channels, n_repeat, t, stride):
super(MakeLayer_Mobile, self).__init__()
# print('makelayer init:'+ str(GlobalVariables.SELF_INPLANES))
self.layers = nn.ModuleList()
for i in range(n_repeat):
if i == 0:
self.layers.append(InvertedResidual(in_channels, out_channels, stride, t))
else:
self.layers.append(InvertedResidual(in_channels, out_channels, 1, t))
in_channels = out_channels
# for l in self.layers:
# print(l)
def forward(self,x):
for layer in self.layers:
x = layer(x)
return x
def quantize(self, quant_type, num_bits=8, e_bits=3):
# 需检查
# print('CHECK======')
for layer in self.layers:
layer.quantize(quant_type=quant_type,num_bits=num_bits,e_bits=e_bits) # 这里是因为每一块都是block,而block中有具体的quantize策略, n_exp和mode已经在__init__中赋值了
# print(layer)
# print('CHECK======')
def quantize_forward(self, x):
for layer in self.layers:
x = layer.quantize_forward(x) # 各个block中有具体的quantize_forward
return x
def freeze(self, qinput): # 需要在 Module Resnet的freeze里传出来
# 这里的qconvbnrelu1其实是可以用前一层的qo的,但感觉不太好传参,就没用
# 还需仔细检查
cnt = 0
for layer in self.layers:
if cnt == 0:
qo = layer.freeze(qinput = qinput)
cnt = 1
else:
qo = layer.freeze(qinput = qo) # 各个block中有具体的freeze
return qo # 供后续的层用
def fakefreeze(self):
for layer in self.layers:
layer.fakefreeze()
def quantize_inference(self, x):
# 感觉是不需要进行初始的quantize_tensor和dequantize_tensor,因为他不是最前/后一层,只要中间的每层都在量化后的领域内,就不需要这种处理。
for layer in self.layers:
x = layer.quantize_inference(x) # 每个block中有具体的quantize_inference
return x
import math
import numpy as np
import gol
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from function import FakeQuantize
# 获取最近的量化值
# def get_nearest_val(quant_type,x,is_bias=False):
# if quant_type=='INT':
# return x.round_()
# plist = gol.get_value(is_bias)
# # print('get')
# # print(plist)
# # x = x / 64
# shape = x.shape
# xhard = x.view(-1)
# plist = plist.type_as(x)
# # 取最近幂次作为索引
# idx = (xhard.unsqueeze(0) - plist.unsqueeze(1)).abs().min(dim=0)[1]
# xhard = plist[idx].view(shape)
# xout = (xhard - x).detach() + x
# # xout = xout * 64
# return xout
def get_nearest_val(quant_type, x, is_bias=False, block_size=1000000):
if quant_type == 'INT':
return x.round_()
plist = gol.get_value(is_bias)
shape = x.shape
xhard = x.view(-1)
xout = torch.zeros_like(xhard)
plist = plist.type_as(x)
n_blocks = (x.numel() + block_size - 1) // block_size
for i in range(n_blocks):
start_idx = i * block_size
end_idx = min(start_idx + block_size, xhard.numel())
block_size_i = end_idx - start_idx
# print(x.numel())
# print(block_size_i)
# print(start_idx)
# print(end_idx)
xblock = xhard[start_idx:end_idx]
# xblock = xblock.view(shape[start_idx:end_idx])
plist_block = plist.unsqueeze(1) #.expand(-1, block_size_i)
idx = (xblock.unsqueeze(0) - plist_block).abs().min(dim=0)[1]
# print(xblock.shape)
xhard_block = plist[idx].view(xblock.shape)
xout[start_idx:end_idx] = (xhard_block - xblock).detach() + xblock
xout = xout.view(shape)
return xout
# 采用对称有符号量化时,获取量化范围最大值
def get_qmax(quant_type,num_bits=None, e_bits=None):
if quant_type == 'INT':
qmax = 2. ** (num_bits - 1) - 1
elif quant_type == 'POT':
qmax = 1
else: #FLOAT
m_bits = num_bits - 1 - e_bits
dist_m = 2 ** (-m_bits)
e = 2 ** (e_bits - 1)
expo = 2 ** e
m = 2 ** m_bits -1
frac = 1. + m * dist_m
qmax = frac * expo
return qmax
# 都采用有符号量化,zeropoint都置为0
def calcScaleZeroPoint(min_val, max_val, qmax):
scale = torch.max(max_val.abs(),min_val.abs()) / qmax
zero_point = torch.tensor(0.)
return scale, zero_point
# 将输入进行量化,输入输出都为tensor
def quantize_tensor(quant_type, x, scale, zero_point, qmax, is_bias=False):
# 量化后范围,直接根据位宽确定
qmin = -qmax
q_x = zero_point + x / scale
q_x.clamp_(qmin, qmax)
q_x = get_nearest_val(quant_type, q_x, is_bias)
return q_x
# bias使用不同精度,需要根据量化类型指定num_bits/e_bits
def bias_qmax(quant_type):
if quant_type == 'INT':
return get_qmax(quant_type, 64)
elif quant_type == 'POT':
return get_qmax(quant_type)
else:
return get_qmax(quant_type, 16, 7)
# 转化为FP32,不需再做限制
def dequantize_tensor(q_x, scale, zero_point):
return scale * (q_x - zero_point)
class QParam(nn.Module):
def __init__(self,quant_type, num_bits=8, e_bits=3):
super(QParam, self).__init__()
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.qmax = get_qmax(quant_type, num_bits, e_bits)
scale = torch.tensor([], requires_grad=False)
zero_point = torch.tensor([], requires_grad=False)
min = torch.tensor([], requires_grad=False)
max = torch.tensor([], requires_grad=False)
# 通过注册为register,使得buffer可以被记录到state_dict
self.register_buffer('scale', scale)
self.register_buffer('zero_point', zero_point)
self.register_buffer('min', min)
self.register_buffer('max', max)
# 更新统计范围及量化参数
def update(self, tensor):
if self.max.nelement() == 0 or self.max.data < tensor.max().data:
self.max.data = tensor.max().data
self.max.clamp_(min=0)
if self.min.nelement() == 0 or self.min.data > tensor.min().data:
self.min.data = tensor.min().data
self.min.clamp_(max=0)
self.scale, self.zero_point = calcScaleZeroPoint(self.min, self.max, self.qmax)
def quantize_tensor(self, tensor):
return quantize_tensor(self.quant_type, tensor, self.scale, self.zero_point, self.qmax)
def dequantize_tensor(self, q_x):
return dequantize_tensor(q_x, self.scale, self.zero_point)
# 该方法保证了可以从state_dict里恢复
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
key_names = ['scale', 'zero_point', 'min', 'max']
for key in key_names:
value = getattr(self, key)
value.data = state_dict[prefix + key].data
state_dict.pop(prefix + key)
# 该方法返回值将是打印该对象的结果
def __str__(self):
info = 'scale: %.10f ' % self.scale
info += 'zp: %.6f ' % self.zero_point
info += 'min: %.6f ' % self.min
info += 'max: %.6f' % self.max
return info
# 作为具体量化层的父类,qi和qo分别为量化输入/输出
class QModule(nn.Module):
def __init__(self,quant_type, qi=True, qo=True, num_bits=8, e_bits=3):
super(QModule, self).__init__()
if qi:
self.qi = QParam(quant_type,num_bits, e_bits)
if qo:
self.qo = QParam(quant_type,num_bits, e_bits)
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
def freeze(self):
pass # 空语句
def fakefreeze(self):
pass
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
"""
QModule 量化卷积
:quant_type: 量化类型
:conv_module: 卷积模块
:qi: 是否量化输入特征图
:qo: 是否量化输出特征图
:num_bits: 8位bit数
"""
class QConv2d(QModule):
def __init__(self, quant_type, conv_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QConv2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
# freeze方法可以固定真量化的权重参数,并将该值更新到原全精度层上,便于散度计算
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
# 这里因为在池化或者激活的输入,不需要对最大值和最小是进行额外的统计,会共享相同的输出
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
# 根据https://zhuanlan.zhihu.com/p/156835141, 这是式3 的系数
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
self.conv_module.weight.data = self.qw.quantize_tensor(self.conv_module.weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
self.conv_module.bias.data = quantize_tensor(self.quant_type,
self.conv_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0.,qmax=self.bias_qmax, is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x): # 前向传播,输入张量,x为浮点型数据
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi) # 对输入张量X完成量化
# foward前更新qw,保证量化weight时候scale正确
self.qw.update(self.conv_module.weight.data)
# 注意:此处主要为了统计各层x和weight范围,未对bias进行量化操作
# tmp_wgt = FakeQuantize.apply(self.conv_module.weight, self.qw)
# x = F.conv2d(x, tmp_wgt, self.conv_module.bias,
# stride=self.conv_module.stride,
# padding=self.conv_module.padding, dilation=self.conv_module.dilation,
# groups=self.conv_module.groups)
x = F.conv2d(x, FakeQuantize.apply(self.conv_module.weight, self.qw), self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
# 利用公式 q_a = M(\sigma(q_w-Z_w)(q_x-Z_x) + q_b)
def quantize_inference(self, x): # 此处input为已经量化的qx
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
# if self.quant_type is 'INT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QLinear(QModule):
def __init__(self, quant_type, fc_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QLinear, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.fc_module = fc_module
self.qw = QParam(quant_type, num_bits, e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
self.fc_module.weight.data = self.qw.quantize_tensor(self.fc_module.weight.data)
self.fc_module.weight.data = self.fc_module.weight.data - self.qw.zero_point
self.fc_module.bias.data = quantize_tensor(self.quant_type,
self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax, is_bias=True)
def fakefreeze(self):
self.fc_module.weight.data = self.qw.dequantize_tensor(self.fc_module.weight.data)
self.fc_module.bias.data = dequantize_tensor(self.fc_module.bias.data, scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
self.qw.update(self.fc_module.weight.data)
# tmp_wgt = FakeQuantize.apply(self.fc_module.weight, self.qw)
# x = F.linear(x, tmp_wgt, self.fc_module.bias)
x = F.linear(x, FakeQuantize.apply(self.fc_module.weight, self.qw), self.fc_module.bias)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.fc_module(x)
x = self.M * x
# if self.quant_type is 'INT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
class QReLU(QModule):
def __init__(self,quant_type, qi=False, qo=True, num_bits=8, e_bits=3):
super(QReLU, self).__init__(quant_type, qi, qo, num_bits, e_bits)
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.relu(x)
return x
def quantize_inference(self, x):
x = x.clone()
# x[x < self.qi.zero_point] = self.qi.zero_point
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
a = self.qi.zero_point.float().to(device)
x[x < a] = a
return x
class QMaxPooling2d(QModule):
def __init__(self, quant_type, kernel_size=3, stride=1, padding=0, qi=False, qo=True, num_bits=8,e_bits=3):
super(QMaxPooling2d, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
def freeze(self, qi=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
x = F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
return x
def quantize_inference(self, x):
return F.max_pool2d(x, self.kernel_size, self.stride, self.padding)
class QConvBNReLU(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QConvBNReLU, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True))
else:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
# if self.quant_type is 'INT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
x.clamp_(min=0)
return x
class QConvBN(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QConvBN, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True))
else:
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
# x = F.relu(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
# if self.quant_type is 'INT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
# x.clamp_(min=0)
return x
# 待修改 需要有qo吧
class QAdaptiveAvgPool2d(QModule):
def __init__(self, quant_type, qi=False, qo=True, num_bits=8, e_bits=3):
super(QAdaptiveAvgPool2d, self).__init__(quant_type,qi,qo,num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if qi is not None:
self.qi = qi
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qo is not None:
self.qo = qo
self.M.data = (self.qi.scale / self.qo.scale).data
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi) # 与ReLu一样,先更新qi的scale,再将x用PoT表示了 (不过一般前一层的qo都是True,则x已经被PoT表示了)
x = F.adaptive_avg_pool2d(x,(1, 1)) # 对输入输出都量化一下就算是量化了
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
x = F.adaptive_avg_pool2d(x,(1, 1)) # 对输入输出都量化一下就算是量化了
x = self.M * x
# if self.quant_type is 'INT':
x = get_nearest_val(self.quant_type,x)
return x
class QConvBNReLU6(QModule):
def __init__(self, quant_type, conv_module, bn_module, qi=True, qo=True, num_bits=8, e_bits=3):
super(QConvBNReLU6, self).__init__(quant_type, qi, qo, num_bits, e_bits)
self.conv_module = conv_module
self.bn_module = bn_module
self.qw = QParam(quant_type, num_bits,e_bits)
self.register_buffer('M', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self, qi=None, qo=None):
if hasattr(self, 'qi') and qi is not None:
raise ValueError('qi has been provided in init function.')
if not hasattr(self, 'qi') and qi is None:
raise ValueError('qi is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
if qi is not None:
self.qi = qi
if qo is not None:
self.qo = qo
self.M.data = (self.qw.scale * self.qi.scale / self.qo.scale).data
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = self.qw.quantize_tensor(weight.data)
self.conv_module.weight.data = self.conv_module.weight.data - self.qw.zero_point
self.conv_module.bias.data = quantize_tensor(self.quant_type,
bias, scale=self.qi.scale * self.qw.scale,
zero_point=0., qmax=self.bias_qmax,is_bias=True)
def fakefreeze(self):
self.conv_module.weight.data = self.qw.dequantize_tensor(self.conv_module.weight.data)
self.conv_module.bias.data = dequantize_tensor(self.conv_module.bias.data,scale=self.qi.scale * self.qw.scale, zero_point=0.)
def forward(self, x):
if hasattr(self, 'qi'):
self.qi.update(x)
x = FakeQuantize.apply(x, self.qi)
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
self.qw.update(weight.data)
x = F.conv2d(x, FakeQuantize.apply(weight, self.qw), bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu6(x)
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x):
a = torch.tensor(6)
a = self.qo.quantize_tensor(a)
x = x - self.qi.zero_point
x = self.conv_module(x)
x = self.M * x
# if self.quant_type is not 'POT':
# x = get_nearest_val(self.quant_type,x)
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point # 属于qo范围的数据
x.clamp_(min=0, max=a.item())
return x
class QModule_2(nn.Module):
def __init__(self,quant_type, qi0=True, qi1=True, qo=True, num_bits=8, e_bits=3):
super(QModule_2, self).__init__()
if qi0:
self.qi0 = QParam(quant_type,num_bits, e_bits) # qi在此处就已经被num_bits和mode赋值了
if qi1:
self.qi1 = QParam(quant_type,num_bits, e_bits) # qi在此处就已经被num_bits和mode赋值了
if qo:
self.qo = QParam(quant_type,num_bits, e_bits) # qo在此处就已经被num_bits和mode赋值了
self.quant_type = quant_type
self.num_bits = num_bits
self.e_bits = e_bits
self.bias_qmax = bias_qmax(quant_type)
def freeze(self):
pass
def fakefreeze(self):
pass
def quantize_inference(self, x):
raise NotImplementedError('quantize_inference should be implemented.')
class QElementwiseAdd(QModule_2):
def __init__(self, quant_type, qi0=True, qi1=True, qo=True, num_bits=8, e_bits=3):
super(QElementwiseAdd, self).__init__(quant_type, qi0, qi1, qo, num_bits, e_bits)
self.register_buffer('M0', torch.tensor([], requires_grad=False)) # 将M注册为buffer
self.register_buffer('M1', torch.tensor([], requires_grad=False)) # 将M注册为buffer
def freeze(self, qi0=None, qi1=None ,qo=None):
if hasattr(self, 'qi') and qi0 is not None:
raise ValueError('qi0 has been provided in init function.')
if not hasattr(self, 'qi') and qi0 is None:
raise ValueError('qi0 is not existed, should be provided.')
if hasattr(self, 'qi1') and qi0 is not None:
raise ValueError('qi1 has been provided in init function.')
if not hasattr(self, 'qi1') and qi0 is None:
raise ValueError('qi1 is not existed, should be provided.')
if hasattr(self, 'qo') and qo is not None:
raise ValueError('qo has been provided in init function.')
if not hasattr(self, 'qo') and qo is None:
raise ValueError('qo is not existed, should be provided.')
# 这里因为在池化或者激活的输入,不需要对最大值和最小是进行额外的统计,会共享相同的输出
if qi0 is not None:
self.qi0 = qi0
if qi1 is not None:
self.qi1 = qi1
if qo is not None:
self.qo = qo
# 根据https://zhuanlan.zhihu.com/p/156835141, 这是式3 的系数
self.M0.data = self.qi0.scale / self.qo.scale
self.M1.data = self.qi1.scale / self.qi0.scale
# self.M0.data = self.qi0.scale / self.qo.scale
# self.M1.data = self.qi1.scale / self.qo.scale
def forward(self, x0, x1): # 前向传播,输入张量,x为浮点型数据
if hasattr(self, 'qi0'):
self.qi0.update(x0)
x0 = FakeQuantize.apply(x0, self.qi0) # 对输入张量X完成量化
if hasattr(self, 'qi1'):
self.qi1.update(x1)
x1 = FakeQuantize.apply(x1, self.qi1) # 对输入张量X完成量化
x = x0 + x1
if hasattr(self, 'qo'):
self.qo.update(x)
x = FakeQuantize.apply(x, self.qo)
return x
def quantize_inference(self, x0, x1): # 此处input为已经量化的qx
x0 = x0 - self.qi0.zero_point
x1 = x1 - self.qi1.zero_point
x = self.M0 * (x0 + x1*self.M1)
# if self.quant_type is 'INT':
x = get_nearest_val(self.quant_type,x)
x = x + self.qo.zero_point
return x
# new modules for full-precision model - fold bn
# inference应该也需要相应的适配
class ConvBNReLU(nn.Module):
def __init__(self,conv_module, bn_module):
super(ConvBNReLU, self).__init__()
self.conv_module = conv_module
self.bn_module = bn_module
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self):
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = weight.data
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(bias)
else:
self.conv_module.bias.data = bias
def fakefreeze(self):
pass
def forward(self, x):
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
x = F.conv2d(x, weight, bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu(x)
return x
def quantize_inference(self, x):
x = self.conv_module(x)
x.clamp_(min=0)
return x
class ConvBN(nn.Module):
def __init__(self,conv_module, bn_module):
super(ConvBN, self).__init__()
self.conv_module = conv_module
self.bn_module = bn_module
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self):
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = weight.data
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(bias)
else:
self.conv_module.bias.data = bias
def fakefreeze(self):
pass
def forward(self, x):
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
x = F.conv2d(x, weight, bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
return x
def quantize_inference(self, x):
x = self.conv_module(x)
return x
class ConvBNReLU6(nn.Module):
def __init__(self,conv_module, bn_module):
super(ConvBNReLU6, self).__init__()
self.conv_module = conv_module
self.bn_module = bn_module
def fold_bn(self, mean, std):
if self.bn_module.affine:
gamma_ = self.bn_module.weight / std
weight = self.conv_module.weight * gamma_.view(self.conv_module.out_channels, 1, 1, 1)
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean + self.bn_module.bias
else:
bias = self.bn_module.bias - gamma_ * mean
else:
gamma_ = 1 / std
weight = self.conv_module.weight * gamma_
if self.conv_module.bias is not None:
bias = gamma_ * self.conv_module.bias - gamma_ * mean
else:
bias = -gamma_ * mean
return weight, bias
def freeze(self):
std = torch.sqrt(self.bn_module.running_var + self.bn_module.eps)
weight, bias = self.fold_bn(self.bn_module.running_mean, std)
self.conv_module.weight.data = weight.data
if self.conv_module.bias is None:
self.conv_module.bias = nn.Parameter(bias)
else:
self.conv_module.bias.data = bias
def fakefreeze(self):
pass
def forward(self, x):
if self.training:
y = F.conv2d(x, self.conv_module.weight, self.conv_module.bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding,
dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
y = y.permute(1, 0, 2, 3) # NCHW -> CNHW
y = y.contiguous().view(self.conv_module.out_channels, -1) # CNHW -> C,NHW
# mean = y.mean(1)
# var = y.var(1)
mean = y.mean(1).detach()
var = y.var(1).detach()
self.bn_module.running_mean = \
(1 - self.bn_module.momentum) * self.bn_module.running_mean + \
self.bn_module.momentum * mean
self.bn_module.running_var = \
(1 - self.bn_module.momentum) * self.bn_module.running_var + \
self.bn_module.momentum * var
else:
mean = Variable(self.bn_module.running_mean)
var = Variable(self.bn_module.running_var)
std = torch.sqrt(var + self.bn_module.eps)
weight, bias = self.fold_bn(mean, std)
x = F.conv2d(x, weight, bias,
stride=self.conv_module.stride,
padding=self.conv_module.padding, dilation=self.conv_module.dilation,
groups=self.conv_module.groups)
x = F.relu6(x)
return x
def quantize_inference(self, x):
x = self.conv_module(x)
x.clamp_(min=0,max=6)
return x
## Naive MIA
#### update 2023.5.24
1. 思路
(1)一共需要训练3个模型,分别是Target Model, Shadow Model,Attack Model. 其中Target Model是被攻击的模型。
(2)假设攻击者已知Target Model的结构并了解其训练所用的数据集的特征,于是Shadow Model采用与Target Model相同的结构。将数据集切分,分别用于Target Model和Shadow Model的训练和测试。
(3)在训练和测试Target Model和Shadow Model时,我们可以分别构建Attack Model的测试集和训练集。Attack Model的输入是Target Model或Shadow Model对一个图片输入的输出向量,输出是该图片是否属于Target Model或Shadow Model的训练集。我们构造Shadow Model就是为了构建Attack Model的训练集(因为我们作为攻击者,只知道Shadow Model的训练集,测试集是什么,而不知道Target Model的),Target Model的输出向量和训练集、测试集可以作为Attack Model的测试集,检验攻击成果。
(4)Attack Model是一个二分类网络,由若干FC,ReLU,BN层组成
(5)最终输出攻击的成功率 (即Attack Model的acc)和一些统计信息
2. 代码文件说明
attack.py:
核心文件,``get_cmd_arguments()``负责了读取各种参数配置
``split_dataset()``将数据集切分
``get_data_loader()``构造data loader
`` attack_inference``负责Attack Model的推理
``create_attack()``负责的工作较多,包括训练Target Model,训练Shadow Model,通过``train_model()``在训练中构造Attack Model的数据集(也可以load训练好的模型权值后通过``prepare_attack_data`构造). 训练Attack Model,并测试。
model.py :
支持的各种模型结构,目前用到的包括ResNet18,50,152,MobileNetV2
train.py:
``prepare_attack_data()``构建Attack Model训练/测试所用的数据集,attack_X是Target Model或Shadow Model的输出,attack_Y是Target Model或Shadow Model是否为其训练集数据
```train_per_epoch()```负责Target Model或Shadow Model每个epoch的训练
``val_per_epoch()``负责Target Model或Shadow Model每个epoch的测试/验证
``train_attack_model()``负责训练Attack Model的训练
``train_model()``负责组织训练模型,调用上述函数
其余程序文件的含义与之前用到的基本相同
3. 结果
- 先后分别在CIFAR10和CIFAR100上尝试了对ResNet18,50,152,MobileNetV2的攻击,在CIFAR100上一般能比CIFAR10上取得更好的结果,原因可能是CIFAR100数据集每类数据图片更少,更难训练,更容易过拟合,因此对MIA更脆弱。
只有对使用CIFAR10训练的MobileNetV2的攻击取得了相对显著的结果(Attack Model acc=61.57%),其余情况下,Attack Model的acc均在52%~56%,考虑到二分类问题随机情况下也有50%的acc,攻击效果不是很显著。
- 具体数据
Target Model和Shadow Model的optimizer是Adam,Attack Model的optimizer是SGD,lr_scheduler都是CosineAnnealingLR
* ResNet18 + CIFAR10:
```
Validation Accuracy for the Best Attack Model is: 54.40 %
----Attack Model Testing----
Y:tensor([1, 1, 1, ..., 0, 0, 0])
Length of Attack Model test dataset : [30000]
Attack Test Accuracy is : 53.26%
---Detailed Results----
precision recall f1-score support
Non-Member 0.55 0.38 0.45 15000
Member 0.52 0.69 0.59 15000
accuracy 0.53 30000
macro avg 0.54 0.53 0.52 30000
weighted avg 0.54 0.53 0.52 30000
```
* ResNet50 + CIFAR10:
```
Validation Accuracy for the Best Attack Model is: 55.70 %
----Attack Model Testing----
Y:tensor([1, 1, 1, ..., 0, 0, 0])
Length of Attack Model test dataset : [30000]
Attack Test Accuracy is : 52.20%
---Detailed Results----
precision recall f1-score support
Non-Member 0.52 0.47 0.49 15000
Member 0.52 0.58 0.55 15000
accuracy 0.52 30000
macro avg 0.52 0.52 0.52 30000
weighted avg 0.52 0.52 0.52 30000
```
* ResNet152 + CIFAR10
```
----Attack Model Testing----
Y:tensor([1, 1, 1, ..., 0, 0, 0])
Length of Attack Model test dataset : [30000]
Attack Test Accuracy is : 52.50%
---Detailed Results----
precision recall f1-score support
Non-Member 0.54 0.37 0.44 15000
Member 0.52 0.68 0.59 15000
accuracy 0.53 30000
macro avg 0.53 0.53 0.51 30000
weighted avg 0.53 0.53 0.51 30000
```
* MobileNetV2 + CIFAR10
```
Validation Accuracy for the Best Attack Model is: 65.20 %
----Attack Model Testing----
Y:tensor([1, 1, 1, ..., 0, 0, 0])
Length of Attack Model test dataset : [30000]
Attack Test Accuracy is : 61.57%
---Detailed Results----
precision recall f1-score support
Non-Member 0.63 0.57 0.60 15000
Member 0.61 0.66 0.63 15000
accuracy 0.62 30000
macro avg 0.62 0.62 0.61 30000
weighted avg 0.62 0.62 0.61 30000
```
* ResNet50 + CIFAR100:
```
----Attack Model Testing----
Y:tensor([1, 1, 1, ..., 0, 0, 0])
Length of Attack Model test dataset : [30000]
Attack Test Accuracy is : 54.30%
---Detailed Results----
precision recall f1-score support
Non-Member 0.54 0.61 0.57 15000
Member 0.55 0.47 0.51 15000
accuracy 0.54 30000
macro avg 0.54 0.54 0.54 30000
weighted avg 0.54 0.54 0.54 30000
```
* ResNet152 + CIFAR100:
```
Attack Test Accuracy is : 56.11%
---Detailed Results----
precision recall f1-score support
Non-Member 0.56 0.60 0.58 15000
Member 0.57 0.52 0.54 15000
accuracy 0.56 30000
macro avg 0.56 0.56 0.56 30000
weighted avg 0.56 0.56 0.56 30000
```
* MobileNetV2 + CIFAR100:
```
Validation Accuracy for the Best Attack Model is: 55.50 %
----Attack Model Testing----
Y:tensor([1, 1, 1, ..., 0, 0, 0])
Length of Attack Model test dataset : [30000]
Attack Test Accuracy is : 51.33%
---Detailed Results----
precision recall f1-score support
Non-Member 0.51 0.52 0.52 15000
Member 0.51 0.51 0.51 15000
accuracy 0.51 30000
macro avg 0.51 0.51 0.51 30000
weighted avg 0.51 0.51 0.51 30000
```
4. 问题及改进方向
Q1:大部分情况下攻击效果较差
A1:尝试换一种MIA攻击方式,尝试Loss Trajectory MIA。之后再仔细检查下有无代码问题。
Q2:对于CIFAR100,Target Model和Shadow Model的训练效果差
A2:因为CIFAR100每类的训练数据本来就比较少,还分别拆分给了Target Model和Shadow Model的train set和test set,导致训练数据过少,因此模型训练效果较差。可以考虑调整训练策略,或是**改用CINIC-10数据集**,CINIC-10 是 CIFAR-10 通过添加下采样的 ImageNet 图像扩展得到,有 270,000 张图像,是 CIFAR 的 4.5 倍,图像大小与 CIFAR 中的一样,不需要对代码进行大量修改,不过有一点问题是图像来源于 CIFAR 和 ImageNet,这些图像的分布不一定相同,可能不利于训练。
Q3:如何把量化也考虑进来
A3:我认为主要取决于我们假设的攻击情景是怎样的。
如果认为攻击者不知道要攻击的模型是全精度的还是量化后的,那么就会采用全精度的Shadow Model,进而根据Shadow Model训练好Attack Model。我们只需要把Target Model进行相应的量化,而后再根据其输出构造Attack Model的test set,就可以将量化考虑进来了。
如果认为攻击者知道要攻击的模型是量化的,同时还知道是用的什么量化数据表示方式、位宽等,那么会比较麻烦。简单的方法是还是训练一个全精度的Shadow Model然后量化,之后再构造Attack Model的训练集。其余步骤于上一个情况相同。可能会取得更好的攻击效果(因为是根据量化后的Shadow Model构建的Attack Model的训练集,Attack Model更能适应量化模型的output). 复杂的方法是引入QAT,但QAT目前在ResNet系列和MobileNetV2上还比较难训练起来。
\ No newline at end of file
import torch
from torch.utils.data.dataset import TensorDataset
import torch.nn.functional as F
import copy
import os
#Prepare data for Attack Model
def prepare_attack_data(attack_X,
attack_Y,
model,
iterator,
device,
top_k=False,
test_dataset=False):
model.eval()
with torch.no_grad():
for inputs, _ in iterator: # train loader
# Move tensors to the configured device
inputs = inputs.to(device)
#Forward pass through the model
outputs = model(inputs)
#To get class probabilities
# posteriors = F.softmax(outputs, dim=1)
posteriors = outputs
if top_k:
#Top 3 posterior probabilities(high to low) for train samples
topk_probs, _ = torch.topk(posteriors, 3, dim=1)
attack_X.append(topk_probs.cpu())
else:
# 只用了output probabilities向量
# TODO 可以用loss等 (用target model的output和shadow model的output做loss,不过这个需要能获得target model的output)
attack_X.append(posteriors.cpu())
# size(0) = batch size
if test_dataset:
attack_Y.append(torch.zeros(posteriors.size(0),dtype=torch.long))
else:
attack_Y.append(torch.ones(posteriors.size(0), dtype=torch.long))
return attack_X, attack_Y
def train_per_epoch(model,
train_iterator,
criterion,
optimizer,
device,
bce_loss=False):
epoch_loss = 0
epoch_acc = 0
correct = 0
total = 0
model.train()
for _ , (features, target) in enumerate(train_iterator):
# Move tensors to the configured device
features = features.to(device)
target = target.to(device)
# Forward pass
outputs = model(features)
# For Attack Model
if bce_loss:
#For BCE loss
loss = criterion(outputs, target.unsqueeze(1))
# For Target Model and Shadow Model
else:
loss = criterion(outputs, target)
# Backward pass and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
#Record Loss
epoch_loss += loss.item()
#Get predictions for accuracy calculation
_, predicted = torch.max(outputs.data, 1)
# 统计epoch内的总数
total += target.size(0)
correct += (predicted == target).sum().item()
#Per epoch valdication accuracy calculation
epoch_acc = correct / total
epoch_loss = epoch_loss / total
return epoch_loss, epoch_acc
def val_per_epoch(model,
val_iterator,
criterion,
device,
bce_loss=False):
epoch_loss = 0
epoch_acc = 0
correct = 0
total =0
model.eval()
with torch.no_grad():
for _,(features,target) in enumerate(val_iterator):
features = features.to(device)
target = target.to(device)
outputs = model(features)
#Caluclate the loss
if bce_loss:
#For BCE loss
loss = criterion(outputs, target.unsqueeze(1))
else:
loss = criterion(outputs,target)
#record the loss
epoch_loss += loss.item()
#Check Accuracy
_, predicted = torch.max(outputs.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
#Per epoch valdication accuracy and loss calculation
epoch_acc = correct / total
epoch_loss = epoch_loss / total
return epoch_loss, epoch_acc
###############################
# Training Attack Model
###############################
def train_attack_model(model,
dataset,
criterion,
optimizer,
lr_scheduler,
device,
model_path='./model',
epochs=10,
b_size=20,
num_workers=1,
verbose=False,
earlystopping=False):
n_validation = 1000 # number of validation samples
best_valacc = 0
stop_count = 0
patience = 10 # Early stopping
path = os.path.join(model_path,'best_attack_model.pt')
train_loss_hist = []
valid_loss_hist = []
val_acc_hist = []
train_X, train_Y = dataset
#Contacetnae list of tensors to a single tensor
t_X = torch.cat(train_X)
t_Y = torch.cat(train_Y)
print(f"t_Y:{t_Y}")
# #Create Attack Dataset
attackdataset = TensorDataset(t_X,t_Y)
print('Shape of Attack Feature Data : {}'.format(t_X.shape))
print('Shape of Attack Target Data : {}'.format(t_Y.shape))
print('Length of Attack Model train dataset : [{}]'.format(len(attackdataset)))
print('Epochs [{}] and Batch size [{}] for Attack Model training'.format(epochs,b_size))
#Create Train and Validation Split
n_train_samples = len(attackdataset) - n_validation
train_data, val_data = torch.utils.data.random_split(attackdataset,
[n_train_samples, n_validation])
train_loader = torch.utils.data.DataLoader(dataset=train_data,
batch_size=b_size,
shuffle=True,
num_workers=num_workers)
val_loader = torch.utils.data.DataLoader(dataset=val_data,
batch_size=b_size,
shuffle=False,
num_workers=num_workers)
print('----Attack Model Training------')
for i in range(epochs):
train_loss, train_acc = train_per_epoch(model, train_loader, criterion, optimizer, device)
valid_loss, valid_acc = val_per_epoch(model, val_loader, criterion, device)
valid_loss_hist.append(valid_loss)
train_loss_hist.append(train_loss)
val_acc_hist.append(valid_acc)
lr_scheduler.step()
print ('Epoch [{}/{}], Train Loss: {:.3f} | Train Acc: {:.2f}% | Val Loss: {:.3f} | Val Acc: {:.2f}%'
.format(i+1, epochs, train_loss, train_acc*100, valid_loss, valid_acc*100))
if earlystopping:
if best_valacc<valid_acc:
print('Saving model checkpoint')
best_valacc = valid_acc
#Store best model weights
best_model = copy.deepcopy(model.state_dict())
torch.save(best_model, path)
stop_count = 0
else:
stop_count+=1
if stop_count >=patience: #early stopping check
print('End Training after [{}] Epochs'.format(epochs+1))
break
else:#Continue model training for all epochs
if best_valacc < valid_acc:
print('Saving model checkpoint')
best_valacc = valid_acc
#Store best model weights
best_model = copy.deepcopy(model.state_dict())
torch.save(best_model, path)
return best_valacc
###################################
# Training Target and Shadow Model
###################################
#
# 每次调用该函数,可以得到一份全新的attck model的训练数据 (不过对于各种量化target model,目前打算采用相同的shadow model? 其实感觉用一个个量化后的shadow model可能更合适吗,这样shadow model是更接近于模仿target model,进而能得到一个个不同的全精度的attack model)
def train_model(model,
train_loader,
val_loader,
test_loader,
loss,
optimizer,
scheduler,
device,
model_path,
verbose=False,
num_epochs=50,
top_k=False,
earlystopping=False,
is_target=False):
best_valacc = 0
patience = 20 # Early stopping
stop_count= 0
train_loss_hist = []
valid_loss_hist = []
val_acc_hist = []
attack_X = []
attack_Y = []
if is_target:
print('----Target model training----')
else:
print('---Shadow model training----')
#Path for saving best target and shadow models
target_path = os.path.join(model_path,'best_target_model.pt')
shadow_path = os.path.join(model_path,'best_shadow_model.pt')
for epoch in range(num_epochs):
train_loss, train_acc = train_per_epoch(model, train_loader, loss, optimizer, device)
valid_loss, valid_acc = val_per_epoch(model, val_loader, loss, device)
valid_loss_hist.append(valid_loss)
train_loss_hist.append(train_loss)
val_acc_hist.append(valid_acc)
scheduler.step()
print ('Epoch [{}/{}], Train Loss: {:.3f} | Train Acc: {:.2f}% | Val Loss: {:.3f} | Val Acc: {:.2f}%'.format(epoch+1, num_epochs, train_loss, train_acc*100, valid_loss, valid_acc*100))
if earlystopping:
if best_valacc<valid_acc:
print('Saving model checkpoint')
best_valacc = valid_acc
#Store best model weights
# best_model = copy.deepcopy(model.state_dict())
if is_target:
torch.save(model.state_dict(), target_path)
else:
torch.save(model.state_dict(), shadow_path)
stop_count = 0
else:
stop_count+=1
if stop_count >=patience: #early stopping check
print('End Training after [{}] Epochs'.format(epoch+1))
break
else:#Continue model training for all epochs
if best_valacc<valid_acc:
print('Saving model checkpoint')
best_valacc = valid_acc
#Store best model weights
# best_model = copy.deepcopy(model.state_dict())
if is_target:
torch.save(model.state_dict(), target_path)
else:
torch.save(model.state_dict(), shadow_path)
# best_valacc = valid_acc
#Store best model weights
# best_model = copy.deepcopy(model.state_dict())
# if is_target:
# torch.save(best_model, target_path)
# else:
# torch.save(best_model, shadow_path)
if is_target:
print('----Target model training finished----')
print('Validation Accuracy for the Target Model is: {:.2f} %'.format(100* best_valacc))
else:
print('----Shadow model training finished-----')
print('Validation Accuracy for the Shadow Model is: {:.2f} %'.format(100* best_valacc))
if is_target:
print('----LOADING the best Target model for Test----')
model.load_state_dict(torch.load(target_path))
else:
print('----LOADING the best Shadow model for Test----')
model.load_state_dict(torch.load(shadow_path))
#As the model is fully trained, time to prepare data for attack model.
#Training Data for members would come from shadow train dataset, and member inference from target train dataset respectively.
# target model的应该不能用做attack model的训练数据,而是用于测试数据,只有shadow model的可以用于构造attack model data
# if not is_target:
attack_X, attack_Y = prepare_attack_data(attack_X,attack_Y,model,train_loader,device,top_k)
# dataset被t_train, t_test, s_train, s_test平均分割为了四份
# 和test set的数据是重复的 (val由test数据抽取得到)
# val_loader中的数据是从t_train中抽出来的1k个,t_train因此减少了相应的数据
# attack_X, attack_Y = prepare_attack_data(attack_X,attack_Y,model,val_loader,device,top_k,test_dataset=True)
# In test phase, we don't need to compute gradients (for memory efficiency)
print('----Test the Trained Network----')
model.eval()
with torch.no_grad():
correct = 0
total = 0
for inputs, labels in test_loader:
inputs = inputs.to(device)
labels = labels.to(device)
test_outputs = model(inputs)
#Predictions for accuracy calculations
_, predicted = torch.max(test_outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
# Posterior and labels for non-members
# probs_test = F.softmax(test_outputs, dim=1)
probs_test = test_outputs
if top_k:
#Take top K posteriors ranked high ---> low
topk_t_probs, _ = torch.topk(probs_test, 3, dim=1)
attack_X.append(topk_t_probs.cpu())
else:
# 把test set的数据也作为attack model的训练数据的一部分 (out)
attack_X.append(probs_test.cpu())
attack_Y.append(torch.zeros(probs_test.size(0), dtype=torch.long))
if is_target:
print('Test Accuracy of the Target model: {:.2f}%'.format(100 * correct / total))
else:
print('Test Accuracy of the Shadow model: {:.2f}%'.format(100 * correct / total))
return attack_X, attack_Y
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment