# -*- coding: utf-8 -*-
from model import *

# from easydict import EasyDict
# from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method
# from cleverhans.torch.attacks.projected_gradient_descent import (
#     projected_gradient_descent,
# )


import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision import datasets, transforms
import os
import os.path as osp
import sys

import time

# import matplotlib.pyplot as plt
# import matplotlib

# sys.path.append("./project/p")
from get_weight import *
from torch.utils.tensorboard import SummaryWriter



def quantize_aware_training(model, device, train_loader, optimizer, epoch):
    lossLayer = torch.nn.CrossEntropyLoss()
    flag = 0
    cnt = 0
    losses=[]
    for batch_idx, (data, target) in enumerate(train_loader, 1):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model.quantize_forward(data)  # 各个layer的forward
        loss = lossLayer(output, target)  #此处loss与layers联系起来
        loss.backward()
        # cnt = cnt + 1
        losses.append(loss)

        histo, grads = (get_model_histogram(model))
        if flag == 0:
            flag = 1
            grads_sum = grads
        # 对一个epoch的每个batch的梯度求和
        else:
            for k, v in grads_sum.items():
                grads_sum[k] += grads[k]
                #print(k)

        optimizer.step()

        if batch_idx % 50 == 0:
            print('Quantize Aware Training Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
            ))


    # print(grad_sum['conv_layers.conv1.weight'])
    # sys.exit(0)

    # print('batch_idx: ' +str(batch_idx))
    # print('cnt: ' + str(cnt))

    # 一个epoch的平均梯度
    for k, v in grads_sum.items():
        grads_sum[k] = v/len(train_loader.dataset)

    return grads_sum,losses
    #
    # print(grads_sum)
    #
    # histo = get_grad_histogram(grads_sum)
    #
    # for s,_ in grads_sum.items():
    #     data = histo[s]
    #     bins = data['bins']
    #     histogram = data['histogram']
    #     max_idx = np.argmax(histogram)
    #     min_idx = np.argmin(histogram)
    #     width = abs(bins[max_idx] - bins[min_idx])
    #
    #     plt.figure(figsize=(9, 6))
    #     plt.bar(bins[:-1], histogram, width=width)
    #     #plt.show()
    #
    #     plt.savefig('diff_fig/int'+ sys.argv[1] + '/' + s +'.jpg')
    #
    # np.save('diff_fig/int' + sys.argv[1] + '/grads_sum.npy', grads_sum)
    # sys.exit(0)


def full_inference(model, test_loader):
    correct = 0
    # report = EasyDict(nb_test=0, correct=0, correct_fgm=0, correct_pgd=0)
    for i, (data, target) in enumerate(test_loader, 1):
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

        #x_fgm = fast_gradient_method(model, data, 0.01, np.inf)
        #x_pgd = projected_gradient_descent(model, data, 0.01, 0.01, 40, np.inf)
        # model prediction on clean examples
        # _, y_pred = model(data).max(1)

        # model prediction on FGM adversarial examples
        #_, y_pred_fgm = model(x_fgm).max(1)

        # model prediction on PGD adversarial examples
        #_, y_pred_pgd = model(x_pgd).max(1)

        # report.nb_test += target.size(0)
        # report.correct += y_pred.eq(target).sum().item()
        #report.correct_fgm += y_pred_fgm.eq(target).sum().item()
        #report.correct_pgd += y_pred_pgd.eq(target).sum().item()

    print('\nTest set: Full Model Accuracy: {:.0f}%\n'.format(100. * correct / len(test_loader.dataset)))
    # print('\nTest set: Full Model Accuracy:')
    # print(
    #     "test acc on clean examples (%): {:.3f}".format(
    #         report.correct / report.nb_test * 100.0
    #     )
    # )
    # print(
    #     "test acc on FGM adversarial examples (%): {:.3f}".format(
    #         report.correct_fgm / report.nb_test * 100.0
    #     )
    # )
    # print(
    #     "test acc on PGD adversarial examples (%): {:.3f}".format(
    #         report.correct_pgd / report.nb_test * 100.0
    #     )
    # )
    print('============================================')

def quantize_inference(model, test_loader):
    correct = 0
    acc=0
    # report = EasyDict(nb_test=0, correct=0, correct_fgm=0, correct_pgd=0)
    for i, (data, target) in enumerate(test_loader, 1):

        data, target = data.to(device), target.to(device)
        output = model.quantize_inference(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
        acc = 100. * correct / len(test_loader.dataset)
    print('\nTest set: Quant Model Accuracy: {:.0f}%\n'.format(acc))
        # data, target = data.to(device), target.to(device)
        # x_fgm = fast_gradient_method(model, data, 0.01, np.inf)
        # x_pgd = projected_gradient_descent(model, data, 0.01, 0.01, 40, np.inf)
        # model prediction on clean examples
        # _, y_pred = model.quantize_inference(data).max(1)

          # model prediction on FGM adversarial examples
        # _, y_pred_fgm = model.quantize_inference(x_fgm).max(1)

          # model prediction on PGD adversarial examples
        #_, y_pred_pgd = model.quantize_inference(x_pgd).max(1)

        # report.nb_test += target.size(0)
        # report.correct += y_pred.eq(target).sum().item()
        # report.correct_fgm += y_pred_fgm.eq(target).sum().item()
        # report.correct_pgd += y_pred_pgd.eq(target).sum().item()
    #     acc = report.correct / report.nb_test * 100.0
    # print(
    #     "test acc on clean examples (%): {:.3f}".format(acc

    #     )
    # )
    # print(
    #     "test acc on FGM adversarial examples (%): {:.3f}".format(
    #         report.correct_fgm / report.nb_test * 100.0
    #     )
    # )
    # print(
    #     "test acc on PGD adversarial examples (%): {:.3f}".format(
    #         report.correct_pgd / report.nb_test * 100.0
    #     )
    # )

    return acc


if __name__ == "__main__":
    # d1=20
    # d2=5
    d1 = sys.argv[1]  # num_bits
    d2 = sys.argv[2]  # epochs
    d3 = sys.argv[3]  # mode
    d4 = sys.argv[4]  # n_exp

    batch_size = 32
    test_batch_size = 32  # test的与train的batch_size相等才更合理点吧 有batch norm
    seed = 1
    epochs = int(d2)

    lr = 0.001 # 1%*0.01
    momentum = 0.5
    net = 'LeNet' # 1:
    acc=0
    using_bn = True
    load_quant_model_file = None
#     load_quant_model_file = "ckpt/mnist_cnnbn_qat.pt"

    torch.manual_seed(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    # datasets.imagenet
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./project/p/data', train=True, download=False,
                       transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('./project/p/data', train=False, download=False,transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])),
        batch_size=test_batch_size, shuffle=True, num_workers=1, pin_memory=False
    )

    #if using_bn:
        #model = NetBN()

    # if (net=='VGG19') == True:
    #     model = VGG_19().to(device)
    #     model.load_state_dict(torch.load('ckpt/cifar-10_vgg19_bn.pt', map_location='cpu'))
    #     save_file = "ckpt/cifar-10_vgg19_bn_qat.pt"
    # elif (net=='LeNet') == True:
    model = LeNet(n_exp=int(d4), mode = int(d3)).to(device)
    #生成梯度分布图的时候是从0开始训练的

    # fine tune qat
    #model.load_state_dict(torch.load('ckpt/cifar-10_lenet_bn.pt', map_location='cuda'))
    #     save_file = "ckpt/cifar-10_lenet_bn_qat.pt"

    # else:
    #     model = Net().to(device)
    #     model.load_state_dict(torch.load('ckpt/cifar-10_vgg19.pt', map_location='cpu'))
    #     save_file = "ckpt/cifar-10_vgg19_qat.pt"
    model.to(device)

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    #考虑用Adam
    # INT
    # writer = SummaryWriter(log_dir='./scratchlog/quant_bit_' + str(d1) + '_log')

    writer = SummaryWriter(log_dir='./project/p/scratchlog/mode' + str(d3) + '_' + str(d4) + '/quant_bit_' + str(d1) + '_log')



    model.eval()  # 评价模式（不更新梯度，不dropout）
    
    full_inference(model, test_loader)

    num_bits = int(d1)
    # 先进行self中的各个量化层的定义
    model.quantize(num_bits=num_bits)
    print('Quantization bit: %d' % num_bits)

    if load_quant_model_file is not None:
        model.load_state_dict(torch.load(load_quant_model_file))
        print("Successfully load quantized model %s" % load_quant_model_file)



    # 进行量化训练
    for epoch in range(1, epochs + 1):
        model.train()  # 训练模式
        grads_sum, losses = quantize_aware_training(model, device, train_loader, optimizer, epoch)
        print('epoch:', epoch)
        checkpoint = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'grads':grads_sum,
            'epoch': epoch,
            'losses': losses
        }

        for name, param in grads_sum.items():
            # 此处的grad是累加值吧 不是平均值
            writer.add_histogram(tag=name + '_grad', values=param, global_step=epoch)
        for name, param in model.named_parameters():
            writer.add_histogram(tag=name + '_data', values=param.data, global_step=epoch)
        # if (net == 'VGG19') == True:
        #     torch.save(checkpoint,
        #                'checkpoint/cifar-10_vgg_19_bn_quant/ckpt_cifar-10_vgg19_bn_quant_%s.pth' % (str(epoch)))
        #
        #
        # elif (net == 'LeNet') == True:

        # INT
        # dir_name = 'checkpoint/cifar-10_lenet_bn_quant/scratch/' + str(d1)

        dir_name = './project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/mode' + str(d3) + '_' + str(d4) + '/' + str(d1)


        if not os.path.isdir(dir_name):
            os.makedirs(dir_name,mode=0o777)
            os.chmod(dir_name,mode=0o777)
        # INT
        # torch.save(checkpoint,'checkpoint/cifar-10_lenet_bn_quant/scratch/' + str(d1) + '/ckpt_cifar-10_lenet_bn_quant_' + str(epoch) + '.pth')

        torch.save(checkpoint,
                    './project/p/checkpoint/cifar-10_lenet_bn_quant/scratch/mode' + str(d3) + '_' + str(d4) + '/' + str(d1)+ '/ckpt_cifar-10_lenet_bn_quant_' + str(
                       epoch) + '.pth')





        # quan_dict = torch.load('checkpoint/cifar-10_lenet_bn_quant/' + str(d1) + '/ckpt_cifar-10_lenet_bn_quant_%s.pth' % (str(epoch)))
        # print(quan_dict['grads']['conv_layers.conv1.weight'].reshape(1,-1).shape)
        #
        #
        # print('Saved all parameters!\n')


    model.eval()
    #torch.save(model.state_dict(), save_file)

    model.freeze()



    acc = quantize_inference(model, test_loader)
    f = open('./project/p/lenet_qat_scratch_acc' + '.txt', 'a')
    f.write('bit ' + str(d1) + ': ' + str(acc) + '\n')
    f.close()





    



    
