# -*- coding: utf-8 -*-
from torch.serialization import load
from model import *
from extract_ratio import *
from utils import *

import gol
import openpyxl
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import torch.utils.bottleneck as bn
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter


def direct_quantize(model, test_loader,device):
    for i, (data, target) in enumerate(test_loader, 1):
        data = data.to(device)
        output = model.quantize_forward(data).cpu()
        if i % 500 == 0:
            break
    print('direct quantization finish')


def full_inference(model, test_loader, device):
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        data = data.to(device)
        output = model(data).cpu()
        pred = output.argmax(dim=1, keepdim=True)
        # print(pred)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('\nTest set: Full Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
    return 100. * correct / len(test_loader.dataset)


def quantize_inference(model, test_loader, device, quant_type):
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        data = data.to(device)
        output = model.quantize_inference(data, quant_type).cpu()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('Test set: Quant Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
    return 100. * correct / len(test_loader.dataset)


def js_div(p_output, q_output, get_softmax=True):
    """
    Function that measures JS divergence between target and output logits:
    """
    KLDivLoss = nn.KLDivLoss(reduction='sum')
    if get_softmax:
        p_output = F.softmax(p_output)
        q_output = F.softmax(q_output)
    log_mean_output = ((p_output + q_output)/2).log()
    return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2


if __name__ == "__main__":

    
    parser = argparse.ArgumentParser(description='PyTorch FP32 Training')
    parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='ResNet18')
    parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
    parser.add_argument('-j','--workers', default=4, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')
    # parser.add_argument('-t', '--test', dest='test', action='store_true', help='test model on test set')
    
    # 训练参数
    args = parser.parse_args()


    batch_size = args.batch_size
    num_workers = args.workers

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)
    train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../../project/p/data', train=True, download=False,
                     transform=transforms.Compose([
                        transforms.RandomCrop(32, padding=2),
                        transforms.RandomHorizontalFlip(),
                         transforms.ToTensor(),
                         transforms.Normalize(
                             (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                     ])),
     batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )
    test_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10('../../project/p/data', train=False, download=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010))
    ])),
    batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )

    # model = AlexNet_BN()
    if args.model == 'ResNet18':
        model = resnet18()
    elif args.model == 'ResNet50':
        model = resnet50()
    elif args.model == 'ResNet152':
        model = resnet152()

    
    writer = SummaryWriter(log_dir='log/' + args.model  +  '/ptq')
    full_file = 'ckpt/cifar10_' + args.model + '.pt'
    model.load_state_dict(torch.load(full_file))
    model.to(device)

    load_ptq = True
    ptq_file_prefix = 'ckpt/cifar10_' + args.model + '_ptq_'

    model.eval()
    full_acc = full_inference(model, test_loader, device)
    model_fold = fold_model(model)   # 
    
    full_params = []

    layer, par_ratio, flop_ratio = extract_ratio(args.model)
    # print(layer)
    layer = []
    for name, param in model.named_parameters():
        if 'weight' in name:
            n = name.split('.')
            pre = '.'.join(n[:len(n)-1])
            layer.append(pre)
            # print(name)
    print('===================')
    # print(layer)
    par_ratio, flop_ratio = fold_ratio(layer, par_ratio, flop_ratio)
    # sys.exit()
    for name, param in model_fold.named_parameters():
        if 'bn' in name or 'sample.1' in name:
            continue
        param_norm = F.normalize(param.data.cpu(),p=2,dim=-1)
        full_params.append(param_norm)  # 没统计bn的 只统计了conv的 而且还是fold后的
        writer.add_histogram(tag='Full_' + name + '_data', values=param.data)
    
    
    gol._init()
    quant_type_list = ['INT','POT','FLOAT']
    title_list = []
    js_flops_list = []
    js_param_list = []
    ptq_acc_list = []
    acc_loss_list = []


    if args.model == 'ResNet18':
        model_ptq = resnet18()
    elif args.model == 'ResNet50':
        model_ptq = resnet50()
    elif args.model == 'ResNet152':
        model_ptq = resnet152()


    if load_ptq is True and osp.exists(ptq_file_prefix + 'POT_6' + '.pt'):
        model_ptq.quantize('POT',6,0)
        model_ptq.load_state_dict(torch.load(ptq_file_prefix + 'POT_6' + '.pt'))
        model_ptq.to(device)
        print('Successfully load ptq model: ' + 'POT_6')
        # 此时需要用到一堆list，还需要构建呢，之后才能换掉
    bias_list = build_bias_list('POT')
    gol.set_value(bias_list, is_bias=True)
    plist = build_list('POT', 6, 0)
    gol.set_value(plist)
    model_ptq.freeze()

    for quant_type in quant_type_list:
        num_bit_list = numbit_list(quant_type)
        # 对一个量化类别，只需设置一次bias量化表
        # int由于位宽大，使用量化表开销过大，直接_round即可
        if quant_type != 'INT':
            bias_list = build_bias_list(quant_type)
            gol.set_value(bias_list, is_bias=True)

        for num_bits in num_bit_list:
            e_bit_list = ebit_list(quant_type,num_bits)
            for e_bits in e_bit_list:

                if quant_type == 'FLOAT':
                    title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
                else:
                    title = '%s_%d' % (quant_type, num_bits)
                print('\nPTQ: '+title)
                title_list.append(title)

                # 设置量化表
                if quant_type != 'INT':
                    plist = build_list(quant_type, num_bits, e_bits)
                    gol.set_value(plist)

                ptq_acc = quantize_inference(model_ptq, test_loader, device, quant_type)
                ptq_acc_list.append(ptq_acc)
                acc_loss = (full_acc - ptq_acc) / full_acc
                acc_loss_list.append(acc_loss)
                idx = -1

                # 获取计算量/参数量下的js-div
                js_flops = 0.
                js_param = 0.
                for name, param in model_ptq.named_parameters():
                    # if '.' not in name or 'bn' in name:
                    if 'bn' in name or 'sample.1' in name:
                        continue
                    writer.add_histogram(tag=title +':'+ name + '_data', values=param.data)
                    idx = idx + 1
                    # renset中有多个. 需要改写拼一下
                    # prefix = name.split('.')[0]
                    n = name.split('.')
                    prefix = '.'.join(n[:len(n) - 1])
                     # weight和bias 1:1 ? 对于ratio，是按层赋予的，此处可以对weight和bias再单独赋予不同的权重，比如(8:2)
                    if prefix in layer:
                        layer_idx = layer.index(prefix)
                        ptq_param = param.data.cpu()
                        # 取L2范数
                        ptq_norm = F.normalize(ptq_param,p=2,dim=-1)
                        writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param)
                        js = js_div(ptq_norm,full_params[idx])   # 这里算了fold后的量化前后模型的js距离
                        js = js.item()
                        if js < 0.:
                            js = 0.
                        js_flops = js_flops + js * flop_ratio[layer_idx]
                        js_param = js_param + js * par_ratio[layer_idx]
                js_flops_list.append(js_flops)
                js_param_list.append(js_param)

                print(title + ': js_flops: %f js_param: %f acc_loss: %f' % (js_flops, js_param, acc_loss))
    
    # 写入xlsx
    workbook = openpyxl.Workbook()
    worksheet = workbook.active
    worksheet.cell(row=1,column=1,value='FP32-acc')
    worksheet.cell(row=1,column=2,value=full_acc)
    worksheet.cell(row=3,column=1,value='title')
    worksheet.cell(row=3,column=2,value='js_flops')
    worksheet.cell(row=3,column=3,value='js_param')
    worksheet.cell(row=3,column=4,value='ptq_acc')
    worksheet.cell(row=3,column=5,value='acc_loss')
    for i in range(len(title_list)):
        worksheet.cell(row=i+4, column=1, value=title_list[i])
        worksheet.cell(row=i+4, column=2, value=js_flops_list[i])
        worksheet.cell(row=i+4, column=3, value=js_param_list[i])
        worksheet.cell(row=i+4, column=4, value=ptq_acc_list[i])
        worksheet.cell(row=i+4, column=5, value=acc_loss_list[i])
    
    workbook.save('POT_ptq_result_' + args.model + '.xlsx')
    writer.close()

    ft = open('POT_ptq_result_' + args.model + '.txt','w')

    print('title_list:',file=ft)
    print(" ".join(title_list),file=ft)
    print('js_flops_list:',file=ft)
    print(" ".join(str(i) for i in js_flops_list), file=ft)
    print('js_param_list:',file=ft)
    print(" ".join(str(i) for i in js_param_list), file=ft)
    print('ptq_acc_list:',file=ft)
    print(" ".join(str(i) for i in ptq_acc_list), file=ft)
    print('acc_loss_list:',file=ft)
    print(" ".join(str(i) for i in acc_loss_list), file=ft)

    ft.close()

