import torch.nn as nn
import torch.nn.functional as F
from module import *


# class BidirectionalLSTM(nn.Module):
#     # Inputs hidden units Out
#     def __init__(self, nIn, nHidden, nOut):
#         super(BidirectionalLSTM, self).__init__()

#         self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
#         self.embedding = nn.Linear(nHidden * 2, nOut)

#     def forward(self, input):
#         recurrent, _ = self.rnn(input)
#         T, b, h = recurrent.size()
#         t_rec = recurrent.view(T * b, h)

#         output = self.embedding(t_rec)  # [T * b, nOut]
#         output = output.view(T, b, -1)

#         return output

class CRNN(nn.Module):
    def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
        super(CRNN, self).__init__()
        assert imgH % 16 == 0, 'imgH has to be a multiple of 16'

        self.layers = nn.ModuleDict()

        # self.conv1 = nn.Conv2d(nc,64,3,1,1)
        # self.relu1 = nn.ReLU()
        # self.pool1 = nn.MaxPool2d(2, 2)
        self.layers['conv1'] = nn.Conv2d(nc,64,3,1,1)
        self.layers['relu1'] = nn.ReLU()
        self.layers['pool1'] = nn.MaxPool2d(2, 2)

        # self.conv2 = nn.Conv2d(64,128,3,1,1)
        # self.relu2 = nn.ReLU()
        # self.pool2 = nn.MaxPool2d(2, 2)

        self.layers['conv2'] = nn.Conv2d(64,128,3,1,1)
        self.layers['relu2'] = nn.ReLU()
        self.layers['pool2'] = nn.MaxPool2d(2, 2)


        # self.conv3 = nn.Conv2d(128,256,3,1,1)
        # self.bn3 = nn.BatchNorm2d(256)
        # self.relu3 = nn.ReLU()

        self.layers['conv3'] = nn.Conv2d(128,256,3,1,1)
        self.layers['bn3'] = nn.BatchNorm2d(256)
        self.layers['relu3'] = nn.ReLU()
        

        # self.conv4 = nn.Conv2d(256,256,3,1,1)
        # self.relu4 = nn.ReLU()
        # self.pool4 = nn.MaxPool2d((2, 2), (2, 1), (0, 1))

        self.layers['conv4'] = nn.Conv2d(256,256,3,1,1)
        self.layers['relu4'] = nn.ReLU()
        self.layers['pool4'] = nn.MaxPool2d((2, 2), (2, 1), (0, 1))

        # self.conv5 = nn.Conv2d(256,512,3,1,1)
        # self.bn5 = nn.BatchNorm2d(512)
        # self.relu5 = nn.ReLU()

        self.layers['conv5'] = nn.Conv2d(256,512,3,1,1)
        self.layers['bn5'] = nn.BatchNorm2d(512)
        self.layers['relu5'] = nn.ReLU()
        

        # self.conv6 = nn.Conv2d(512,512,3,1,1)
        # self.relu6 = nn.ReLU()
        # self.pool6 = nn.MaxPool2d((2, 2), (2, 1), (0, 1))

        self.layers['conv6'] = nn.Conv2d(512,512,3,1,1)
        self.layers['relu6'] = nn.ReLU()
        self.layers['pool6'] = nn.MaxPool2d((2, 2), (2, 1), (0, 1))

        # self.conv7 = nn.Conv2d(512,512,2,0,1)
        # self.bn7 = nn.BatchNorm2d(512)
        # self.relu7 = nn.ReLU()

        self.layers['conv7'] = nn.Conv2d(512,512,2,1,0)
        self.layers['bn7'] = nn.BatchNorm2d(512)
        self.layers['relu7'] = nn.ReLU()
        
        # 默认用的CONCAT
        # self.lstm1 = nn.LSTM(512, nh, bidirectional=True)
        # self.fc1 = nn.Linear(nh*2, nh)
        # self.lstm2 = nn.LSTM(nh,nh,bidirectional=True)
        # self.fc2 = nn.Linear(nh*2 , nclass)

        self.layers['lstm1'] = nn.LSTM(512, nh, bidirectional=True)
        self.layers['fc1'] = nn.Linear(nh*2, nh)
        self.layers['lstm2'] = nn.LSTM(nh,nh,bidirectional=True)
        self.layers['fc2'] = nn.Linear(nh*2 , nclass)


    def forward(self, x):

        # conv features
        for name,layer in self.layers.items():
            if 'lstm' in name:
                break
            else:
                x = layer(x)

        x = x.squeeze(2)  # b *512 * width
        # 需要将width作为seq继续传给lstm
        x = x.permute(2,0,1)   # [w, b, c]
        
        x,_ = self.layers['lstm1'](x)
        t, n = x.size(0), x.size(1)
        x = x.view(t * n, -1)
        x = self.layers['fc1'](x)
        x = x.view(t,n,-1)
        x,_ = self.layers['lstm2'](x)
        t, n = x.size(0), x.size(1)
        x = x.view(t * n, -1)
        x = self.layers['fc2'](x)
        x = x.view(t,n,-1)

        output = F.log_softmax(x, dim=2)

        return output

    def quantize(self, quant_type, num_bits=8, e_bits=3):
        self.qlayers = nn.ModuleDict()
        self.qlayers['qconv1'] = QConv2d(quant_type, self.layers['conv1'], qi=True, qo=True, num_bits=num_bits, e_bits=e_bits)
        self.qlayers['qrelu1'] = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
        self.qlayers['qpool1'] =  QMaxPooling2d(quant_type, kernel_size=self.layers['pool1'].kernel_size, stride=self.layers['pool1'].stride, padding=self.layers['pool1'].padding, num_bits=num_bits, e_bits=e_bits)

        self.qlayers['qconv2'] =  QConv2d(quant_type, self.layers['conv2'], qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
        self.qlayers['qrelu2'] =  QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
        self.qlayers['qpool2'] =  QMaxPooling2d(quant_type, kernel_size=self.layers['pool2'].kernel_size, stride=self.layers['pool2'].stride, padding=self.layers['pool2'].padding, num_bits=num_bits, e_bits=e_bits)

        self.qlayers['qconvbnrelu3'] = QConvBNReLU(quant_type, self.layers['conv3'], self.layers['bn3'], qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)

        self.qlayers['qconv4'] = QConv2d(quant_type, self.layers['conv4'], qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
        self.qlayers['qrelu4'] = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
        self.qlayers['qpool4'] =  QMaxPooling2d(quant_type, kernel_size=self.layers['pool4'].kernel_size, stride=self.layers['pool4'].stride, padding=self.layers['pool4'].padding, num_bits=num_bits, e_bits=e_bits)

        self.qlayers['qconvbnrelu5'] = QConvBNReLU(quant_type, self.layers['conv5'], self.layers['bn5'], qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)

        self.qlayers['qconv6'] = QConv2d(quant_type, self.layers['conv6'], qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
        self.qlayers['qrelu6'] = QReLU(quant_type, num_bits=num_bits, e_bits=e_bits)
        self.qlayers['qpool6'] = QMaxPooling2d(quant_type, kernel_size=self.layers['pool6'].kernel_size, stride=self.layers['pool6'].stride, padding=self.layers['pool6'].padding, num_bits=num_bits, e_bits=e_bits)

        self.qlayers['qconvbnrelu7'] = QConvBNReLU(quant_type, self.layers['conv7'], self.layers['bn7'], qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)

        self.qlayers['qlstm1'] = QLSTM(quant_type, self.layers['lstm1'], has_hidden=False, qix=False, qih=True, qic=True, num_bits=num_bits, e_bits=e_bits)
        self.qlayers['qfc1'] = QLinear(quant_type, self.layers['fc1'], qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
        # 并没有接续上一个lstm 因此qih，qic仍需要True来统计 
        self.qlayers['qlstm2'] = QLSTM(quant_type, self.layers['lstm2'], has_hidden=False, qix=False, qih=True, qic=True, num_bits=num_bits, e_bits=e_bits)
        self.qlayers['qfc2'] = QLinear(quant_type, self.layers['fc2'], qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)

    def quantize_forward(self, x):
         # conv features
        for name,layer in self.qlayers.items():
            if 'qlstm' in name:
                break
            else:
                x = layer(x)

        x = x.squeeze(2)  # b *512 * width
        # 需要将width作为seq继续传给lstm
        x = x.permute(2,0,1)   # [w, b, c]
        
        x,_ = self.qlayers['qlstm1'](x)
        t, n = x.size(0), x.size(1)
        x = x.view(t * n, -1)
        x = self.qlayers['qfc1'](x)
        x = x.view(t,n,-1)
        x,_ = self.qlayers['qlstm2'](x)
        t, n = x.size(0), x.size(1)
        x = x.view(t * n, -1)
        x = self.qlayers['qfc2'](x)
        x = x.view(t,n,-1)

        output = F.log_softmax(x, dim=2)

        return output

    def freeze(self):
        last_name = ''
        for name,layer in self.qlayers.items():
            if last_name == '':
                layer.freeze()
            else:
                if 'lstm' not in last_name:
                    layer.freeze(self.qlayers[last_name].qo)
                else:
                    layer.freeze(self.qlayers[last_name].qox)
            if 'conv' in name or 'fc' in name or 'lstm' in name:
                last_name = name


    def fakefreeze(self):
        for name,layer in self.qlayers.items():
            if 'lstm' not in name:
                layer.fakefreeze()





    def quantize_inference(self, x):
        x= self.qlayers['qconv1'].qi.quantize_tensor(x)
        for name,layer in self.qlayers.items():
            if 'qlstm' in name:
                break
            else:
                x = layer.quantize_inference(x)

        # dequantize => fp32 scale
        # x = layer.qix.dequantize_tensor(x)
        x = self.qlayers['qconvbnrelu7'].qo.dequantize_tensor(x)

        x = x.squeeze(2)  # b *512 * width
        # 需要将width作为seq继续传给lstm
        x = x.permute(2,0,1)   # [w, b, c]
        
        x,_ = self.qlayers['qlstm1'].quantize_inference(x)
        t, n = x.size(0), x.size(1)
        x = x.view(t * n, -1)

        x = self.qlayers['qfc1'].qi.quantize_tensor(x)
        x = self.qlayers['qfc1'].quantize_inference(x)
        x = self.qlayers['qfc1'].qo.dequantize_tensor(x)

        x = x.view(t,n,-1)
        x,_ = self.qlayers['qlstm2'].quantize_inference(x)
        t, n = x.size(0), x.size(1)
        x = x.view(t * n, -1)

        x = self.qlayers['qfc2'].qi.quantize_tensor(x)
        x = self.qlayers['qfc2'].quantize_inference(x)
        x = self.qlayers['qfc2'].qo.dequantize_tensor(x)

        x = x.view(t,n,-1)

        output = F.log_softmax(x, dim=2)

        return output



# Xavier initialization 使每个神经元的输出方差大致相等 避免梯度消失或梯度爆炸
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

def get_crnn(config):

    model = CRNN(config.MODEL.IMAGE_SIZE.H, 1, config.MODEL.NUM_CLASSES + 1, config.MODEL.NUM_HIDDEN)
    model.apply(weights_init)

    return model