# -*- coding: utf-8 -*-
from torch.serialization import load
from model import *
# from extract_ratio import *
from utils import *

import gol
import openpyxl
import sys
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode
import torch.utils.bottleneck as bn
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
import json
from decoder import seq_mnist_decoder
from data import seq_mnist_train, seq_mnist_val
from torch.utils.data import DataLoader
import random

class objdict(dict):
    def __getattr__(self, name):
        if name in self:
            return self[name]
        else:
            raise AttributeError("No such attribute: " + name)

    def __setattr__(self, name, value):
        self[name] = value

    def __delattr__(self, name):
        if name in self:
            del self[name]
        else:
            raise AttributeError("No such attribute: " + name)


# def direct_quantize(model, test_loader,device):
#     for i, (data, target) in enumerate(test_loader, 1):
#         data = data.to(device)
#         output = model.quantize_forward(data).cpu()
#         if i % 500 == 0:
#             break
#     print('direct quantization finish')


# def full_inference(model, test_loader, device):
#     correct = 0
#     for i, (data, target) in enumerate(test_loader, 1):
#         data = data.to(device)
#         output = model(data).cpu()
#         pred = output.argmax(dim=1, keepdim=True)
#         correct += pred.eq(target.view_as(pred)).sum().item()
#     print('\nTest set: Full Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
#     return 100. * correct / len(test_loader.dataset)


# def quantize_inference(model, test_loader, device):
#     correct = 0
#     for i, (data, target) in enumerate(test_loader, 1):
#         data = data.to(device)
#         output = model.quantize_inference(data).cpu()
#         pred = output.argmax(dim=1, keepdim=True)
#         correct += pred.eq(target.view_as(pred)).sum().item()
#     print('Test set: Quant Model Accuracy: {:.2f}%'.format(100. * correct / len(test_loader.dataset)))
#     return 100. * correct / len(test_loader.dataset)

def direct_quantize(model, val_loader , val_data ,args , trainer_params, decoder, criterion):
    model.eval()
    loss_value = 0
    for i, (item) in enumerate(val_loader):           
        data, labels, output_len, lab_len = item
            
        data = Variable(data.transpose(1,0), requires_grad=False)
        labels = Variable(labels.view(-1), requires_grad=False)
        output_len = Variable(output_len.view(-1), requires_grad=False)
        lab_len = Variable(lab_len.view(-1), requires_grad=False)
            

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
        data = data.to(device)

        output = model.quantize_forward(data)
#         if i % 500 == 0:
# #             break

    print('direct quantization finish')
            
       

    # loss_value /= (len(val_data)//trainer_params.test_batch_size)
    # # loss_value = loss_value[0]
    # loss_value = loss_value.item()
    # print("Average Loss Value for Val Data is = {:.4f}\n".format(float(loss_value)))

def full_inference(model, val_loader , val_data ,args , trainer_params, decoder, criterion):
    model.eval()
    loss_value = 0
    for i, (item) in enumerate(val_loader):           
        data, labels, output_len, lab_len = item
            
        data = Variable(data.transpose(1,0), requires_grad=False)
        labels = Variable(labels.view(-1), requires_grad=False)
        output_len = Variable(output_len.view(-1), requires_grad=False)
        lab_len = Variable(lab_len.view(-1), requires_grad=False)
            

            # data = data.cuda()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
        data = data.to(device)

        output = model(data)
            
        index = random.randint(0,trainer_params.test_batch_size-1)      
        label = labels[index*trainer_params.word_size:(index+1)*trainer_params.word_size].data.numpy()
        label = label-1
        prediction = decoder.decode(output[:,index,:], output_len[index], lab_len[index])
        accuracy = decoder.hit(prediction, label)

        print("Sample Label      = {}".format(decoder.to_string(label))) 
        print("Sample Prediction = {}".format(decoder.to_string(prediction)))
        print("Full Model Accuracy on Sample = {:.2f}%\n\n".format(accuracy))

        loss = criterion(output, labels, output_len, lab_len)
        # loss_value += loss.data.numpy()
        loss_value += loss.cpu().data.numpy()


    loss_value /= (len(val_data)//trainer_params.test_batch_size)
    # loss_value = loss_value[0]
    loss_value = loss_value.item()
    print("Full Model Average Loss Value for Val Data is = {:.4f}\n".format(float(loss_value)))

def quantize_inference(model, val_loader , val_data ,args , trainer_params, decoder, criterion):
    model.eval()
    loss_value = 0
    for i, (item) in enumerate(val_loader):           
        data, labels, output_len, lab_len = item
            
        data = Variable(data.transpose(1,0), requires_grad=False)
        labels = Variable(labels.view(-1), requires_grad=False)
        output_len = Variable(output_len.view(-1), requires_grad=False)
        lab_len = Variable(lab_len.view(-1), requires_grad=False)
            
            # data = data.cuda()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
        data = data.to(device)

        output = model.quantize_inference(data)
            
        index = random.randint(0,trainer_params.test_batch_size-1)      
        label = labels[index*trainer_params.word_size:(index+1)*trainer_params.word_size].data.numpy()
        label = label-1
        prediction = decoder.decode(output[:,index,:], output_len[index], lab_len[index])
        accuracy = decoder.hit(prediction, label)

        print("Sample Label      = {}".format(decoder.to_string(label))) 
        print("Sample Prediction = {}".format(decoder.to_string(prediction)))
        print("Quantize Model Accuracy on Sample = {:.2f}%\n\n".format(accuracy))

        loss = criterion(output, labels, output_len, lab_len)
        # loss_value += loss.data.numpy()
        loss_value += loss.cpu().data.numpy()


    loss_value /= (len(val_data)//trainer_params.test_batch_size)
    # loss_value = loss_value[0]
    loss_value = loss_value.item()
    print("Quantize Model Average Loss Value for Val Data is = {:.4f}\n".format(float(loss_value)))



def js_div(p_output, q_output, get_softmax=True):
    """
    Function that measures JS divergence between target and output logits:
    """
    KLDivLoss = nn.KLDivLoss(reduction='sum')
    if get_softmax:
        p_output = F.softmax(p_output)
        q_output = F.softmax(q_output)
    log_mean_output = ((p_output + q_output)/2).log()
    return (KLDivLoss(log_mean_output, p_output) + KLDivLoss(log_mean_output, q_output))/2


if __name__ == "__main__":

    
    parser = argparse.ArgumentParser(description='PyTorch FP32 Training')
    parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='ResNet18')
    parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
    parser.add_argument('-j','--workers', default=4, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')
    parser.add_argument('-s', '--save',  default=False, type=bool)
    # parser.add_argument('-t', '--test', dest='test', action='store_true', help='test model on test set')
    parser.add_argument('--params', '-p', type=str, default="default_trainer_params.json", help='Path to params JSON file. Default ignored when resuming.')
    # 训练参数
    args = parser.parse_args()

    with open(args.params) as d:
        trainer_params = json.load(d)
            # trainer_params = json.load(d, object_hook=ascii_encode_dict)
    trainer_params = objdict(trainer_params)

    batch_size = args.batch_size
    num_workers = args.workers

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    labels = [i for i in range(trainer_params.num_classes-1)]
    decoder = seq_mnist_decoder(labels=labels)
    criterion = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)

    random.seed(trainer_params.random_seed)
    torch.manual_seed(trainer_params.random_seed)
    # if args.cuda:
    torch.cuda.manual_seed_all(trainer_params.random_seed)
        
    train_data = seq_mnist_train(trainer_params)
    val_data = seq_mnist_val(trainer_params) 
        
    train_loader = DataLoader(train_data, batch_size=trainer_params.batch_size, \
                                        shuffle=True, num_workers=trainer_params.num_workers)
        
    val_loader = DataLoader(val_data, batch_size=trainer_params.test_batch_size, \
                                        shuffle=False, num_workers=trainer_params.num_workers)      

    if args.model == 'LSTM-OCR':
        model = BiLSTM(trainer_params)

    
    # writer = SummaryWriter(log_dir='log/' + args.model  +  '/ptq')
    save_dir = 'ckpt'
    full_file = save_dir + '/mnist_' + trainer_params.reduce_bidirectional +'_' + str(trainer_params.bidirectional) + '.pt'
    model.load_state_dict(torch.load(full_file))
    model.to(device)

    load_ptq = False
    ptq_file_prefix = 'ckpt/mnist_' + trainer_params.reduce_bidirectional +'_' + str(trainer_params.bidirectional)  + '_ptq_'

    model.eval()
    full_acc = full_inference(model, val_loader, val_data, args, trainer_params, decoder,criterion)

    # model_fold = fold_model(model)   # 
    
    # full_params = []

    # layer, par_ratio, flop_ratio = extract_ratio(args.model)

    # layer = []
    # for name, param in model.named_parameters():
    #     if 'weight' in name:
    #         n = name.split('.')
    #         pre = '.'.join(n[:len(n)-1])
    #         # 提取出weight前的名字(就是这个层的名字，if weight是避免bias重复提取一遍名字)
    #         layer.append(pre)

    # print('===================')

    # par_ratio, flop_ratio = fold_ratio(layer, par_ratio, flop_ratio)

    # for name, param in model_fold.named_parameters():
    #     if 'bn' in name or 'sample.1' in name:
    #         continue
    #     param_norm = param.data.cpu()
    #     full_params.append(param_norm)  # 没统计bn的 只统计了conv的 而且还是fold后的
    #     writer.add_histogram(tag='Full_' + name + '_data', values=param.data)
    
    
    gol._init()
    quant_type_list = ['INT','POT','FLOAT']
    title_list = []
    js_flops_list = []
    js_param_list = []
    ptq_acc_list = []
    acc_loss_list = []

    for quant_type in quant_type_list:
        num_bit_list = numbit_list(quant_type)
        # 对一个量化类别，只需设置一次bias量化表
        # int由于位宽大，使用量化表开销过大，直接_round即可
        if quant_type != 'INT':
            bias_list = build_bias_list(quant_type)
            gol.set_value(bias_list, is_bias=True)

        for num_bits in num_bit_list:
            e_bit_list = ebit_list(quant_type,num_bits)
            for e_bits in e_bit_list:
                # model_ptq = resnet18()
                if args.model == 'LSTM-OCR':
                    model_ptq = BiLSTM(trainer_params)


                if quant_type == 'FLOAT':
                    title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
                else:
                    title = '%s_%d' % (quant_type, num_bits)
                print('\nPTQ: '+title)
                title_list.append(title)

                # 设置量化表
                if quant_type != 'INT':
                    plist = build_list(quant_type, num_bits, e_bits)
                    gol.set_value(plist)

                # 判断是否需要载入
                if load_ptq is True and osp.exists(ptq_file_prefix + title + '.pt'):
                    model_ptq.quantize(quant_type,num_bits,e_bits)
                    model_ptq.load_state_dict(torch.load(ptq_file_prefix + title + '.pt'))
                    model_ptq.to(device)
                    print('Successfully load ptq model: ' + title)
                else:
                    model_ptq.load_state_dict(torch.load(full_file))
                    model_ptq.to(device)
                    model_ptq.quantize(quant_type,num_bits,e_bits)
                    model_ptq.eval()
                    direct_quantize(model_ptq, val_loader, val_data, args, trainer_params, decoder,criterion)

                    # if args.save == True:
                    #     torch.save(model_ptq.state_dict(), ptq_file_prefix + title + '.pt')
                        

                model_ptq.freeze()
                quantize_inference(model_ptq, val_loader, val_data, args, trainer_params, decoder,criterion)
                
                # ptq_acc = quantize_inference(model_ptq, val_loader, val_data, args, trainer_params, decoder,criterion)
                # ptq_acc_list.append(ptq_acc)
                # acc_loss = (full_acc - ptq_acc) / full_acc
                # acc_loss_list.append(acc_loss)
                # idx = -1

                # 获取计算量/参数量下的js-div
                js_flops = 0.
                js_param = 0.

                # for name, param in model_ptq.named_parameters():
                #     # if '.' not in name or 'bn' in name:
                #     if 'bn' in name or 'sample.1' in name:
                #         continue
                #     writer.add_histogram(tag=title +':'+ name + '_data', values=param.data)
                #     idx = idx + 1
                #     # renset中有多个. 需要改写拼一下
                #     # prefix = name.split('.')[0]
                #     n = name.split('.')
                #     prefix = '.'.join(n[:len(n) - 1])
                #      # weight和bias 1:1 ? 对于ratio，是按层赋予的，此处可以对weight和bias再单独赋予不同的权重，比如(8:2)
                #     if prefix in layer:
                #         layer_idx = layer.index(prefix)
                #         ptq_param = param.data.cpu()
                #         # 取L2范数
                #         # ptq_norm = F.normalize(ptq_param,p=2,dim=-1)
                #         ptq_norm = ptq_param
                #         writer.add_histogram(tag=title +':'+ name + '_data', values=ptq_param)
                #         # print(name)
                #         # print('=========')
                #         # print(ptq_norm)
                #         # print('=========')
                #         # print(full_params[idx])
                #         js = js_div(ptq_norm,full_params[idx])   # 这里算了fold后的量化前后模型的js距离
                #         js = js.item()
                #         if js < 0.:
                #             js = 0.
                #         js_flops = js_flops + js * flop_ratio[layer_idx]
                #         js_param = js_param + js * par_ratio[layer_idx]
                # js_flops_list.append(js_flops)
                # js_param_list.append(js_param)

                # print(title + ': js_flops: %f js_param: %f acc_loss: %f' % (js_flops, js_param, acc_loss))
    sys.exit()
    # 写入xlsx
    workbook = openpyxl.Workbook()
    worksheet = workbook.active
    worksheet.cell(row=1,column=1,value='FP32-acc')
    worksheet.cell(row=1,column=2,value=full_acc)
    worksheet.cell(row=3,column=1,value='title')
    worksheet.cell(row=3,column=2,value='js_flops')
    worksheet.cell(row=3,column=3,value='js_param')
    worksheet.cell(row=3,column=4,value='ptq_acc')
    worksheet.cell(row=3,column=5,value='acc_loss')
    for i in range(len(title_list)):
        worksheet.cell(row=i+4, column=1, value=title_list[i])
        worksheet.cell(row=i+4, column=2, value=js_flops_list[i])
        worksheet.cell(row=i+4, column=3, value=js_param_list[i])
        worksheet.cell(row=i+4, column=4, value=ptq_acc_list[i])
        worksheet.cell(row=i+4, column=5, value=acc_loss_list[i])
    
    workbook.save('ptq_result_' + args.model + '.xlsx')
    writer.close()

    ft = open('ptq_result_' + args.model + '.txt','w')

    print('title_list:',file=ft)
    print(" ".join(title_list),file=ft)
    print('js_flops_list:',file=ft)
    print(" ".join(str(i) for i in js_flops_list), file=ft)
    print('js_param_list:',file=ft)
    print(" ".join(str(i) for i in js_param_list), file=ft)
    print('ptq_acc_list:',file=ft)
    print(" ".join(str(i) for i in ptq_acc_list), file=ft)
    print('acc_loss_list:',file=ft)
    print(" ".join(str(i) for i in acc_loss_list), file=ft)

    ft.close()

