#!/usr/bin/python

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.dataset import TensorDataset
from torch.utils.data.sampler import SubsetRandomSampler
import model
from model import init_params as w_init
from train import  train_model, train_attack_model, prepare_attack_data
from sklearn.metrics import classification_report
from sklearn.metrics import precision_score, recall_score
import argparse
import numpy as np
import os
import copy
import random
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import CosineAnnealingLR


#set the seed for reproducibility
np.random.seed(1234)
#Flag to enable early stopping
need_earlystop = False

########################
# Model Hyperparameters
########################
#Number of layer's parameters for target and shadow models 
# target_filters = [128, 256, 256]
# shadow_filters = [64, 128, 128]
#New FC layers size for pretrained model
# n_fc= [256, 128] 
#For CIFAR-10 and MNIST dataset
num_classes = 10
# #No. of training epocs
# num_epochs = 100
# #how many samples per batch to load
# batch_size = 128
# #learning rate
# learning_rate = 0.001
# #Learning rate decay 
# lr_decay = 0.0001
#Regularizer
reg=1e-4
#percentage of dataset to use for shadow model
shadow_split = 0.6
#Number of validation samples
n_validation = 1000
#Number of processes
num_workers = 2
#Hidden units for MNIST model
n_hidden_mnist = 32


################################
#Attack Model Hyperparameters
################################
NUM_EPOCHS = 100
BATCH_SIZE = 10
#Learning rate
LR_ATTACK = 0.1 
#L2 Regulariser
REG = 1e-7
#weight decay
LR_DECAY = 0.96
#No of hidden units
n_hidden = 256
#Binary Classsifier
out_classes = 2



def get_cmd_arguments():
    parser = argparse.ArgumentParser(prog="Membership Inference Attack")
    parser.add_argument('--model', default='resnet18', type=str, choices=['resnet18', 'resnet50', 'resnet152', 'mobilenetv2'], help='Which model to use (resnet18,50,152 or mobilenet)')
    parser.add_argument('--dataset', default='CIFAR10', type=str, choices=['CIFAR10', 'CIFAR100' ,'MNIST'], help='Which dataset to use (CIFAR10 or MNIST)')
    parser.add_argument('--dataPath', default='/lustre/datasets', type=str, help='Path to store data')
    parser.add_argument('--modelPath', default='./ckpt',type=str, help='Path to save or load model checkpoints')

    parser.add_argument('-e','--num_epochs', default=100, type=int, metavar='EPOCHS', help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
    parser.add_argument('-lr', '--learning_rate', default=0.001, type=float, metavar='LR', help='initial learning rate')
    parser.add_argument('-wd','--lr_decay',default=0.0001,type=float,metavar='WD',help='lr schduler weight decay')
    parser.add_argument('-j','--num_workers', default=2, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')

    parser.add_argument('--trainTargetModel', action='store_true', help='Train a target model, if false then load an already trained model')
    parser.add_argument('--trainShadowModel', action='store_true', help='Train a shadow model, if false then load an already trained model')
    parser.add_argument('--trainAttackModel', action='store_true', help='Train an attack model, if false then load an already trained model')
    parser.add_argument('--need_augm',action='store_true', help='To use data augmentation on target and shadow training set or not')
    parser.add_argument('--need_topk',action='store_true', help='Flag to enable using Top 3 posteriors for attack data')
    parser.add_argument('--param_init', action='store_true', help='Flag to enable custom model params initialization')
    parser.add_argument('--verbose',action='store_true', help='Add Verbosity')
    return parser.parse_args()

def get_data_transforms(dataset, augm=False):

    if dataset == 'CIFAR10':
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        test_transforms = transforms.Compose([transforms.ToTensor(),
                                            normalize])

        if augm:
            train_transforms = transforms.Compose([transforms.RandomRotation(5),
                                                transforms.RandomHorizontalFlip(p=0.5),
                                                transforms.RandomCrop(32, padding=4),
                                                transforms.ToTensor(),
                                                normalize]) 
        else:
            train_transforms = transforms.Compose([transforms.ToTensor(),
                                                normalize])

    elif dataset == 'CIFAR100':
        normalize = transforms.Normalize((0.507, 0.487, 0.441), (0.267, 0.256, 0.276))
        test_transforms = transforms.Compose([transforms.ToTensor(),
                                            normalize])

        if augm:
            train_transforms = transforms.Compose([transforms.RandomRotation(5),
                                                transforms.RandomHorizontalFlip(p=0.5),
                                                transforms.RandomCrop(32, padding=4),
                                                transforms.ToTensor(),
                                                normalize]) 
        else:
            train_transforms = transforms.Compose([transforms.ToTensor(),
                                                normalize])
    else:
        #The values 0.1307 and 0.3081 used for the Normalize() transformation below are the global mean and standard deviation 
        #of the MNIST dataset
        test_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                        torchvision.transforms.Normalize((0.1307,), (0.3081,))])
        if augm:
            train_transforms = torchvision.transforms.Compose([transforms.RandomRotation(5),
                                                    transforms.RandomHorizontalFlip(p=0.5),
                                                    torchvision.transforms.ToTensor(),
                                                    torchvision.transforms.Normalize((0.1307,), (0.3081,))])
        else:
      
            train_transforms = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                    torchvision.transforms.Normalize((0.1307,), (0.3081,))])
        
    return train_transforms, test_transforms


def split_dataset(train_dataset):
    #For simplicity we are only using orignal training set and splitting into 4 equal parts
    #and assign it to Target train/test and Shadow train/test.
    total_size = len(train_dataset)
    split1 = total_size // 4
    split2 = split1*2
    split3 = split1*3
    
    # 0 ~ total_size-1
    indices = list(range(total_size))
    
    np.random.shuffle(indices)
    
    #Shadow model train and test set
    s_train_idx = indices[:split1]
    s_test_idx = indices[split1:split2]

    #Target model train and test set
    t_train_idx = indices[split2:split3]
    t_test_idx = indices[split3:]
    
    return s_train_idx, s_test_idx,t_train_idx,t_test_idx
    

#--------------------------------------------------------------------------------
# Get dataloaders for Shadow and Target models 
# Data Strategy:
# - Split the entire training dataset into 4 parts(T_tain, T_test, S_train, S_test)
#  Target -  Train on T_train and T_test
#  Shadow -  Train on S_train and S_test
#  Attack - Use T_train and T_test for evaluation
#           Use S_train and S_test for training
#--------------------------------------------------------------------------------

def get_data_loader(dataset,
                    data_dir,
                    batch,
                    shadow_split=0.5,  # 目前还没用上
                    augm_required=False,
                    num_workers=1):
    """
     Utility function for loading and returning train and valid
     iterators over the CIFAR-10 and MNIST dataset.
    """ 
    # 暂时固定了shadow split为1：1：1：1
    # error_msg = "[!] shadow_split should be in the range [0, 1]."
    # assert ((shadow_split >= 0) and (shadow_split <= 1)), error_msg
    
    
    train_transforms, test_transforms = get_data_transforms(dataset,augm_required)
        
    #Download test and train dataset
    print(f"Using Dataset:{dataset}")
    if dataset == 'CIFAR10':
        #CIFAR10 training set
        train_set = torchvision.datasets.CIFAR10(root=data_dir,
                                                    train=True,
                                                    transform=train_transforms,
                                                    download=False)  
        #CIFAR10 test set 在之前版中没用上，现可以考虑将其与train_set合并，构成一个统一的train set
        test_set = torchvision.datasets.CIFAR10(root=data_dir, 
                                                train = False,  
                                                transform = test_transforms)
        train_set = torch.utils.data.ConcatDataset([train_set, test_set])
        # 原本只分割了train set (包含了到in，out)，效果等同于对train set再做划分得到t_in(t_train),t_out(t_test),s_in(s_train),s_out(s_test)，后续版本将test_set也合并进来了
        s_train_idx, s_out_idx, t_train_idx, t_out_idx = split_dataset(train_set)   
    elif dataset == 'CIFAR100':
        train_set = torchvision.datasets.CIFAR100(root=data_dir,
                                                    train=True,
                                                    transform=train_transforms,
                                                    download=False)  
        test_set = torchvision.datasets.CIFAR100(root=data_dir, 
                                                train = False,  
                                                transform = test_transforms)
        train_set = torch.utils.data.ConcatDataset([train_set, test_set])
        s_train_idx, s_out_idx, t_train_idx, t_out_idx = split_dataset(train_set)  
    elif dataset == 'MNIST':
        #MNIST train set
        train_set = torchvision.datasets.MNIST(root=data_dir,
                                        train=True,
                                        transform=train_transforms,
                                        download=False)
        #MNIST test set
        test_set = torchvision.datasets.MNIST(root=data_dir, 
                                        train = False,  
                                        transform = test_transforms)
        train_set = torch.utils.data.ConcatDataset([train_set, test_set])
        
        s_train_idx, s_out_idx, t_train_idx, t_out_idx = split_dataset(train_set)
   
    
    # Data samplers
    s_train_sampler = SubsetRandomSampler(s_train_idx)
    s_out_sampler = SubsetRandomSampler(s_out_idx)
    t_train_sampler = SubsetRandomSampler(t_train_idx)
    t_out_sampler = SubsetRandomSampler(t_out_idx)
       

    #In our implementation we are keeping validation set to measure training performance
    #From the held out set for target and shadow, we take n_validation samples. 
    #As train set is already small we decided to take valid samples from held out set
    #as these are samples not used in training. 
    # if dataset == 'CIFAR10':
        # out中的前n_validation个元素构成了in的val集 (没用于训练，可以)
        # target_val_idx = t_out_idx[:n_validation]
        # shadow_val_idx = s_out_idx[:n_validation]
        # 将val set 等同于 test set
    target_val_idx = t_out_idx
    shadow_val_idx = s_out_idx
    
    t_val_sampler = SubsetRandomSampler(target_val_idx)
    s_val_sampler = SubsetRandomSampler(shadow_val_idx)
    # elif dataset == 'CIFAR100':
    #     target_val_idx = t_out_idx
    #     shadow_val_idx = s_out_idx

    #     t_val_sampler = SubsetRandomSampler(target_val_idx)
    #     s_val_sampler = SubsetRandomSampler(shadow_val_idx)
    # # elif dataset == 'MNIST':
    #     # target_val_idx = t_out_idx[:n_validation]
    #     # shadow_val_idx = s_out_idx[:n_validation]
    #     target_val_idx = t_out_idx
    #     shadow_val_idx = s_out_idx

    #     t_val_sampler = SubsetRandomSampler(target_val_idx)
    #     s_val_sampler = SubsetRandomSampler(shadow_val_idx)
    

    #-------------------------------------------------
    # Data loader
    #-------------------------------------------------
    # if dataset == 'CIFAR10':

    t_train_loader = torch.utils.data.DataLoader(dataset=train_set, 
                                            batch_size=batch, 
                                            sampler = t_train_sampler,
                                            num_workers=num_workers)
                                            
    t_out_loader = torch.utils.data.DataLoader(dataset=train_set,
                                            batch_size=batch,
                                            sampler = t_out_sampler,
                                            num_workers=num_workers)
                                            
    t_val_loader = torch.utils.data.DataLoader(dataset=train_set,
                                            batch_size=batch,
                                            sampler=t_val_sampler,
                                            num_workers=num_workers)
        
    s_train_loader = torch.utils.data.DataLoader(dataset=train_set,
                                            batch_size=batch,
                                            sampler=s_train_sampler,
                                            num_workers=num_workers)
                                            
    s_out_loader = torch.utils.data.DataLoader(dataset=train_set,
                                            batch_size=batch,
                                            sampler=s_out_sampler,
                                            num_workers=num_workers)
        
    s_val_loader = torch.utils.data.DataLoader(dataset=train_set,
                                            batch_size=batch,
                                            sampler=s_val_sampler,
                                            num_workers=num_workers)

    
    # elif dataset == 'CIFAR100':

    #     t_train_loader = torch.utils.data.DataLoader(dataset=train_set, 
    #                                         batch_size=batch, 
    #                                         sampler = t_train_sampler,
    #                                         num_workers=num_workers)
                                            
    #     t_out_loader = torch.utils.data.DataLoader(dataset=train_set,
    #                                         batch_size=batch,
    #                                         sampler = t_out_sampler,
    #                                         num_workers=num_workers)
                                            
    #     t_val_loader = torch.utils.data.DataLoader(dataset=train_set,
    #                                         batch_size=batch,
    #                                         sampler=t_val_sampler,
    #                                         num_workers=num_workers)
        
    #     s_train_loader = torch.utils.data.DataLoader(dataset=train_set,
    #                                         batch_size=batch,
    #                                         sampler=s_train_sampler,
    #                                         num_workers=num_workers)
                                            
    #     s_out_loader = torch.utils.data.DataLoader(dataset=train_set,
    #                                         batch_size=batch,
    #                                         sampler=s_out_sampler,
    #                                         num_workers=num_workers)
        
    #     s_val_loader = torch.utils.data.DataLoader(dataset=train_set,
    #                                         batch_size=batch,
    #                                         sampler=s_val_sampler,
    #                                         num_workers=num_workers)
        
    # elif dataset == 'MNIST':
    #     t_train_loader = torch.utils.data.DataLoader(dataset=train_set, 
    #                                         batch_size=batch, 
    #                                         sampler=t_train_sampler,
    #                                         num_workers=num_workers)
    
    #     t_out_loader = torch.utils.data.DataLoader(dataset=train_set,
    #                                         batch_size=batch,
    #                                         sampler=t_out_sampler,
    #                                         num_workers=num_workers)
        
    #     t_val_loader = torch.utils.data.DataLoader(dataset=train_set,
    #                                         batch_size=batch,
    #                                         sampler=t_val_sampler,
    #                                         num_workers=num_workers)
        
    #     s_train_loader = torch.utils.data.DataLoader(dataset=train_set,
    #                                         batch_size=batch,
    #                                         sampler=s_train_sampler,
    #                                         num_workers=num_workers)
                                            
    #     s_out_loader = torch.utils.data.DataLoader(dataset=train_set,
    #                                         batch_size=batch,
    #                                         sampler=s_out_sampler,
    #                                         num_workers=num_workers)
        
    #     s_val_loader = torch.utils.data.DataLoader(dataset=train_set,
    #                                         batch_size=batch,
    #                                         sampler=s_val_sampler,
    #                                         num_workers=num_workers)
    

      
    print('Total Combined Test samples in {} dataset : {}'.format(dataset, len(train_set))) 
    # print('Total Train samples in {} dataset : {}'.format(dataset, len(train_set)))  
    print('Number of Target train samples: {}'.format(len(t_train_sampler)))
    print('Number of Target valid samples: {}'.format(len(t_val_sampler)))  #从test中取出来了部分做val
    print('Number of Target test samples: {}'.format(len(t_out_sampler)))
    print('Number of Shadow train samples: {}'.format(len(s_train_sampler)))
    print('Number of Shadow valid samples: {}'.format(len(s_val_sampler)))  #从test中取出来了部分做val
    print('Number of Shadow test samples: {}'.format(len(s_out_sampler)))
   

    return t_train_loader, t_val_loader, t_out_loader, s_train_loader, s_val_loader, s_out_loader


def attack_inference(model,
                    test_X,
                    test_Y,
                    device):
    
    print('----Attack Model Testing----')

    targetnames= ['Non-Member', 'Member']
    pred_y = []
    true_y = []
    
    #Tuple of tensors => full tensor    
    # (batch_size, channels, height, width)=>(total_samples, channels, height, width)
    X = torch.cat(test_X)
    Y = torch.cat(test_Y)
    print(f"Y:{Y}")

    #Create Inference dataset (reconstructed)
    inferdataset = TensorDataset(X,Y) 
    print('Length of Attack Model test dataset : [{}]'.format(len(inferdataset)))

    dataloader = torch.utils.data.DataLoader(dataset=inferdataset,
                                            batch_size=50,
                                            shuffle=False,
                                            num_workers=num_workers)

    #Evaluation of Attack Model
    model.eval() 
    with torch.no_grad():
        correct = 0
        total = 0
        for i, (inputs, labels) in enumerate(dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            
            #Predictions for accuracy calculations
            _, predictions = torch.max(outputs.data, 1)
            total+=labels.size(0)
            correct+=(predictions == labels).sum().item()
            
            # print('True Labels for Batch [{}] are : {}'.format(i,labels))
            # print('Predictions for Batch [{}] are : {}'.format(i,predictions))
            
            true_y.append(labels.cpu())
            pred_y.append(predictions.cpu())
        
    attack_acc = correct / total
    print('Attack Test Accuracy is  : {:.2f}%'.format(100*attack_acc))
    
    true_y =  torch.cat(true_y).numpy()
    pred_y = torch.cat(pred_y).numpy()

    print('---Detailed Results----')
    print(classification_report(true_y,pred_y, target_names=targetnames))


#Main Method to initate model training and attack
def create_attack(modelname, dataset, dataPath, modelPath, num_epochs , batch_size,  learning_rate, lr_decay , num_workers,trainTargetModel, trainShadowModel,trainAttackModel, need_augm, need_topk, param_init, verbose):
 
    dataset = dataset
    need_augm = need_augm
    verbose = verbose
    #For using top 3 posterior probabilities 
    top_k = need_topk

    if dataset == 'CIFAR10' or dataset == 'CIFAR100':
        img_size = 32
        #Input Channels for the Image
        input_dim = 3
    else:#MNIST
        img_size = 28
        input_dim = 1

    datasetDir = os.path.join(dataPath,dataset)
    modelDir = os.path.join(modelPath, dataset)  
    modelDir = os.path.join(modelDir,modelname)
    
    #Create dataset and model directories
    if not os.path.exists(datasetDir):
        try:
            os.makedirs(datasetDir)
        except OSError:
            pass
    
    if not os.path.exists(modelDir):
        try:
            os.makedirs(modelDir)
        except OSError:
            pass 

    # setting device on GPU if available, else CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print('Using device:', device)

    #Creating data loaders
    # val是由test提取出来的部分
    t_train_loader, t_val_loader, t_test_loader,\
    s_train_loader, s_val_loader, s_test_loader = get_data_loader(dataset, 
                                                                datasetDir,
                                                                batch_size,
                                                                shadow_split,
                                                                need_augm,
                                                                num_workers)
    
     
    if (trainTargetModel):

        if dataset == 'CIFAR10':          
            # target_model = model.TargetNet(input_dim,target_filters,img_size,num_classes).to(device)
            if modelname =='resnet18':
                target_model = model.resnet18().to(device)
            elif modelname == 'resnet50':
                target_model = model.resnet50().to(device)
            elif modelname == 'resnet152':
                target_model = model.resnet152().to(device)
            elif modelname == 'mobilenetv2':
                target_model = model.MobileNetV2().to(device)
        elif dataset == 'CIFAR100':
            if modelname =='resnet18':
                target_model = model.resnet18(num_classes=100).to(device)
            elif modelname == 'resnet50':
                target_model = model.resnet50(num_classes=100).to(device)
            elif modelname == 'resnet152':
                target_model = model.resnet152(num_classes=100).to(device)
            elif modelname == 'mobilenetv2':
                target_model = model.MobileNetV2(num_classes=100).to(device)
        else:
            target_model = model.MNISTNet(input_dim, n_hidden_mnist, num_classes).to(device)

        if (param_init):
            #Initialize params
            target_model.apply(w_init) 
        
        
        # Print the model we just instantiated
        if verbose:
            print('----Target Model Architecure----')
            print(target_model)
            print('----Model Learnable Params----')
            for name,param in target_model.named_parameters():
                 if param.requires_grad == True:
                    print("\t",name)
        

        # Loss and optimizer for Tager Model
        loss = nn.CrossEntropyLoss()
        
        # optimizer = torch.optim.Adam(target_model.parameters(), lr=learning_rate, weight_decay=reg)
        # lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma=lr_decay)
        optimizer = torch.optim.Adam(target_model.parameters(), lr=learning_rate)
        lr_scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

        # return attackX，attckY (均由CIFAR0的train set构建)
        targetX, targetY = train_model(target_model,
                                    t_train_loader,
                                    t_val_loader,
                                    t_test_loader,  # 来源于CIFAR10的train+test set
                                    loss,
                                    optimizer,
                                    lr_scheduler,
                                    device,
                                    modelDir,
                                    verbose,
                                    num_epochs,
                                    top_k,
                                    need_earlystop,
                                    is_target=True)

    else: #Target model training not required, load the saved checkpoint
        target_file = os.path.join(modelDir,'best_target_model.pt')
        print('Use Target model at the path ====> [{}] '.format(modelDir))
        #Instantiate Target Model Class
        if dataset == 'CIFAR10': 
            # target_model = model.TargetNet(input_dim,target_filters,img_size,num_classes).to(device)
            # target_model = model.resnet18().to(device)
            if modelname =='resnet18':
                target_model = model.resnet18().to(device)
            elif modelname == 'resnet50':
                target_model = model.resnet50().to(device)
            elif modelname == 'resnet152':
                target_model = model.resnet152().to(device)
            elif modelname == 'mobilenetv2':
                target_model = model.MobileNetV2().to(device)
        elif dataset == 'CIFAR100':
            if modelname =='resnet18':
                target_model = model.resnet18(num_classes=100).to(device)
            elif modelname == 'resnet50':
                target_model = model.resnet50(num_classes=100).to(device)
            elif modelname == 'resnet152':
                target_model = model.resnet152(num_classes=100).to(device)
            elif modelname == 'mobilenetv2':
                target_model = model.MobileNetV2(num_classes=100).to(device)
        else:
            target_model = model.MNISTNet(input_dim,n_hidden_mnist,num_classes).to(device)

        target_model.load_state_dict(torch.load(target_file))
        print('---Peparing Attack Training data---')
        t_trainX = []
        t_trainY = []
        t_testX = []
        t_testY = []
        # test dataset for attack model : in
        t_trainX, t_trainY = prepare_attack_data(t_trainX,t_trainY,target_model,t_train_loader,device,top_k)
        # print(f"len of t_trainX:{len(t_trainX)},len of t_trainY:{len(t_trainY)}")
        # test dataset for attack model : out
        t_testX, t_testY = prepare_attack_data(t_testX,t_testY,target_model,t_test_loader,device,top_k,test_dataset=True)
        # print(f"len of t_testX:{len(t_testX)},len of t_testY:{len(t_testY)}")
        # list concat
        targetX = t_trainX + t_testX
        targetY = t_trainY + t_testY
        # print(f"len of targetX:{len(targetX)},len of targetY:{len(targetY)}")
        # print(f"targetX:{targetX}")
        # print(f"targetY:{targetY}")

    if (trainShadowModel):

        if dataset == 'CIFAR10' :
            # shadow_model = model.ShadowNet(input_dim,shadow_filters,img_size,num_classes).to(device)
            if modelname =='resnet18':
                shadow_model = model.resnet18().to(device)
            elif modelname == 'resnet50':
                shadow_model = model.resnet50().to(device)
            elif modelname == 'resnet152':
                shadow_model = model.resnet152().to(device)
            elif modelname == 'mobilenetv2':
                shadow_model = model.MobileNetV2().to(device)
        elif dataset == 'CIFAR100':
            if modelname =='resnet18':
                shadow_model = model.resnet18(num_classes=100).to(device)
            elif modelname == 'resnet50':
                shadow_model = model.resnet50(num_classes=100).to(device)
            elif modelname == 'resnet152':
                shadow_model = model.resnet152(num_classes=100).to(device)
            elif modelname == 'mobilenetv2':
                shadow_model = model.MobileNetV2(num_classes=100).to(device)
        else:
            #Using less hidden units than target model to mimic the architecture
            n_shadow_hidden = 16 
            shadow_model = model.MNISTNet(input_dim,n_shadow_hidden,num_classes).to(device)

        if (param_init):
            #Initialize params
            shadow_model.apply(w_init)

        # Print the model we just instantiated
        if verbose:
            print('----Shadow Model Architecure---')
            print(shadow_model)
            print('---Model Learnable Params----')
            for name,param in shadow_model.named_parameters():
                 if param.requires_grad == True:
                    print("\t",name)
        
        # Loss and optimizer
        shadow_loss = nn.CrossEntropyLoss()
        # shadow_optimizer = torch.optim.Adam(shadow_model.parameters(), lr=learning_rate, weight_decay=reg)
        # shadow_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(shadow_optimizer,gamma=lr_decay)

        # shadow_optimizer = torch.optim.Adam(shadow_model.parameters(), lr=learning_rate)
        shadow_optimizer = torch.optim.Adam(shadow_model.parameters(), lr=learning_rate)
        shadow_lr_scheduler = CosineAnnealingLR(shadow_optimizer, T_max=num_epochs)

        # 返回的是由shadow model得到的attack model的训练数据集
        shadowX, shadowY = train_model(shadow_model,
                                    s_train_loader,
                                    s_val_loader,
                                    s_test_loader,
                                    shadow_loss,
                                    shadow_optimizer,
                                    shadow_lr_scheduler,
                                    device,
                                    modelDir,
                                    verbose,
                                    num_epochs,
                                    top_k,
                                    need_earlystop,
                                    is_target=False)
    else: #Shadow model training not required, load the saved checkpoint
        print('Using Shadow model at the path  ====> [{}] '.format(modelDir))
        shadow_file = os.path.join(modelDir,'best_shadow_model.pt')
        assert os.path.isfile(shadow_file), 'Shadow Mode Checkpoint not found, aborting load'
        #Instantiate Shadow Model Class
        if dataset == 'CIFAR10' :
            # shadow_model = model.ShadowNet(input_dim,shadow_filters,img_size,num_classes).to(device)
            # shadow_model = model.resnet18().to(device)
            if modelname =='resnet18':
                shadow_model = model.resnet18().to(device)
            elif modelname == 'resnet50':
                shadow_model = model.resnet50().to(device)
            elif modelname == 'resnet152':
                shadow_model = model.resnet152().to(device)
            elif modelname == 'mobilenetv2':
                shadow_model = model.MobileNetV2().to(device)
        elif dataset =='CIFAR100':
            if modelname =='resnet18':
                shadow_model = model.resnet18(num_classes=100).to(device)
            elif modelname == 'resnet50':
                shadow_model = model.resnet50(num_classes=100).to(device)
            elif modelname == 'resnet152':
                shadow_model = model.resnet152(num_classes=100).to(device)
            elif modelname == 'mobilenetv2':
                shadow_model = model.MobileNetV2(num_classes=100).to(device)

        else:
            #Using less hidden units than target model to mimic the architecture
            n_shadow_hidden = 16
            shadow_model = model.MNISTNet(input_dim,n_shadow_hidden,num_classes).to(device)

        #Load the saved model
        shadow_model.load_state_dict(torch.load(shadow_file))
        #Prepare dataset for training attack model
        print('----Preparing Attack training data---')
        # 构造attack model的数据集 (train和test放在一起了)  (只能用shadow model的各种输出)
        # in training dataset
        trainX =[]
        trainY = []
        testX=[]
        testY=[]
        # train dataset for attack model: in
        trainX, trainY = prepare_attack_data(trainX,trainY,shadow_model,s_train_loader,device,top_k)
        # train dataset for attack model: out
        testX, testY = prepare_attack_data(testX,testY,shadow_model,s_test_loader,device,top_k,test_dataset=True)
        shadowX = trainX + testX
        shadowY = trainY + testY    
    

    ###################################
    # Attack Model Training
    ##################################
    #The input dimension to MLP attack model
    input_size = shadowX[0].size(1)
    print('Input Feature dim for Attack Model : [{}]'.format(input_size))
    
    attack_model = model.AttackMLP(input_size,n_hidden,out_classes).to(device)

    if (trainAttackModel):
        if (param_init):
            #Initialize params
            attack_model.apply(w_init)

        # Loss and optimizer
        attack_loss = nn.CrossEntropyLoss()
        # attack_optimizer = torch.optim.Adam(attack_model.parameters(), lr=LR_ATTACK, weight_decay=REG)
        # attack_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(attack_optimizer,gamma=LR_DECAY)
        
        # attack_optimizer = torch.optim.Adam(attack_model.parameters(), lr=learning_rate)
        
        attack_optimizer = torch.optim.SGD(attack_model.parameters(), lr=learning_rate,momentum=0.7, weight_decay=0.0001)
        attack_lr_scheduler = CosineAnnealingLR(attack_optimizer, T_max=NUM_EPOCHS)
    
        #Feature vector and labels for training Attack model
        # 只能用shadow model的train/test set (都属于CIFAR10的train set) 作为attack model的train set
        attackdataset = (shadowX, shadowY)
    
        attack_valacc = train_attack_model(attack_model, attackdataset, attack_loss,
                       attack_optimizer, attack_lr_scheduler, device, modelDir,
                        NUM_EPOCHS, BATCH_SIZE, num_workers, verbose, earlystopping=False)
    
    
        print('Validation Accuracy for the Best Attack Model is: {:.2f} %'.format(100* attack_valacc))
    
    
   
    else:
        # TODO
        #Load the trained attack model
        attack_path = os.path.join(modelDir,'best_attack_model.pt')
        attack_model.load_state_dict(torch.load(attack_path))
    
    #Inference on trained attack model
    # 在target model的train/test set上去推理 (起始还是CIFAR10的train set的一部分) 
    # train attack model的时候用的是shadow model的，因此这里target model的数据集成为了attack model的测试集
    # target_X中是target model的output vector，与量化相关
    attack_inference(attack_model, targetX, targetY, device)

if __name__ == '__main__':
    #get command line arguments from the user
    args = get_cmd_arguments()
    print(args)
    #Generate Membership inference attack
    create_attack(args.model,args.dataset, args.dataPath, args.modelPath, args.num_epochs , args.batch_size, args.learning_rate, args.lr_decay , args.num_workers,args.trainTargetModel, args.trainShadowModel, args.trainAttackModel,args.need_augm, args.need_topk, args.param_init, args.verbose)