import math
import numpy as np

class seq_mnist_decoder():
    def __init__(self, labels, blank=0):
        self.blank_chr = blank
        self.labels = labels
    
    def decode(self, predictions, output_len, label_len):
        predictions = predictions.data.cpu().numpy()
        output = []
        # 把结果逐个翻译，再拼成序列
        # predictions参数是一个大小为(time_steps, num_classes)的二维数组，表示模型的预测输出。每一行代表一个时间步长，每一列代表一个可能标签的概率。
        for i in range(output_len):
            pred = np.argmax(predictions[i, :])
            # 对标签做一些去除空和重复的处理(因为lstm序列中可能多个neuron处理同一个字符)
            if (pred != self.blank_chr) and (pred != np.argmax(predictions[i-1, :])): # merging repeats and removing blank character (0)
                output.append(pred-1)
        return np.asarray(output)

    def hit(self, pred, target):
        res = []
        for idx, word in enumerate(target):
            if idx < len(pred):    # 列表长度
                item = pred[idx]
            # 判断当前位置是否已经不小于预测结果列表的长度，则真实标签对应的预测结果已不存在，此时我们将item设置为任意一个
            else:
                item = 10
            res.append(word == item)
        acc = np.mean(np.asarray(res))*100
        if math.isnan(acc):
            return 0.00
        else:
            return acc

    def to_string(self, in_str):
        out_str = ''
        for i in range(in_str.shape[0]):
            out_str += str(in_str[i])
        return out_str