from model import *
from model_foldbn import *
from extract_ratio import *
from utils import *

import openpyxl
import gol
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR


def js_div_norm(a,b):
    a_norm = F.normalize(a.data,p=2,dim=-1)
    b_norm = F.normalize(b.data,p=2,dim=-1)
    return js_div(a_norm,b_norm).cpu().item()

def js_div_0(a,b):
    return js_div(a,b).cpu().item()


def direct_quantize(model, test_loader,device):
    for i, (data, target) in enumerate(test_loader, 1):
        data = data.to(device)
        output = model.quantize_forward(data).cpu()
        if i % 500 == 0:
            break
    print('direct quantization finish')

def quantize_aware_training(model, device, train_loader, optimizer, epoch):

    old_sub_str0 = "downsample.0"
    new_sub_str0 = "conv1"
    old_sub_str1 = "downsample.1"
    new_sub_str1 = "bn1"

    lossLayer = torch.nn.CrossEntropyLoss()
    #统计loss和每个参数的grad
    #初始化
    loss_sum = 0.
    grad_dict = {}
    for name,param in model.named_parameters():
        if old_sub_str0 in name:
            name = name.replace(old_sub_str0, new_sub_str0)
        elif old_sub_str1 in name:
            name = name.replace(old_sub_str1, new_sub_str1)
        grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
        
    
    for batch_idx, (data, target) in enumerate(train_loader, 1):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model.quantize_forward(data)
        # 对一批数据求得的loss是平均值
        loss = lossLayer(output, target)
        loss.backward()
        
        #loss和grads累加
        loss_sum += loss
        for name,param in model.named_parameters():
            if param.grad is not None:
                # print('-------'+name+'-------')
                if old_sub_str0 in name:
                    name = name.replace(old_sub_str0, new_sub_str0)
                elif old_sub_str1 in name:
                    name = name.replace(old_sub_str1, new_sub_str1)
                grad_dict[name] += param.grad.detach()
                # print(grad_dict[name])

        # print(grad_dict.items())
        # input()
        optimizer.step()

        if batch_idx % 50 == 0:
            print('Quantize Aware Training Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
            ))
    
    batch_size = len(train_loader.batch_sampler)
    #对不同batch累加值求平均
    for name,grad in grad_dict.items():
        grad_dict[name] = grad / batch_size
    loss_avg = loss_sum / batch_size
    return loss_avg, grad_dict

def full_inference(model, test_loader):
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    lossLayer = torch.nn.CrossEntropyLoss()
    #统计loss和每个参数的grad
    #初始化
    loss_sum = 0.
    grad_dict = {}
    for name,param in model.named_parameters():
        grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = lossLayer(output, target)
        loss.backward()
        
        #loss和grads累加
        loss_sum += loss
        for name,param in model.named_parameters():
            if param.grad is not None:
                # print('-------'+name+'-------')
                grad_dict[name] += param.grad.detach()
                # print(grad_dict[name])        

        optimizer.step()

        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
            ))

    batch_size = len(train_loader.batch_sampler)
    #对不同batch累加值求平均
    for name,grad in grad_dict.items():
        grad_dict[name] = grad / batch_size
    loss_avg = loss_sum / batch_size
    return loss_avg, grad_dict


def quantize_inference(model, test_loader):
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        data, target = data.to(device), target.to(device)
        output = model.quantize_inference(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('\nTest set: Quant Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))


if __name__ == "__main__":


    parser = argparse.ArgumentParser(description='QAT Training')
    parser.add_argument('-m', '--model', metavar='MODEL ARCH', default='resnet18')
    parser.add_argument('-e','--epochs', default=15, type=int, metavar='EPOCHS', help='number of total epochs to run')
    parser.add_argument('-b', '--batch_size', default=128, type=int, metavar='BATCH SIZE', help='mini-batch size (default: 128)')
    parser.add_argument('-j','--workers', default=1, type=int, metavar='WORKERS',help='number of data loading workers (default: 4)')
    parser.add_argument('-lr', '--learning-rate', default=0.001, type=float, metavar='LR', help='initial learning rate', dest='lr')
    parser.add_argument('-wd','--weight_decay',default=0.0001,type=float,metavar='WD',help='lr schduler weight decay',dest='wd')
    parser.add_argument('-t', '--test', dest='test', action='store_true', help='test model on test set')
    args = parser.parse_args()


    
    batch_size = args.batch_size
    seed = 1
    epochs = args.epochs
    lr = args.lr
    # momentum = 0.5
    weight_decay = args.wd

    torch.manual_seed(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    writer = SummaryWriter(log_dir='log/' + args.model  +  '/qat')
    wb = openpyxl.Workbook()
    ws = wb.active

    old_sub_str0 = "downsample.0"
    new_sub_str0 = "conv1"
    old_sub_str1 = "downsample.1"
    new_sub_str1 = "bn1"


    if args.model == 'ResNet18':
        model = resnet18_fold()
    elif args.model == 'ResNet50':
        model = resnet50_fold()
    elif args.model == 'ResNet152':
        model = resnet152_fold()


    layer, par_ratio, flop_ratio = extract_ratio(args.model)
    # TODO layer要重新读取

    layer = []

    # 此处得到的layer是为了标记par_ratio, flop_ratio 对应起来 一层一个名字 一个flop/flop ratio
    for name, param in model.named_parameters():
        if 'weight' in name:
            n = name.split('.')  # conv,bn,fc这些有param的层的名字都能提取出来
            pre = '.'.join(n[:len(n)-1])
            # 提取出weight前的名字(就是这个层的名字，if weight是避免bias重复提取一遍名字) 
            # 无downsample串
            layer.append(pre)


    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('../../project/p/data', train=True, download=True,
                       transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=batch_size, shuffle=True, num_workers=args.workers, pin_memory=False
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('../../project/p/data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])),
        batch_size=batch_size, shuffle=True, num_workers=args.workers, pin_memory=False
    )


    

    # model.load_state_dict(torch.load(full_file))
    model.to(device)
    momentum = 0.9
    
    # optimizer1 = optim.Adam(model.parameters(), lr=lr)
    optimizer1 = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    lr_scheduler1 = CosineAnnealingLR(optimizer1, T_max=epochs)

   

 # 没save .pt  无load
    load_qat = False
    ckpt_prefix = 'ckpt/qat/'+ args.model + '/'

    loss_sum = 0.
    full_grad_sum = {}
    full_grad_avg = {}
    for name,param in model.named_parameters():
        full_grad_sum[name] = torch.zeros_like(param)
        full_grad_avg[name] = torch.zeros_like(param)
    for epoch in range(1, epochs+1):
        # 训练原模型，获取梯度分布
        loss,full_grad = train(model, device, train_loader, optimizer1, epoch)
        if epoch == 1:
            loss_start = loss
        # print('loss:%f' % loss_avg)
        writer.add_scalar('Full.loss',loss,epoch)
        # for name,grad in grad_dict.items():
        #     writer.add_histogram('Full.'+name+'_grad',grad,global_step=epoch)

        loss_sum += loss
        loss_avg = loss_sum / epoch  
        # loss的变化量  越大说明收敛的越快(不同model在对比时，相同的epoch数，loss_delta大说明很快就进入了小loss的收敛期)      
        loss_delta = loss - loss_start  

        
        for name,grad in full_grad.items():
            full_grad_sum[name] += full_grad[name]
            full_grad_avg[name] = full_grad_sum[name] / epoch
    
        if epoch % 5 == 0:
            ws = wb.create_sheet('epoch_%d'%epoch)
            ws.cell(row=1,column=2,value='loss')
            ws.cell(row=1,column=3,value='loss_sum')
            ws.cell(row=1,column=4,value='loss_avg')
            ws.cell(row=1,column=5,value='loss_delta')
            ws.cell(row=2,column=1,value='FP32')
            ws.cell(row=2,column=2,value=loss.item())
            ws.cell(row=2,column=3,value=loss_sum.item())
            ws.cell(row=2,column=4,value=loss_avg.item())
            ws.cell(row=2,column=5,value=loss_delta.item())
            
            ws.cell(row=4,column=1,value='title')
            ws.cell(row=4,column=2,value='loss')
            ws.cell(row=4,column=3,value='loss_sum')
            ws.cell(row=4,column=4,value='loss_avg')
            ws.cell(row=4,column=5,value='loss_delta')
            ws.cell(row=4,column=6,value='js_grad')
            ws.cell(row=4,column=7,value='js_grad_sum')
            ws.cell(row=4,column=8,value='js_grad_avg')
        
        # lr_scheduler1.step()
            

    quant_type_list = ['INT']
    gol._init()


    
    currow=4 #数据从哪行开始写
    for quant_type in quant_type_list:
        num_bit_list = numbit_list(quant_type)

        # 对一个量化类别，只需设置一次bias量化表
        # int由于位宽大，使用量化表开销过大，直接_round即可
        if quant_type != 'INT':
            bias_list = build_bias_list(quant_type)
            gol.set_value(bias_list, is_bias=True)

        for num_bits in num_bit_list:
            e_bit_list = ebit_list(quant_type,num_bits)
            for e_bits in e_bit_list:
                
                if quant_type == 'FLOAT':
                    title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
                else:
                    title = '%s_%d' % (quant_type, num_bits)

                if load_qat is True and osp.exists(ckpt_prefix+'epoch_20/'+title+'.pt'):
                    continue

                currow += 1
                print('\nQAT: '+title)
                
                if args.model == 'ResNet18':
                    model_ptq = resnet18()
                elif args.model == 'ResNet50':
                    model_ptq = resnet50()
                elif args.model == 'ResNet152':
                    model_ptq = resnet152()

                # optimizer2 = optim.Adam(model_ptq.parameters(), lr=lr)
                # lr_scheduler2 = CosineAnnealingLR(optimizer2, T_max=epochs)
                optimizer2 = optim.SGD(model_ptq.parameters(), lr=lr, momentum=momentum)

                 # 设置量化表
                if quant_type != 'INT':
                    plist = build_list(quant_type, num_bits, e_bits)
                    gol.set_value(plist)
                model_ptq.to(device)
                full_file = 'ckpt/cifar10_' + args.model + '.pt'
                # model_ptq.load_state_dict(torch.load(full_file))
                model_ptq.quantize(quant_type,num_bits,e_bits)

                model_ptq.eval()
                direct_quantize(model_ptq, train_loader, device)

                model_ptq.train()

                loss_sum = 0.
                qat_grad_sum = {}
                qat_grad_avg = {}
                # 因为没有freeze 所以model和model_ptq的parameters其实一样，只是name在downsample处略有不同
                for name,param in model_ptq.named_parameters():
                    if old_sub_str0 in name:
                        name = name.replace(old_sub_str0, new_sub_str0)
                    elif old_sub_str1 in name:
                        name = name.replace(old_sub_str1, new_sub_str1)
                    qat_grad_sum[name] = torch.zeros_like(param)
                    qat_grad_avg[name] = torch.zeros_like(param)

               

                for epoch in range(1, epochs+1):
                    loss,qat_grad = quantize_aware_training(model_ptq, device, train_loader, optimizer2, epoch)
                    # print('loss:%f' % loss_avg)
                    if epoch == 1:
                        loss_start = loss
                    writer.add_scalar(title+'.loss',loss,epoch)
                    # for name,grad in qat_grad.items():
                    #     writer.add_histogram(title+'.'+name+'_grad',grad,global_step=epoch)

                    loss_sum += loss
                    loss_avg = loss_sum / epoch 
                    loss_delta = loss-loss_start   


                    
                    # 这里对各个epoch的梯度求和不太合理吧  修改成下面的每5个epoch只对那一个epoch的梯度求和
                    for name,param in model_ptq.named_parameters():
                            # qat_grad_sum[name] += qat_grad[name]
                        # 只是对name中的部分串做简单替换
                        if old_sub_str0 in name:
                            name = name.replace(old_sub_str0, new_sub_str0)
                        elif old_sub_str1 in name:
                            name = name.replace(old_sub_str1, new_sub_str1)
                        qat_grad_sum[name] +=  qat_grad[name]
                        qat_grad_avg[name] += qat_grad_sum[name] / epoch

                    

                    # 应对每一个epoch都这样计算，而不是只计算在某一个epoch的情况
                    if epoch % 5 == 0:
                           
                        
                        ws = wb['epoch_%d'%epoch]
                        js_grad = 0.
                        js_grad_sum = 0.
                        js_grad_avg = 0.


            
                        for name,_ in model_ptq.named_parameters():
                         # TODO
                            # 可以把downsample换成对应conv，bn的名字 
                            # downsample.0 => conv1    downsample.1 => bn1
                            # 由于没有freeze，因此model和model_ptq中的conv都是没有bias的
                            # 是否需要考虑BN的相似度和梯度还有待观察
                            n = name.split('.')
                            prefix = '.'.join(n[:len(n) - 1])
                            if old_sub_str0 in prefix:
                                prefix = prefix.replace(old_sub_str0, new_sub_str0)
                            elif old_sub_str1 in prefix:
                                prefix = prefix.replace(old_sub_str1, new_sub_str1)

                            if old_sub_str0 in name:
                                name = name.replace(old_sub_str0, new_sub_str0)
                            elif old_sub_str1 in name:
                                name = name.replace(old_sub_str1, new_sub_str1)
                            
                            # layer中是层名的顺序排序，flop_ratio中也是按层名顺序排序的ratio
                            layer_idx = layer.index(prefix)

                            # 加权求和
                            # 这里相当于只记录了full precision时的最后一个epoch的grad
                            js = js_div_0(qat_grad[name],full_grad[name])
                            js_sum = js_div_0(qat_grad_sum[name],full_grad_sum[name])
                            js_avg = js_div_0(qat_grad_avg[name],full_grad_avg[name])
                            if js < 0:
                                js = 0
                            if js_sum < 0:
                                js_sum = 0
                            if js_avg < 0:
                                js_avg = 0

                            js_grad += flop_ratio[layer_idx] * js
                            print(f"name{name}\nqat_grad_avg[{name}]={qat_grad_avg[name]}\nfull_grad_avg[{name}]={full_grad_avg[name]}\njs:{js}\nidx:{layer_idx}")
                            js_grad_sum += flop_ratio[layer_idx] * js_sum
                            js_grad_avg += flop_ratio[layer_idx] * js_avg
                        ws.cell(row=currow,column=1,value=title)
                        ws.cell(row=currow,column=2,value=loss.item())
                        ws.cell(row=currow,column=3,value=loss_sum.item())
                        ws.cell(row=currow,column=4,value=loss_avg.item())
                        ws.cell(row=currow,column=5,value=loss_delta.item())
                        ws.cell(row=currow,column=6,value=js_grad)
                        ws.cell(row=currow,column=7,value=js_grad_sum)
                        ws.cell(row=currow,column=8,value=js_grad_avg)
                        print(f"name:{name},js_grad:{js_grad},js_sum:{js_grad_sum},js_avg:{js_grad_avg}")

                        # print(f"quan_type:{quant_type},num_bits:{num_bits},epoch:{epoch}")
                        # print(f"loss:{loss.item()},loss_sum:{loss_sum.item()},loss_avg:{loss_avg.item()},loss_delta:{loss_delta.item()}")
                        # print(f"js_grad:{js_grad},js_grad_sum:{js_grad_sum},js_grad_avg:{js_grad_avg}")

                # lr_scheduler2.step()


    wb.remove(wb['Sheet'])  # 根据名称删除工作表
    wb.save(args.model + 'qat_result.xlsx')
    writer.close()
