import torch
import torch.nn as nn
import torch.nn.functional as F

from module import *
import module
from global_var import GlobalVariables


# 定义 MobileNet V2 模型
# 适用于Cifar10

class MobileNetV2(nn.Module):
    def __init__(self, num_classes=10):
        super(MobileNetV2, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU6(inplace=True)
         # Bottleneck 层次, t指channel扩充系数
        self.layer1 = MakeLayer(32, 16, 1, t=1, stride=1)
        self.layer2 = MakeLayer(16, 24, 2, t=6, stride=2)
        self.layer3 = MakeLayer(24, 32, 3, t=6, stride=2)
         # 根据CIFAR-10图像大小调整层数
        self.layer4 = MakeLayer(32, 96, 3, t=6, stride=1)
        self.layer5 = MakeLayer(96, 160, 3, t=6, stride=2)
        self.layer6 = MakeLayer(160, 320, 1, t=6, stride=1)

        self.conv2 = nn.Conv2d(320, 1280, 1)
        self.avg1 = nn.AdaptiveAvgPool2d(1)


        self.fc = nn.Linear(1280, num_classes)


    def forward(self, x):
        # x = self.layers(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.conv2(x)
        x = self.avg1(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
    
    def quantize(self, quant_type, num_bits=8, e_bits=3):
        self.qconvbnrelu1 = QConvBNReLU6(quant_type,self.conv1,self.bn1,qi=True,qo=True,num_bits=num_bits,e_bits=e_bits)
        # 没有输入num_bits 需修改
        
        self.layer1.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)      
        self.layer2.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.layer3.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.layer4.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.layer5.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)
        self.layer6.quantize(quant_type=quant_type,num_bits=num_bits, e_bits=e_bits)

        self.qconv1 = QConv2d(quant_type, self.conv2, qi=False, qo=True, num_bits=num_bits, e_bits=e_bits)
        self.qavgpool1 = QAdaptiveAvgPool2d(quant_type,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qfc1 = QLinear(quant_type, self.fc,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)

    def quantize_forward(self, x):
        # for _, layer in self.quantize_layers.items():
        #     x = layer(x)

        # out = F.softmax(x, dim=1)
        # return out
        x = self.qconvbnrelu1(x)
        x = self.layer1.quantize_forward(x)
        x = self.layer2.quantize_forward(x)
        x = self.layer3.quantize_forward(x)
        x = self.layer4.quantize_forward(x)
        x = self.layer5.quantize_forward(x)
        x = self.layer6.quantize_forward(x)
        x = self.qconv1(x)
        x = self.qavgpool1(x)
        x = x.view(x.size(0), -1)   
        x = self.qfc1(x)
        
        out = F.softmax(x,dim = 1)         # 这里不softmax也行 影响不大
        return out

    def freeze(self):
        self.qconvbnrelu1.freeze()  # 因为作为第一层是有qi的，所以freeze的时候无需再重新提供qi
        qo = self.layer1.freeze(qinput = self.qconvbnrelu1.qo)
        qo = self.layer2.freeze(qinput = qo)
        qo = self.layer3.freeze(qinput = qo)
        qo = self.layer4.freeze(qinput = qo)
        qo = self.layer5.freeze(qinput = qo)
        qo = self.layer6.freeze(qinput = qo)
        self.qconv1.freeze(qi = qo)
        self.qavgpool1.freeze(qi=self.qconv1.qo)
        self.qfc1.freeze(qi=self.qavgpool1.qo)
        # self.qfc1.freeze()
    
    def quantize_inference(self, x):

        qx = self.qconvbnrelu1.qi.quantize_tensor(x)
        qx = self.qconvbnrelu1.quantize_inference(qx)

        qx = self.layer1.quantize_inference(qx)
        qx = self.layer2.quantize_inference(qx)
        qx = self.layer3.quantize_inference(qx)
        qx = self.layer4.quantize_inference(qx)
        qx = self.layer5.quantize_inference(qx)
        qx = self.layer6.quantize_inference(qx)

        qx = self.qconv1.quantize_inference(qx)
        qx = self.qavgpool1.quantize_inference(qx)
        qx = qx.view(qx.size(0), -1)
        qx = self.qfc1.quantize_inference(qx) 
        qx = self.qfc1.qo.dequantize_tensor(qx)
        
        out = F.softmax(qx,dim = 1)         # 这里不softmax也行 影响不大
        return out



 

class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        hidden_dims = int(in_channels * expand_ratio)
        self.identity_flag = stride == 1 and in_channels == out_channels

        # self.bottleneck = nn.Sequential(
        #     # Pointwise Convolution
        #     nn.Conv2d(in_channels, hidden_dims, 1),
        #     nn.BatchNorm2d(hidden_dims),
        #     nn.ReLU6(inplace=True),
        #     # Depthwise Convolution
        #     nn.Conv2d(hidden_dims, hidden_dims, 3, stride=stride, padding=1, groups=hidden_dims),
        #     nn.BatchNorm2d(hidden_dims),
        #     nn.ReLU6(inplace=True),
        #     # Pointwise & Linear Convolution
        #     nn.Conv2d(hidden_dims, out_channels, 1),
        #     nn.BatchNorm2d(out_channels),
        # )
        self.conv1 = nn.Conv2d(in_channels, hidden_dims, 1)
        self.bn1 = nn.BatchNorm2d(hidden_dims)
        self.relu1 = nn.ReLU6(inplace=True)
        # Depthwise Convolution
        self.conv2 = nn.Conv2d(hidden_dims, hidden_dims, 3, stride=stride, padding=1, groups=hidden_dims)
        self.bn2 = nn.BatchNorm2d(hidden_dims)
        self.relu2 = nn.ReLU6(inplace=True)
        # Pointwise & Linear Convolution
        self.conv3 = nn.Conv2d(hidden_dims, out_channels, 1)
        self.bn3 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        # if self.identity_flag:
        #     return x + self.bottleneck(x)
        # else:
        #     return self.bottleneck(x)
        identity = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv3(x)
        x = self.bn3(x)

        if self.identity_flag:
            return identity + x
        
        else:
            return x

    
    def quantize(self, quant_type ,num_bits=8, e_bits=3):
        self.qconvbnrelu1 = QConvBNReLU6(quant_type,self.conv1,self.bn1,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qconvbnrelu2 = QConvBNReLU6(quant_type,self.conv2,self.bn2,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        self.qconvbn1 = QConvBN(quant_type,self.conv3,self.bn3,qi=False,qo=True,num_bits=num_bits,e_bits=e_bits)
        

        self.qelementadd = QElementwiseAdd(quant_type,qi0=False, qi1=False, qo=True,num_bits=num_bits,e_bits=e_bits)

   
    def quantize_forward(self, x):
        identity = x
        out = self.qconvbnrelu1(x)
        out = self.qconvbnrelu2(out)
        out = self.qconvbn1(out)

        if self.identity_flag:
            out = self.qelementadd(out, identity)

        return out
    
    def freeze(self, qinput):
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        self.qconvbnrelu1.freeze(qi= qinput)   # 需要接前一个module的最后一个qo
        self.qconvbnrelu2.freeze(qi=self.qconvbnrelu1.qo)
        self.qconvbn1.freeze(qi = self.qconvbnrelu2.qo)

        if self.identity_flag:
            self.qelementadd.freeze(qi0 = self.qconvbn1.qo, qi1 = qinput)
            return self.qelementadd.qo
        else:
            return self.qconvbn1.qo
        
    
    def quantize_inference(self, x):
        # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor，因为他不是最前/后一层，只要中间的每层都在量化后的领域内，就不需要这种处理。
        identity = x
        out = self.qconvbnrelu1.quantize_inference(x)
        out = self.qconvbnrelu2.quantize_inference(out)
        out = self.qconvbn1.quantize_inference(out)

        if self.identity_flag:
            out = self.qelementadd.quantize_inference(out, identity)        
        
        return out



class MakeLayer(nn.Module):

    # def _make_bottleneck(self, in_channels, out_channels, n_repeat, t, stride):
    #     layers = []
    #     for i in range(n_repeat):
    #         if i == 0:
    #             layers.append(InvertedResidual(in_channels, out_channels, stride, t))
    #         else:
    #             layers.append(InvertedResidual(in_channels, out_channels, 1, t))
    #         in_channels = out_channels
    #     return nn.Sequential(*layers)

    def __init__(self, in_channels, out_channels, n_repeat, t, stride):
        super(MakeLayer, self).__init__()
        # print('makelayer init:'+ str(GlobalVariables.SELF_INPLANES))
        self.layers = nn.ModuleList()
        for i in range(n_repeat):
            if i == 0:
                self.layers.append(InvertedResidual(in_channels, out_channels, stride, t))
            else:
                self.layers.append(InvertedResidual(in_channels, out_channels, 1, t))
            in_channels = out_channels
        
        # for l in self.layers:
        #     print(l)


    def forward(self,x):
        
        for layer in self.layers:
            x = layer(x)
        
        return x

    def quantize(self, quant_type, num_bits=8, e_bits=3):
        # 需检查
        # print('CHECK======')
        for layer in self.layers:
            layer.quantize(quant_type=quant_type,num_bits=num_bits,e_bits=e_bits)   # 这里是因为每一块都是block，而block中有具体的quantize策略, n_exp和mode已经在__init__中赋值了
            
            # print(layer)

        # print('CHECK======')
        

    def quantize_forward(self, x):
        for layer in self.layers:
            x = layer.quantize_forward(x)   # 各个block中有具体的quantize_forward

        return x
        
       
    def freeze(self, qinput):  # 需要在 Module Resnet的freeze里传出来
        # 这里的qconvbnrelu1其实是可以用前一层的qo的，但感觉不太好传参，就没用 
        # 还需仔细检查
        cnt = 0
        for layer in self.layers:
            if cnt == 0:
                qo = layer.freeze(qinput = qinput)
                cnt = 1
            else:
                qo = layer.freeze(qinput = qo)  # 各个block中有具体的freeze

        return qo   # 供后续的层用


    def quantize_inference(self, x):
        # 感觉是不需要进行初始的quantize_tensor和dequantize_tensor，因为他不是最前/后一层，只要中间的每层都在量化后的领域内，就不需要这种处理。
        for layer in self.layers:
            x = layer.quantize_inference(x)  # 每个block中有具体的quantize_inference
        
        return x

