import argparse
from easydict import EasyDict as edict
import yaml
import os
import torch
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from model import *
from utils import *
from lstm_utils import *
from dataset import get_dataset
from config.alphabets import *
import time
import sys

from torch.utils.tensorboard import SummaryWriter

def parse_arg():
    parser = argparse.ArgumentParser(description="train crnn")

    # parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
    parser.add_argument('--cfg', help='experiment configuration filename', type=str, default='config/360CC_config.yaml')

    args = parser.parse_args()

    with open(args.cfg, 'r') as f:
        # config = yaml.load(f, Loader=yaml.FullLoader)
        config = yaml.load(f,Loader=yaml.FullLoader)
        config = edict(config)

    config.DATASET.ALPHABETS = alphabet
    config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)

    return config


def train(config, train_loader, dataset, converter, model, criterion, optimizer, device, epoch):

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()

    model.train()

    end = time.time()
    for i, (inp, idx) in enumerate(train_loader):
        # measure data time
        data_time.update(time.time() - end)

        labels = get_batch_label(dataset, idx)
        inp = inp.to(device)

        # inference
        preds = model(inp).cpu()

        # compute loss
        batch_size = inp.size(0)
        text, length = converter.encode(labels)                    # length = 一个batch中的总字符长度, text = 一个batch中的字符所对应的下标
        preds_size = torch.IntTensor([preds.size(0)] * batch_size) # timestep * batchsize
        loss = criterion(preds, text, preds_size, length)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(loss.item(), inp.size(0))

        batch_time.update(time.time()-end)
        if i % config.PRINT_FREQ == 0:
            msg = 'Epoch: [{0}][{1}/{2}]\t' \
                  'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                  'Speed {speed:.1f} samples/s\t' \
                  'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
                  'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(
                      epoch, i, len(train_loader), batch_time=batch_time,
                      speed=inp.size(0)/batch_time.val,
                      data_time=data_time, loss=losses)
            print(msg)

        end = time.time()


def validate(config, val_loader, dataset, converter, model, criterion, device, epoch):

    losses = AverageMeter()
    model.eval()

    n_correct = 0
    with torch.no_grad():
        for i, (inp, idx) in enumerate(val_loader):

            # 一个batch的label的list
            labels = get_batch_label(dataset, idx)
            inp = inp.to(device)

            # inference
            preds = model(inp).cpu()

            # compute loss
            batch_size = inp.size(0)
            text, length = converter.encode(labels)
            preds_size = torch.IntTensor([preds.size(0)] * batch_size)
            loss = criterion(preds, text, preds_size, length)

            losses.update(loss.item(), inp.size(0))

            _, preds = preds.max(2)
            preds = preds.transpose(1, 0).contiguous().view(-1)
            # sim_preds是decode后的，得到了一个str或者一个str list (也是一个batch的)
            sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
            # target应该是一个list  这里相当于在判断字符串是否相等
            for pred, target in zip(sim_preds, labels):
                if pred == target:
                    n_correct += 1

            if (i + 1) % config.PRINT_FREQ == 0:
                print('Epoch: [{0}][{1}/{2}]'.format(epoch, i, len(val_loader)))

            # if i == config.TEST.NUM_TEST_BATCH:  # 只检查这些数据的  (config.TEST.NUM_TEST_BATCH个batch)
            #     break

    # 只打印展示前 config.TEST.NUM_TEST_DISP 个句子的对比
    raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.TEST.NUM_TEST_DISP]
    for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
        print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))

    # num_test_sample = config.TEST.NUM_TEST_BATCH * config.TEST.BATCH_SIZE_PER_GPU
    # if num_test_sample > len(dataset):
    #     num_test_sample = len(dataset)

    num_test_sample = len(dataset)
    print("[#correct:{} / #total:{}]".format(n_correct, num_test_sample))
    accuracy = n_correct / float(num_test_sample)
    print('Test loss: {:.4f}, accuray: {:.4f}'.format(losses.avg, accuracy))


    return accuracy

def main():

    # load config
    config = parse_arg()

    # construct face related neural networks
    model = get_crnn(config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = model.to(device)

    # define loss function
    criterion = torch.nn.CTCLoss()

    # 上一轮训练的epoch  (起始为0 )
    last_epoch = config.TRAIN.BEGIN_EPOCH


    train_dataset = get_dataset(config)(config, is_train=True)
    train_loader = DataLoader(
        dataset=train_dataset,
        batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
        shuffle=config.TRAIN.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    val_dataset = get_dataset(config)(config, is_train=False)
    val_loader = DataLoader(
        dataset=val_dataset,
        batch_size=config.TEST.BATCH_SIZE_PER_GPU,
        shuffle=config.TEST.SHUFFLE,
        num_workers=config.WORKERS,
        pin_memory=config.PIN_MEMORY,
    )

    optimizer = get_optimizer(config, model)
    if isinstance(config.TRAIN.LR_STEP, list):
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, config.TRAIN.LR_STEP,
            config.TRAIN.LR_FACTOR, last_epoch-1
        )
    else:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, config.TRAIN.LR_STEP,
            config.TRAIN.LR_FACTOR, last_epoch - 1
        )


    save_dir = 'ckpt'
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir, mode=0o777)
        os.chmod(save_dir, mode=0o777)
        
    converter = strLabelConverter(config.DATASET.ALPHABETS)

    # for resume, only 
    if config.TRAIN.RESUME.IS_RESUME:       
       model.load_state_dict(torch.load(save_dir + '/360cc_' + config.MODEL.NAME + '.pt'))
       acc = validate(config, val_loader, val_dataset, converter, model, criterion, device, config.TRAIN.END_EPOCH)
       print(f"test accuracy: {acc:.2f}%")
       sys.exit()



    best_acc = 0.0
    
    for epoch in range(last_epoch, config.TRAIN.END_EPOCH):

        # train
        train(config, train_loader, train_dataset, converter, model, criterion, optimizer, device, epoch)
        lr_scheduler.step()

        # validate
        acc = validate(config, val_loader, val_dataset, converter, model, criterion, device, epoch)

        # save best ckpt
        if config.TRAIN.SAVE and acc>best_acc:
           torch.save(model.state_dict(), save_dir + '/360cc_' + config.MODEL.NAME + '.pt')

        best_acc = max(acc, best_acc)
        print("best acc is:", best_acc)

if __name__ == '__main__':

    main()