# -*- coding: utf-8 -*-
from torch.serialization import load
from model import *
from extract_ratio import *

from train import parse_arg
from model import *
from utils import *
from lstm_utils import *
from dataset import get_dataset
from config.alphabets import *
from torch.utils.data import DataLoader

import gol
import openpyxl
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import torch.utils.bottleneck as bn
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter


def direct_quantize(config, val_loader, dataset, converter, model, criterion, device):

    model.eval()

    n_correct = 0
    with torch.no_grad():
        for i, (inp, idx) in enumerate(val_loader):
            inp = inp.to(device)
            # inference
            preds = model.quantize_forward(inp)

    print('direct quantization finish')


def full_inference(config, val_loader, dataset, converter, model, criterion, device):
    losses = AverageMeter()
    model.eval()

    n_correct = 0
    with torch.no_grad():
        for i, (inp, idx) in enumerate(val_loader):

            # 一个batch的label的list
            labels = get_batch_label(dataset, idx)
            inp = inp.to(device)

            # inference
            preds = model(inp).cpu()

            # compute loss
            batch_size = inp.size(0)
            text, length = converter.encode(labels)
            preds_size = torch.IntTensor([preds.size(0)] * batch_size)
            loss = criterion(preds, text, preds_size, length)

            losses.update(loss.item(), inp.size(0))

            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            # sim_preds是decode后的，得到了一个str或者一个str list (也是一个batch的)
            sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
            # target应该是一个list  这里相当于在判断字符串是否相等
            for pred, target in zip(sim_preds, labels):
                if pred == target:
                    n_correct += 1

            # if (i + 1) % config.PRINT_FREQ == 0:
            #     print('Epoch: [{0}][{1}/{2}]'.format(epoch, i, len(val_loader)))

    # # 只打印展示前 config.TEST.NUM_TEST_DISP 个句子的对比
    # raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.TEST.NUM_TEST_DISP]
    # for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
    #     print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))


    num_test_sample = len(dataset)
    print("[#Full Precision Model: correct:{} / #total:{}]".format(n_correct, num_test_sample))
    accuracy = n_correct / float(num_test_sample)
    print('Full Precision Model: Test loss: {:.4f}, accuray: {:.4f}'.format(losses.avg, accuracy))
    return accuracy



def quantize_inference(config, val_loader, dataset, converter, model, criterion, device):
    losses = AverageMeter()
    model.eval()

    n_correct = 0
    with torch.no_grad():
        for i, (inp, idx) in enumerate(val_loader):

            # 一个batch的label的list
            labels = get_batch_label(dataset, idx)
            inp = inp.to(device)

            # inference
            preds = model.quantize_inference(inp).cpu()

            # compute loss
            batch_size = inp.size(0)
            text, length = converter.encode(labels)
            preds_size = torch.IntTensor([preds.size(0)] * batch_size)
            loss = criterion(preds, text, preds_size, length)

            losses.update(loss.item(), inp.size(0))

            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            # sim_preds是decode后的，得到了一个str或者一个str list (也是一个batch的)
            sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
            # target应该是一个list  这里相当于在判断字符串是否相等
            for pred, target in zip(sim_preds, labels):
                if pred == target:
                    n_correct += 1

    #         if (i + 1) % config.PRINT_FREQ == 0:
    #             print('Epoch: [{0}][{1}/{2}]'.format(epoch, i, len(val_loader)))

    # # 只打印展示前 config.TEST.NUM_TEST_DISP 个句子的对比
    # raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.TEST.NUM_TEST_DISP]
    # for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
    #     print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))


    num_test_sample = len(dataset)
    print("[#Quantization Model: correct:{} / #total:{}]".format(n_correct, num_test_sample))
    accuracy = n_correct / float(num_test_sample)
    print('Quantization Model: Test loss: {:.4f}, accuray: {:.4f}'.format(losses.avg, accuracy))
    return accuracy



def js_div(p_output, q_output, get_softmax=True):
    """
    Function that measures JS divergence between target and output logits:
    """
    KLDivLoss = nn.KLDivLoss(reduction='sum')
    if get_softmax:
        p_output = F.softmax(p_output)
        q_output = F.softmax(q_output)
    log_mean_output = ((p_output + q_output)/2).log()
    return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2


if __name__ == "__main__":

    # load config
    config = parse_arg()

    # construct face related neural networks
    model = get_crnn(config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    # define loss function
    criterion = torch.nn.CTCLoss()

    train_dataset = get_dataset(config)(config, is_train=True)
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    val_dataset = get_dataset(config)(config, is_train=False)
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=config.TEST.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    converter = strLabelConverter(config.DATASET.ALPHABETS)

    writer = SummaryWriter(log_dir='log/' + config.MODEL.NAME  +  '/ptq')
    full_file = 'ckpt/360cc_' + config.MODEL.NAME + '.pt'
    model.load_state_dict(torch.load(full_file))
    model.to(device)

    load_ptq = True
    ptq_file_prefix = 'ckpt/360cc_' + config.MODEL.NAME + '_ptq_'

    model.eval()
    full_acc = full_inference(config, val_loader, val_dataset, converter, model, criterion, device)

  
    model_fold = fold_model(model)   # 可以得到conv，bn，relu，fc的各个层 
    
    
    full_params = []

    layer, par_ratio, flop_ratio = extract_ratio('lstm-ocr')
    # print(layer)
    layer = []
    for name, param in model.named_parameters():
        # 提取出weight前的名字(就是这个层的名字)
        n = name.split('.')
        pre = '.'.join(n[:len(n)-1])
        # 避免重复添加
        if pre not in layer:
            layer.append(pre)
        

    print('===================')
    # print(layer)
    par_ratio, flop_ratio = fold_ratio(layer, par_ratio, flop_ratio)
    # sys.exit()
    for name, param in model_fold.named_parameters():
        if 'bn' in name or 'sample.1' in name:
            continue
        # param_norm = F.normalize(param.data.cpu(),p=2,dim=-1)
        param_data = param.data.cpu()
        full_params.append(param_data)  # 统计了fold后的conv以及fc的weight和bias
        writer.add_histogram(tag='Full_' + name + '_data', values=param.data)

        # print(f"full-precision after Fold: name:{name}, param:{param.data}, param_data:{param_data}")
    
    
    gol._init()
    quant_type_list = ['INT','POT','FLOAT']
    title_list = []
    js_flops_list = []
    js_param_list = []
    ptq_acc_list = []
    acc_loss_list = []

    for quant_type in quant_type_list:
        num_bit_list = numbit_list(quant_type)
        # 对一个量化类别，只需设置一次bias量化表
        # int由于位宽大，使用量化表开销过大，直接_round即可
        if quant_type != 'INT':
            bias_list = build_bias_list(quant_type)
            gol.set_value(bias_list, is_bias=True)

        for num_bits in num_bit_list:
            e_bit_list = ebit_list(quant_type,num_bits)
            for e_bits in e_bit_list:
                # model_ptq = MobileNetV2()
                model_ptq = get_crnn(config)
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                model_ptq = model_ptq.to(device)

                if quant_type == 'FLOAT':
                    title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
                else:
                    title = '%s_%d' % (quant_type, num_bits)
                print('\nPTQ: '+title)
                title_list.append(title)

                # 设置量化表
                if quant_type != 'INT':
                    plist = build_list(quant_type, num_bits, e_bits)
                    gol.set_value(plist)

                # 判断是否需要载入
                if load_ptq is True and osp.exists(ptq_file_prefix + title + '.pt'):
                    model_ptq.quantize(quant_type,num_bits,e_bits)
                    model_ptq.load_state_dict(torch.load(ptq_file_prefix + title + '.pt'))
                    model_ptq.to(device)
                    print('Successfully load ptq model: ' + title)
                else:
                    model_ptq.load_state_dict(torch.load(full_file))
                    model_ptq.to(device)
                    model_ptq.quantize(quant_type,num_bits,e_bits)
                    model_ptq.eval()
                    direct_quantize(config, val_loader, val_dataset, converter, model_ptq, criterion, device)
                    # if args.save == True:
                    #     torch.save(model_ptq.state_dict(), ptq_file_prefix + title + '.pt')

                model_ptq.freeze()
                ptq_acc = quantize_inference(config, val_loader, val_dataset, converter, model_ptq, criterion, device)
                ptq_acc_list.append(ptq_acc)
                acc_loss = (full_acc - ptq_acc) / full_acc
                acc_loss_list.append(acc_loss)
                idx = -1

                # 对权值参数反量化
                model_ptq.fakefreeze()
                # 获取计算量/参数量下的js-div
                js_flops = 0.
                js_param = 0.
                for name, param in model_ptq.named_parameters():
                    # if '.' not in name or 'bn' in name:
                    if 'bn' in name or 'sample.1' in name:
                        continue
                    writer.add_histogram(tag=title +':'+ name + '_data', values=param.data)
                    idx = idx + 1
                    # renset中有多个. 需要改写拼一下
                    # prefix = name.split('.')[0]
                    n = name.split('.')
                    prefix = '.'.join(n[:len(n) - 1])
                     # weight和bias 1:1 ? 对于ratio，是按层赋予的，此处可以对weight和bias再单独赋予不同的权重，比如(8:2)
                     # layer中只有conv,bn,fc, prefix只可能是conv，fc (顺序的，weight和bias都有)
                    if prefix in layer:
                        layer_idx = layer.index(prefix)
                        ptq_param = param.data.cpu()
                        # 取L2范数
                        # ptq_norm = F.normalize(ptq_param,p=2,dim=-1)
                        writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param)
                        # print(name)
                        # print('=========')
                        # print(ptq_norm)
                        # print('=========')
                        # print(full_params[idx])

                        # full_params中只有conv和fc 且是顺序下来的(weight和bias都有)
                        js = js_div(ptq_param,full_params[idx])   # 这里算了fold后的量化前后模型的js距离
                        js = js.item()
                        origin_js = js
                        # print()                               
                        if js < 0.:
                            js = 0.
                        js_flops = js_flops + js * flop_ratio[layer_idx]
                        js_param = js_param + js * par_ratio[layer_idx]
                        print(f"prefix:{prefix}, layer_idx:{layer_idx}")
                        # print(f"layer name:{prefix} layer_idx is:{layer_idx}  origin_js is {origin_js}  js is {js}  flop_ratio is {flop_ratio[layer_idx]}  par_ratio is {par_ratio[layer_idx]}  js_flops is {js_flops}  js_param is {js_param}")
                js_flops_list.append(js_flops)
                js_param_list.append(js_param)

                print(title + ': js_flops: %f js_param: %f acc_loss: %f' % (js_flops, js_param, acc_loss))
    
    # 写入xlsx
    workbook = openpyxl.Workbook()
    worksheet = workbook.active
    worksheet.cell(row=1,column=1,value='FP32-acc')
    worksheet.cell(row=1,column=2,value=full_acc)
    worksheet.cell(row=3,column=1,value='title')
    worksheet.cell(row=3,column=2,value='js_flops')
    worksheet.cell(row=3,column=3,value='js_param')
    worksheet.cell(row=3,column=4,value='ptq_acc')
    worksheet.cell(row=3,column=5,value='acc_loss')
    for i in range(len(title_list)):
        worksheet.cell(row=i+4, column=1, value=title_list[i])
        worksheet.cell(row=i+4, column=2, value=js_flops_list[i])
        worksheet.cell(row=i+4, column=3, value=js_param_list[i])
        worksheet.cell(row=i+4, column=4, value=ptq_acc_list[i])
        worksheet.cell(row=i+4, column=5, value=acc_loss_list[i])
    
    workbook.save('ptq_result_' + config.MODEL.NAME + '.xlsx')
    writer.close()

    ft = open('ptq_result_' + config.MODEL.NAME + '.txt','w')

    print('title_list:',file=ft)
    print(" ".join(title_list),file=ft)
    print('js_flops_list:',file=ft)
    print(" ".join(str(i) for i in js_flops_list), file=ft)
    print('js_param_list:',file=ft)
    print(" ".join(str(i) for i in js_param_list), file=ft)
    print('ptq_acc_list:',file=ft)
    print(" ".join(str(i) for i in ptq_acc_list), file=ft)
    print('acc_loss_list:',file=ft)
    print(" ".join(str(i) for i in acc_loss_list), file=ft)

    ft.close()

