from model import *
from utils import *
import gol
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import os
import os.path as osp
from torch.utils.tensorboard import SummaryWriter

def quantize_aware_training(model, device, train_loader, optimizer, epoch):
    lossLayer = torch.nn.CrossEntropyLoss()
    #统计loss和每个参数的grad
    #初始化
    loss_sum = 0.
    grad_dict = {}
    for name,param in model.named_parameters():
        grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
    
    for batch_idx, (data, target) in enumerate(train_loader, 1):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model.quantize_forward(data)
        # 对一批数据求得的loss是平均值
        loss = lossLayer(output, target)
        loss.backward()
        
        #loss和grads累加
        loss_sum += loss
        for name,param in model.named_parameters():
            if param.grad is not None:
                # print('-------'+name+'-------')
                grad_dict[name] += param.grad
                # print(grad_dict[name])

        # print(grad_dict.items())
        # input()
        optimizer.step()

        if batch_idx % 50 == 0:
            print('Quantize Aware Training Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
            ))
    
    batch_size = len(train_loader.batch_sampler)
    #对不同batch累加值求平均
    for name,grad in grad_dict.items():
        grad_dict[name] = grad / batch_size
    loss_avg = loss_sum / batch_size
    return loss_avg, grad_dict

def full_inference(model, test_loader):
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        data, target = data.to(device), target.to(device)
        output = model(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('\nTest set: Full Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    lossLayer = torch.nn.CrossEntropyLoss()
    #统计loss和每个参数的grad
    #初始化
    loss_sum = 0.
    grad_dict = {}
    for name,param in model.named_parameters():
        grad_dict[name] = torch.zeros_like(param) #param.grad和param形状相同
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = lossLayer(output, target)
        loss.backward()
        
        #loss和grads累加
        loss_sum += loss
        for name,param in model.named_parameters():
            if param.grad is not None:
                # print('-------'+name+'-------')
                grad_dict[name] += param.grad
                # print(grad_dict[name])        

        optimizer.step()

        if batch_idx % 50 == 0:
            print('Train Epoch: {} [{}/{}]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset), loss.item()
            ))

    batch_size = len(train_loader.batch_sampler)
    #对不同batch累加值求平均
    for name,grad in grad_dict.items():
        grad_dict[name] = grad / batch_size
    loss_avg = loss_sum / batch_size
    return loss_avg, grad_dict


def quantize_inference(model, test_loader):
    correct = 0
    for i, (data, target) in enumerate(test_loader, 1):
        data, target = data.to(device), target.to(device)
        output = model.quantize_inference(data)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    print('\nTest set: Quant Model Accuracy: {:.2f}%\n'.format(100. * correct / len(test_loader.dataset)))


if __name__ == "__main__":

    batch_size = 32
    seed = 1
    epochs = 20
    lr = 0.001
    momentum = 0.5

    torch.manual_seed(seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    writer = SummaryWriter(log_dir='./log/qat')
    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('data', train=True, download=True,
                       transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                       ])),
        batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
    )

    test_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10('data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])),
        batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=False
    )


    full_file = 'ckpt/cifar10_AlexNet.pt'
    model = AlexNet()
    # model.load_state_dict(torch.load(full_file))
    model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    
    load_qat = False
    ckpt_prefix = "ckpt/qat/"

    loss_sum = 0.
    grad_dict_sum = {}
    grad_dict_avg = {}
    for name,param in model.named_parameters():
        grad_dict_sum[name] = torch.zeros_like(param)
        grad_dict_avg[name] = torch.zeros_like(param)
    for epoch in range(1, epochs+1):
        # 训练原模型，获取梯度分布
        loss,grad_dict = train(model, device, train_loader, optimizer, epoch)
        if epoch == 1:
            loss_start = loss
        loss_delta = loss - loss_start
        # print('loss:%f' % loss_avg)
        writer.add_scalar('Full.loss',loss,epoch)
        # for name,grad in grad_dict.items():
        #     writer.add_histogram('Full.'+name+'_grad',grad,global_step=epoch)

        loss_sum += loss
        loss_avg = loss_sum / epoch        
        for name,grad in grad_dict.items():
            grad_dict_sum[name] += grad_dict[name]
            grad_dict_avg[name] = grad_dict_sum[name] / epoch

        if store_qat:
            ckpt = {
                'epoch'     : epoch,
                'loss'      : loss,
                'loss_sum'  : loss_sum,
                'loss_avg'  : loss_avg,
                'grad_dict_avg' : grad_dict_avg
            }
            if epoch % 5 == 0:
                subdir = 'epoch_%d/' % epoch
                torch.save(ckpt,ckpt_prefix+ subdir +'full.pt')
    
        
        # loss_avg,grad_dict = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
        # print('qat_loss:%f' % loss_avg)
        # for name,grad in grad_dict.items():
        #     writer.add_histogram('qat_'+name+'_grad',grad,global_step=epoch)

    quant_type_list = ['INT','POT','FLOAT']
    gol._init()
    for quant_type in quant_type_list:
        num_bit_list = numbit_list(quant_type)

        # 对一个量化类别，只需设置一次bias量化表
        # int由于位宽大，使用量化表开销过大，直接_round即可
        if quant_type != 'INT':
            bias_list = build_bias_list(quant_type)
            gol.set_value(bias_list, is_bias=True)

        for num_bits in num_bit_list:
            e_bit_list = ebit_list(quant_type,num_bits)
            for e_bits in e_bit_list:
                
                if quant_type == 'FLOAT':
                    title = '%s_%d_E%d' % (quant_type, num_bits, e_bits)
                else:
                    title = '%s_%d' % (quant_type, num_bits)

                if load_qat is True and osp.exists(ckpt_prefix+'epoch_20/'+title+'.pt'):
                    continue

                print('\nQAT: '+title)
                
                model_ptq = AlexNet()
                 # 设置量化表
                if quant_type != 'INT':
                    plist = build_list(quant_type, num_bits, e_bits)
                    gol.set_value(plist)
                model_ptq.load_state_dict(torch.load(full_file))
                model_ptq.to(device)
                model_ptq.quantize(quant_type,num_bits,e_bits)
                model_ptq.train()

                loss_sum = 0.
                grad_dict_sum = {}
                grad_dict_avg = {}
                for name,param in model.named_parameters():
                    grad_dict_sum[name] = torch.zeros_like(param)
                    grad_dict_avg[name] = torch.zeros_like(param)
                for epoch in range(1, epochs+1):
                    loss,grad_dict = quantize_aware_training(model_ptq, device, train_loader, optimizer, epoch)
                    # print('loss:%f' % loss_avg)
                    writer.add_scalar(title+'.loss',loss,epoch)
                    for name,grad in grad_dict.items():
                        writer.add_histogram(title+'.'+name+'_grad',grad,global_step=epoch)

                    loss_sum += loss
                    loss_avg = loss_sum / epoch        
                    for name,param in model.named_parameters():
                        grad_dict_sum[name] += grad_dict[name]
                        grad_dict_avg[name] = grad_dict_sum[name] / epoch

                    
                    ckpt = {
                        'epoch'     : epoch,
                        'loss'      : loss,
                        'loss_sum'  : loss_sum,
                        'loss_avg'  : loss_avg,
                        # 'grad_dict' : grad_dict,
                        # 'grad_dict_sum' : grad_dict_sum,
                        'grad_dict_avg' : grad_dict_avg
                    }
                    if epoch % 5 == 0:
                        subdir = 'epoch_%d/' % epoch
                        torch.save(ckpt,ckpt_prefix+subdir + title+'.pt')
    writer.close()
    # # model.eval()
    
    # # full_inference(model, test_loader)

    # num_bits = 8
    # e_bits = 0
    # gol._init()
    # print("qat: INT8")

    # model.quantize('INT',num_bits,e_bits)
    # print('Quantization bit: %d' % num_bits)

    # if load_quant_model_file is not None:
    #     model.load_state_dict(torch.load(load_quant_model_file))
    #     print("Successfully load quantized model %s" % load_quant_model_file)
    # else:
    #     model.train()

    #     for epoch in range(1, epochs+1):
    #         quantize_aware_training(model, device, train_loader, optimizer, epoch)
    #     # for epoch in range(epochs1 + 1, epochs2 + 1):
    #     #     quantize_aware_training(model, device, train_loader, optimizer2, epoch)

    #     model.eval()
    #     # torch.save(model.state_dict(), save_file)

    # model.freeze()
    # # for name, param in model.named_parameters():
    # #     print(name)
    # #     print(param.data)
    # #     print('----------')

    # # for param_tensor, param_value in model.state_dict().items():
    # #     print(param_tensor, "\t", param_value)
    # quantize_inference(model, test_loader)

    



    
