# -*- coding: utf-8 -*-
from model import *
from get_weight import *
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
import sys

import time

# import matplotlib.pyplot as plt
# import matplotlib
from torchvision.datasets import ImageFolder

from torch.utils.tensorboard import SummaryWriter


from absl import app, flags
# from easydict import EasyDict
# from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
# from cleverhans.torch.attacks.projected_gradient_descent import (
#     projected_gradient_descent,
# )







def train(model, device, train_loader, optimizer, epoch):
    model.train()
    lossLayer = torch.nn.CrossEntropyLoss()
    flag = 0
    cnt = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        cnt = cnt + 1
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = lossLayer(output, target)

        loss.backward()

        histo, grads = (get_model_histogram(model))
        if flag == 0:
            flag = 1
            grads_sum = grads
        else:
            for k,v in grads_sum.items():
                grads_sum[k] += grads[k]

        optimizer.step()

        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
            ))

    for k, v in grads_sum.items():
        grads_sum[k] = v / len(train_loader.dataset)

    return grads_sum






def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    acc=0
    lossLayer = torch.nn.CrossEntropyLoss(reduction='sum')

    # report = EasyDict(nb_test=0, correct=0, correct_fgm=0, correct_pgd=0)
    with torch.no_grad:
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
        
            output = model(data)

        # x_fgm = fast_gradient_method(model, data, 0.01, np.inf)
        # x_pgd = projected_gradient_descent(model, data, 0.01, 0.01, 40, np.inf)

        # model prediction on clean examples
        # _, y_pred = model(data).max(1)

        # # model prediction on FGM adversarial examples
        # _, y_pred_fgm = model(x_fgm).max(1)
        #
        # # model prediction on PGD adversarial examples
        # _, y_pred_pgd = model(x_pgd).max(1)

        # report.nb_test += target.size(0)
        # report.correct += y_pred.eq(target).sum().item()
        # report.correct_fgm += y_pred_fgm.eq(target).sum().item()
        # report.correct_pgd += y_pred_pgd.eq(target).sum().item()

            test_loss += lossLayer(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    acc=100. * correct / len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}%\n'.format(
        test_loss, acc
    ))

    # print(
    #     "test acc on clean examples (%): {:.3f}".format(
    #         report.correct / report.nb_test * 100.0
    #     )
    # )
    # print(
    #     "test acc on FGM adversarial examples (%): {:.3f}".format(
    #         report.correct_fgm / report.nb_test * 100.0
    #     )
    # )
    # print(
    #     "test acc on PGD adversarial examples (%): {:.3f}".format(
    #         report.correct_pgd / report.nb_test * 100.0
    #     )
    # )

    return acc


    batch_size = 32
    test_batch_size = 32
    seed = 1
    # epochs = 15
    d1 = sys.argv[1]
    epochs = int(d1)
    lr = 0.001
    momentum = 0.5
    save_model = False
    using_bn = True
    net = 'LeNet'
    torch.manual_seed(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)



    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('data', train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])),
        batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=True
    )


    #if using_bn:
    if (net == 'VGG19') == True:
        model = VGG_19().to(device)
    elif (net == 'LeNet') == True:
        model = LeNet().to(device)

    # else:
    #     model = Net().to(device)


    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    writer = SummaryWriter(log_dir='./fullprecision_log')
    #optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9,0.999),eps=1e-08,weight_decay=0,amsgrad=False)

    for epoch in range(1, epochs + 1):
        grads_sum = train(model, device, train_loader, optimizer, epoch)
        acc = test(model, device, test_loader)
        print('epoch:', epoch)
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'epoch': epoch,
            'grads': grads_sum,
            'accuracy':acc
        }
        # for name, param in model.named_parameters():
        #     writer.add_histogram(tag=name + '_grad', values=param.grad, global_step=epoch)
        #     writer.add_histogram(tag=name + '_data', values=param.data, global_step=epoch)
        for name, param in grads_sum.items():
            # 此处的grad是累加值吧 不是平均值
            writer.add_histogram(tag=name + '_grad', values=param, global_step=epoch)
            # 取这个epoch最后一个batch算完之后的weight
        for name, param in model.named_parameters():
            writer.add_histogram(tag=name + '_data', values=param.data, global_step=epoch)

        if (net == 'LeNet') == True:
            torch.save(checkpoint, 'checkpoint/cifar-10_lenet_bn/full/ckpt_cifar-10_lenet_bn_%s.pth' % (str(epoch)))


    #保存参数
        # if (net == 'VGG19') == True:
        #     torch.save(checkpoint, 'checkpoint/cifar-10_vgg19_bn/ckpt_cifar-10_vgg19_bn_%s.pth' % (str(epoch)))
        # elif (net == 'LeNet') == True:
        #     torch.save(checkpoint, 'checkpoint/cifar-10_lenet_bn/ckpt_cifar-10_lenet_bn_%s.pth' % (str(epoch)))

        #print('Saved all parameters!\n')


    if save_model:
        if not osp.exists('ckpt'):
            os.makedirs('ckpt')
        #if using_bn:
        if (net == 'VGG19') == True:
            torch.save(model.state_dict(), 'ckpt/cifar-10_vgg19_bn.pt')
        elif (net == 'LeNet') == True:
            torch.save(model.state_dict(), 'ckpt/cifar-10_lenet_bn.pt')

        # else:
        #     torch.save(model.state_dict(), 'ckpt/cifar-10_vgg19.pt')
