from model import *
from functools import partial
from lstm_utils import *

import argparse
from easydict import EasyDict as edict
import yaml
from config.alphabets import *

import sys
import torch
from ptflops import get_model_complexity_info

data_path = '../data/ptb'
embed_size = 700
hidden_size = 700
eval_batch_size = 10
dropout = 0.65
tied = True

def parse_arg():
    parser = argparse.ArgumentParser(description="train crnn")

    # parser.add_argument('--cfg', help='experiment configuration filename', required=True, type=str)
    parser.add_argument('--cfg', help='experiment configuration filename', type=str, default='config/360CC_config.yaml')

    args = parser.parse_args()

    with open(args.cfg, 'r') as f:
        # config = yaml.load(f, Loader=yaml.FullLoader)
        config = yaml.load(f,Loader=yaml.FullLoader)
        config = edict(config)

    config.DATASET.ALPHABETS = alphabet
    config.MODEL.NUM_CLASSES = len(config.DATASET.ALPHABETS)

    return config

def lstm_constructor(shape,hidden):
    return {"x":        torch.zeros(shape,dtype=torch.int64),
            "hidden":   hidden}

if __name__ == "__main__":

    config = parse_arg()
    model = model = get_crnn(config)

    full_file = 'ckpt/360cc_lstm-ocr.pt'
    model.load_state_dict(torch.load(full_file))
    flops, params = get_model_complexity_info(model, (1, 32, 160), as_strings=True,
                                            print_per_layer_stat=True)
